Gradio 月活用户突破 100 万!

阅读更多
Gradio logo
  1. 其他教程
  2. 使用 GAN 创建你自己的朋友

使用 GAN 创建你自己的朋友

简介

似乎加密货币、NFT 和 web3 运动现在非常流行!数字资产在市场上以惊人的价格被列出,几乎每位名人都推出了自己的 NFT 系列。虽然你的加密资产可能需要缴税,例如在加拿大,但今天我们将探索一些有趣且免税的方式来生成你自己的程序生成的 CryptoPunks 合集。

生成对抗网络,通常简称为 GAN,是一类特定的深度学习模型,旨在从输入数据集中学习,以创建(生成!)与原始训练集元素非常相似的新材料。最著名的是,网站 thispersondoesnotexist.com 因使用名为 StyleGAN2 的模型生成的逼真但合成的人物图像而走红。GAN 在机器学习领域获得了关注,现在正被用于生成各种图像、文本,甚至 音乐

今天,我们将简要了解 GAN 背后的高层次直觉,然后我们将围绕预训练的 GAN 构建一个小演示,看看这一切到底是怎么回事。这是一个 预览,看看我们将要组合在一起的东西。

先决条件

确保你已经安装gradio Python 包。要使用预训练模型,还需要安装 torchtorchvision

GANs:非常简短的介绍

最初在 Goodfellow 等人 2014 中提出,GAN 由神经网络组成,这些神经网络相互竞争,意图胜过对方。一个网络,称为生成器,负责生成图像。另一个网络,判别器,每次从生成器接收一个图像以及来自训练数据集的真实图像。然后,判别器必须猜测:哪个图像是假的?

生成器不断训练以创建对于判别器来说更难识别的图像,而判别器每次正确检测到假图像时都会提高生成器的标准。随着网络参与这种竞争性(对抗性!)关系,生成的图像得到改进,以至于它们变得与人眼无法区分!

要更深入地了解 GAN,你可以查看 Analytics Vidhya 上的这篇优秀文章 或这篇 PyTorch 教程。不过,现在我们将深入研究一个演示!

步骤 1 — 创建生成器模型

要使用 GAN 生成新图像,你只需要生成器模型。生成器可以使用许多不同的架构,但对于本演示,我们将使用具有以下架构的预训练 GAN 生成器模型

from torch import nn

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

我们从 @teddykoker 的这个 repo 中获取生成器,你也可以在那里看到原始的判别器模型结构。

在实例化模型后,我们将从 Hugging Face Hub 加载权重,存储在 nateraw/cryptopunks-gan

from huggingface_hub import hf_hub_download
import torch

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

步骤 2 — 定义 predict 函数

predict 函数是使 Gradio 工作的关键!无论我们通过 Gradio 界面选择什么输入,都将通过我们的 predict 函数传递,该函数应该对输入进行操作并生成我们可以使用 Gradio 输出组件显示的输出。对于 GAN,通常将随机噪声作为输入传递到我们的模型中,因此我们将生成一个随机数张量,并通过模型传递它。然后我们可以使用 torchvisionsave_image 函数将模型的输出保存为 png 文件,并返回文件名

from torchvision.utils import save_image

def predict(seed):
    num_punks = 4
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

我们正在为我们的 predict 函数提供一个 seed 参数,以便我们可以使用种子固定随机张量生成。然后,如果我们想再次看到它们,就可以通过传入相同的种子来重现 punks。

注意! 我们的模型需要一个维度为 100x1x1 的输入张量才能进行单次推理,或者 (BatchSize)x100x1x1 用于生成一批图像。在本演示中,我们将首先一次生成 4 个 punks。

步骤 3 — 创建 Gradio 界面

此时,您甚至可以使用 predict(<SOME_NUMBER>) 运行您拥有的代码,您将在文件系统的 ./punks.png 中找到新生成的 punks。但是,为了制作一个真正交互式的演示,我们将使用 Gradio 构建一个简单的界面。我们的目标是:

  • 设置一个滑块输入,以便用户可以选择 “seed” 值
  • 使用图像组件作为我们的输出,以展示生成的 punks
  • 使用我们的 predict() 函数来获取种子并生成图像

使用 gr.Interface(),我们可以用一个函数调用定义所有这些

import gradio as gr

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
    ],
    outputs="image",
).launch()

步骤 4 — 更多 punks!

一次生成 4 个 punks 是一个好的开始,但也许我们想控制每次想要制作多少个。向 Gradio 界面添加更多输入就像向我们传递给 gr.Interfaceinputs 列表添加另一个项目一样简单

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10), # Adding another slider!
    ],
    outputs="image",
).launch()

新的输入将传递到我们的 predict() 函数,因此我们必须对该函数进行一些更改以接受新的参数

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

当您重新启动您的界面时,您应该会看到第二个滑块,它将让您控制 punks 的数量!

步骤 5 - 润色

您的 Gradio 应用程序几乎可以正常运行了,但是您可以添加一些额外的功能,使其真正准备好成为焦点 ✨

我们可以添加一些示例,用户可以通过将其添加到 gr.Interface 中轻松试用:

gr.Interface(
    # ...
    # keep everything as it is, and then add
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True) # cache_examples is optional

examples 参数接受列表的列表,其中子列表中的每个项目都按照我们列出 inputs 的相同顺序排序。所以在我们的例子中,[seed, num_punks]。试一试!

您还可以尝试向 gr.Interface 添加 titledescriptionarticle。 这些参数中的每一个都接受一个字符串,所以试一试,看看会发生什么 👀 article 也将接受 HTML,正如之前的指南中探讨的那样

当您完成所有操作后,您最终可能会得到类似这样的东西。

作为参考,这是我们的完整代码

import torch
from torch import nn
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
import gradio as gr

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10),
    ],
    outputs="image",
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True)

恭喜!您已经构建了您自己的 GAN 驱动的 CryptoPunks 生成器,以及一个精美的 Gradio 界面,使任何人都可以轻松使用。现在您可以在 Hub 上搜索更多 GAN(或训练您自己的),并继续制作更精彩的演示 🤗