1. 其他教程
  2. 用 GAN 创建你自己的朋友

使用 GAN 创建你自己的朋友

简介

加密货币、NFT 和 web3 运动似乎是当下的热门话题!数字资产正在以惊人的价格在市场上出售,几乎每位名人都推出了自己的 NFT 系列。虽然您的加密资产可能需要缴税,例如在加拿大,但今天我们将探索一些有趣且免税的方法来生成您自己的一系列程序化生成的 CryptoPunks

生成对抗网络(通常简称为 GANs)是一种深度学习模型的特定类别,旨在从输入数据集中学习,以创建(生成!)与原始训练集元素非常相似的新材料。著名的网站 thispersondoesnotexist.com 通过使用名为 StyleGAN2 的模型生成逼真但合成的人物图像而走红。GANs 在机器学习领域获得了广泛关注,现在正被用于生成各种图像、文本,甚至是音乐

今天我们将简要了解 GANs 背后的高级直觉,然后我们将围绕一个预训练的 GAN 构建一个小演示,看看它到底有什么了不起之处。这里是我们将要组装的预览

先决条件

确保您已经安装gradio Python 包。要使用预训练模型,还需要安装 torchtorchvision

GANs:非常简短的介绍

GANs 最初由 Goodfellow 等人于 2014 年提出,由相互竞争的神经网络组成,目的是相互超越。其中一个网络称为*生成器*,负责生成图像。另一个网络称为*判别器*,它会接收来自生成器的一张图像以及来自训练数据集的*真实*图像。然后,判别器必须猜测:哪张图像是假的?

生成器不断训练以创建对判别器来说更难识别的图像,而判别器每次正确检测到假图像时都会提高生成器的标准。随着网络进入这种竞争(*对抗!*)关系,生成的图像质量不断提高,直到肉眼无法分辨真假!

有关 GANs 的更深入了解,您可以查看 Analytics Vidhya 上的这篇优秀文章这份 PyTorch 教程。不过,现在我们将深入了解一个演示!

步骤 1 — 创建生成器模型

要使用 GAN 生成新图像,您只需要生成器模型。生成器可以使用许多不同的架构,但对于这个演示,我们将使用一个具有以下架构的预训练 GAN 生成器模型

from torch import nn

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

我们正在使用 @teddykoker 的这个 repo 中的生成器,您也可以在那里看到原始的判别器模型结构。

实例化模型后,我们将从 Hugging Face Hub 加载权重,存储在 nateraw/cryptopunks-gan

from huggingface_hub import hf_hub_download
import torch

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

步骤 2 — 定义一个 predict 函数

predict 函数是使 Gradio 正常工作的关键!我们通过 Gradio 界面选择的任何输入都将通过我们的 predict 函数传递,该函数应该对输入进行操作并生成我们可以使用 Gradio 输出组件显示的输出。对于 GANs,通常将随机噪声作为输入传递给模型,因此我们将生成一个随机数张量并通过模型传递它。然后我们可以使用 torchvisionsave_image 函数将模型的输出保存为 png 文件,并返回文件名。

from torchvision.utils import save_image

def predict(seed):
    num_punks = 4
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

我们为 predict 函数提供了一个 seed 参数,这样我们就可以用一个种子来固定随机张量生成。然后,如果需要再次查看这些 punks,我们可以通过传入相同的种子来重现它们。

注意!我们的模型需要一个 100x1x1 维度的输入张量来进行单次推理,或者 (BatchSize)x100x1x1 来生成一批图像。在这个演示中,我们将从一次生成 4 个 punks 开始。

步骤 3 — 创建 Gradio 界面

此时,您甚至可以使用 predict(<某个数字>) 运行代码,您会发现新生成的 punks 位于文件系统中的 ./punks.png。不过,为了制作一个真正交互式的演示,我们将使用 Gradio 构建一个简单的界面。我们的目标是:

  • 设置一个滑块输入,让用户选择“种子”值
  • 使用图像组件作为输出,展示生成的 punks
  • 使用我们的 predict() 函数接收种子并生成图像

使用 gr.Interface(),我们可以通过一次函数调用来定义所有这些内容

import gradio as gr

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
    ],
    outputs="image",
).launch()

步骤 4 — 更多 Punks!

一次生成 4 个 punks 是一个不错的开始,但也许我们想控制每次生成的数量。向 Gradio 界面添加更多输入就像向我们传递给 gr.Interfaceinputs 列表添加另一个项一样简单

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10), # Adding another slider!
    ],
    outputs="image",
).launch()

新输入将传递给我们的 predict() 函数,因此我们必须对该函数进行一些更改以接受新参数

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

重新启动界面后,您应该会看到第二个滑块,可以让您控制 punks 的数量!

步骤 5 - 完善它

您的 Gradio 应用基本可以使用了,但您可以添加一些额外的东西来使其真正引人注目 ✨

我们可以通过向 gr.Interface 添加以下内容来增加用户可以轻松尝试的示例

gr.Interface(
    # ...
    # keep everything as it is, and then add
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True) # cache_examples is optional

examples 参数接受一个列表的列表,其中子列表中的每个项的顺序与我们列出的 inputs 顺序相同。所以在我们的例子中是 [seed, num_punks]。试试看吧!

您还可以尝试向 gr.Interface 添加 titledescriptionarticle。这些参数都接受字符串,所以试试看会发生什么吧 👀 article 也接受 HTML,正如在之前的指南中探讨过的那样

完成后,您可能会得到类似这样的东西。

作为参考,这是我们的完整代码

import torch
from torch import nn
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
import gradio as gr

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10),
    ],
    outputs="image",
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True)

恭喜!您已经构建了属于自己的 GAN 驱动的 CryptoPunks 生成器,拥有一个精美的 Gradio 界面,让任何人都可以轻松使用。现在您可以在 Hub 上寻找更多 GAN(或训练自己的 GAN)并继续制作更多出色的演示 🤗

gradio