Gradio 代理 & MCP 黑客松

获奖者
Gradio logo
  1. 其他教程
  2. 使用数据标记

使用数据标记

引言

当您演示机器学习模型时,您可能希望从尝试该模型的用户那里收集数据,特别是那些模型行为不如预期的数据点。捕获这些“难点”数据点很有价值,因为这能让您改进机器学习模型,使其更可靠、更稳健。

Gradio 通过在每个 Interface 中包含一个标记按钮,简化了数据收集。这使得用户或测试人员可以轻松地将数据发送回运行演示的机器。在本指南中,我们将详细讨论如何在 gradio.Interfacegradio.Blocks 中使用数据标记功能。

gradio.Interface 中的标记按钮

使用 Gradio 的 Interface 进行数据标记非常容易。默认情况下,在输出组件下方有一个标有标记的按钮。当测试您模型的用户看到具有有趣输出的输入时,他们可以点击标记按钮,将输入和输出数据发送回运行演示的机器。这些样本(默认情况下)保存到一个 JSON 日志文件中。如果演示涉及图像、音频、视频或其他类型的文件,这些文件将单独保存在一个并行目录中,文件路径则保存在 JSON 文件中。

gradio.Interface 中有四个参数控制数据标记的工作方式。我们将详细介绍它们。

  • flagging_mode:此参数可设置为 "manual"(默认)、"auto""never"
    • manual(手动):用户将看到一个标记按钮,只有在点击按钮时才标记样本。
    • auto(自动):用户不会看到标记按钮,但每个样本都会自动标记。
    • never(从不):用户不会看到标记按钮,并且不会标记任何样本。
  • flagging_options:此参数可为 None(默认)或字符串列表。
    • 如果为 None,用户只需点击标记按钮,不会显示其他选项。
    • 如果提供字符串列表,则用户会看到多个按钮,每个按钮对应提供的一个字符串。例如,如果此参数的值为 ["Incorrect", "Ambiguous"],则会出现标有标记为不正确标记为模糊的按钮。这仅在 flagging_mode"manual" 时适用。
    • 选择的选项随后会与输入和输出一起记录下来。
  • flagging_dir:此参数接受一个字符串。
    • 它表示存储标记数据的目录名称。
  • flagging_callback:此参数接受 FlaggingCallback 类子类的一个实例。
    • 使用此参数,您可以编写在标记按钮被点击时运行的自定义代码。
    • 默认情况下,这被设置为 gr.JSONLogger 的一个实例。

标记数据会发生什么?

在由 flagging_dir 参数提供的目录中,一个 JSON 文件将记录标记数据。

以下是一个示例:下面的代码创建了嵌入在其下方的计算器界面。

import gradio as gr


def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        return num1 / num2


iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual"
)

iface.launch()

当您点击上方的标记按钮时,启动界面的目录中将包含一个新的 `flagged` 子文件夹,其中有一个 CSV 文件。这个 CSV 文件包含了所有被标记的数据。

+-- flagged/
|   +-- logs.csv

flagged/logs.csv

num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542

如果界面涉及文件数据,例如图像和音频组件,也会创建文件夹来存储这些标记数据。例如,一个从 image 输入到 image 输出的界面将创建以下结构。

+-- flagged/
|   +-- logs.csv
|   +-- image/
|   |   +-- 0.png
|   |   +-- 1.png
|   +-- Output/
|   |   +-- 0.png
|   |   +-- 1.png

flagged/logs.csv

im,Output timestamp
im/0.png,Output/0.png,2022-02-04 19:49:58.026963
im/1.png,Output/1.png,2022-02-02 10:40:51.093412

如果您希望用户提供标记原因,可以将一个字符串列表传递给 Interface 的 flagging_options 参数。用户在标记时必须选择其中一个选项,该选项将作为额外的一列保存到 CSV 中。

如果我们回到计算器示例,以下代码将创建嵌入在其下方的界面。

iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual",
    flagging_options=["wrong sign", "off by one", "other"]
)

iface.launch()

当用户点击标记按钮时,CSV 文件现在将包含一列,指示所选的选项。

flagged/logs.csv

num1,operation,num2,Output,flag,timestamp
5,add,7,-12,wrong sign,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,off by one,2022-02-04 11:42:32.062512

使用 Blocks 进行数据标记

如果您正在使用 gradio.Blocks 又如何呢?一方面,Blocks 提供了更大的灵活性——您可以在点击按钮时运行任何您想编写的 Python 代码,并通过 Blocks 中的内置事件进行分配。

同时,您可能希望使用现有的 FlaggingCallback 来避免编写额外代码。这需要两个步骤:

  1. 您必须在首次标记数据之前,在代码中的某个位置运行回调的 .setup() 方法。
  2. 当标记按钮被点击时,您需要触发回调的 .flag() 方法,确保正确收集参数并禁用典型的预处理。

这里是一个图像棕褐色滤镜 Blocks 演示的例子,它允许您使用默认的 CSVLogger 来标记数据。

import numpy as np
import gradio as gr

def sepia(input_img, strength):
    sepia_filter = strength * np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    ) + (1-strength) * np.identity(3)
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

callback = gr.CSVLogger()

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            img_input = gr.Image()
            strength = gr.Slider(0, 1, 0.5)
        img_output = gr.Image()
    with gr.Row():
        btn = gr.Button("Flag")

    # This needs to be called at some point prior to the first call to callback.flag()
    callback.setup([img_input, strength, img_output], "flagged_data_points")

    img_input.change(sepia, [img_input, strength], img_output)
    strength.change(sepia, [img_input, strength], img_output)

    # We can choose which components to flag -- in this case, we'll flag all of them
    btn.click(lambda *args: callback.flag(list(args)), [img_input, strength, img_output], None, preprocess=False)

demo.launch()

隐私

重要提示:请确保您的用户了解他们提交的数据何时被保存,以及您计划如何处理这些数据。当您使用 flagging_mode=auto(即通过演示提交的所有数据都会被标记)时,这一点尤为重要。

就是这样!愉快的构建吧 :)