Gradio 代理 & MCP 黑客松

获奖者
Gradio logo
  1. 其他教程
  2. 使用 GAN 创建你自己的朋友

使用 GAN 创建你自己的朋友

引言

加密货币、NFT 和 web3 运动似乎是当今的热门话题!数字资产在市场上以惊人的价格挂牌出售,几乎所有名人都推出了自己的 NFT 系列。虽然你的加密资产可能需要纳税,例如在加拿大,但今天我们将探索一些有趣且免税的方法来生成你自己的各种程序化生成的加密朋克(CryptoPunks)

生成对抗网络,通常简称为 GANs,是一类特殊的深度学习模型,旨在从输入数据集中学习,以创建(生成!)与原始训练集元素非常相似的新材料。著名的网站thispersondoesnotexist.com因使用名为 StyleGAN2 的模型生成逼真但合成的人物图像而走红。GANs 在机器学习领域获得了关注,现在被用于生成各种图像、文本,甚至音乐

今天我们将简要了解 GANs 背后的高级直觉,然后我们将围绕一个预训练的 GAN 构建一个小演示,看看它到底有什么神奇之处。这是我们即将组建的预览

先决条件

请确保您已安装gradio Python 包。要使用预训练模型,还需要安装 torchtorchvision

GANs:非常简短的介绍

GANs 最初由Goodfellow 等人于 2014 年提出,由相互竞争以智取对方的神经网络组成。其中一个网络,称为生成器,负责生成图像。另一个网络,判别器,每次从生成器接收一张图像,同时接收一张来自训练数据集的真实图像。然后判别器必须猜测:哪张图像是假的?

生成器不断训练以创建更难被判别器识别的图像,而判别器每次正确检测到假图像时,都会提高对生成器的要求。随着网络进入这种竞争(对抗!)关系,生成的图像质量不断提高,直到它们对人眼来说变得难以区分!

要更深入地了解 GANs,您可以查阅Analytics Vidhya 上的这篇优秀文章或此PyTorch 教程。不过,现在我们将深入演示!

步骤 1 — 创建生成器模型

要使用 GAN 生成新图像,您只需要生成器模型。生成器可以使用许多不同的架构,但在此演示中,我们将使用具有以下架构的预训练 GAN 生成器模型:

from torch import nn

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

我们正在从 @teddykoker 的这个仓库中获取生成器,您也可以在那里查看原始的判别器模型结构。

实例化模型后,我们将从 Hugging Face Hub 加载权重,权重存储在nateraw/cryptopunks-gan

from huggingface_hub import hf_hub_download
import torch

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

步骤 2 — 定义 predict 函数

predict 函数是 Gradio 工作的关键!我们通过 Gradio 界面选择的任何输入都将通过我们的 predict 函数,该函数应对输入进行操作并生成我们可以用 Gradio 输出组件显示的输出。对于 GANs,通常将随机噪声作为输入传递给我们的模型,因此我们将生成一个随机数张量并将其通过模型。然后我们可以使用 torchvisionsave_image 函数将模型输出保存为 png 文件,并返回文件名称。

from torchvision.utils import save_image

def predict(seed):
    num_punks = 4
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

我们给 predict 函数一个 seed 参数,这样我们就可以用一个种子固定随机张量生成。如果我们想再次看到相同的朋克,只需传入相同的种子即可重现它们。

注意! 我们的模型需要维度为 100x1x1 的输入张量才能进行单次推理,或者需要 (BatchSize)x100x1x1 的张量才能生成一批图像。在此演示中,我们将从一次生成 4 个朋克开始。

步骤 3 — 创建 Gradio 界面

此时,您甚至可以使用 predict(<SOME_NUMBER>) 运行您的代码,您将在文件系统中 ./punks.png 处找到新生成的朋克。然而,为了制作一个真正的交互式演示,我们将使用 Gradio 构建一个简单的界面。我们的目标是:

  • 设置一个滑块输入,让用户可以选择“种子”值
  • 使用图像组件作为输出,展示生成的朋克
  • 使用我们的 predict() 函数接收种子并生成图像

通过 gr.Interface(),我们可以通过一个函数调用来定义所有这些。

import gradio as gr

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
    ],
    outputs="image",
).launch()

步骤 4 — 更多朋克!

一次生成 4 个朋克是一个好的开始,但也许我们想控制每次生成的数量。向 Gradio 界面添加更多输入就像向我们传递给 gr.Interfaceinputs 列表中添加另一个项一样简单。

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10), # Adding another slider!
    ],
    outputs="image",
).launch()

新输入将传递给我们的 predict() 函数,因此我们必须对该函数进行一些更改以接受新参数。

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

当您重新启动界面时,您应该会看到第二个滑块,让您能够控制朋克的数量!

步骤 5 - 优化和完善

您的 Gradio 应用基本可以使用了,但您可以添加一些额外内容,使其真正准备好在聚光灯下闪耀 ✨

我们可以通过将此添加到 gr.Interface 来添加一些用户可以轻松尝试的示例

gr.Interface(
    # ...
    # keep everything as it is, and then add
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True) # cache_examples is optional

examples 参数接受一个列表的列表,其中子列表中的每个项都按照我们列出 inputs 的相同顺序排列。因此,在我们的例子中,是 [seed, num_punks]。试一试吧!

您还可以尝试向 gr.Interface 添加 titledescriptionarticle。这些参数都接受字符串,所以试试看会发生什么 👀 article 也接受 HTML,如之前的指南中所探讨

完成后,您可能会得到类似这样的东西。

供参考,这是我们的完整代码

import torch
from torch import nn
from huggingface_hub import hf_hub_download
from torchvision.utils import save_image
import gradio as gr

class Generator(nn.Module):
    # Refer to the link below for explanations about nc, nz, and ngf
    # https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html#inputs
    def __init__(self, nc=4, nz=100, ngf=64):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

model = Generator()
weights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available

def predict(seed, num_punks):
    torch.manual_seed(seed)
    z = torch.randn(num_punks, 100, 1, 1)
    punks = model(z)
    save_image(punks, "punks.png", normalize=True)
    return 'punks.png'

gr.Interface(
    predict,
    inputs=[
        gr.Slider(0, 1000, label='Seed', default=42),
        gr.Slider(4, 64, label='Number of Punks', step=1, default=10),
    ],
    outputs="image",
    examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],
).launch(cache_examples=True)

恭喜!您已经构建了自己基于 GAN 的 CryptoPunks 生成器,并拥有一个精美的 Gradio 界面,使任何人都可以轻松使用。现在您可以在 Hub 上搜索更多 GANs(或训练您自己的 GAN),并继续制作更棒的演示 🤗