1. 其他教程
  2. 使用标记

使用标记

引言

当你演示一个机器学习模型时,你可能想从尝试该模型的用户那里收集数据,特别是模型行为不如预期的数据点。捕获这些“困难”的数据点非常有价值,因为它可以让你改进机器学习模型,使其更可靠和健壮。

Gradio 通过在每个 Interface 中包含一个 **Flag** 按钮来简化此数据的收集。这允许用户或测试人员轻松地将数据发送回运行演示的机器。在本指南中,我们将讨论如何使用标记功能,包括 gradio.Interfacegradio.Blocks

gradio.Interface 中的 **Flag** 按钮

使用 Gradio 的 Interface 进行标记非常简单。默认情况下,在输出组件下方有一个标有 **Flag** 的按钮。当测试您的模型用户看到有趣的输出时,他们可以点击标记按钮将输入和输出数据发送回运行演示的机器。样本(默认情况下)保存到一个 CSV 日志文件中。如果演示涉及图像、音频、视频或其他类型的文件,这些文件将单独保存在一个并行目录中,并且这些文件的路径将保存在 CSV 文件中。

gradio.Interface 中有 四个参数 控制标记的工作方式。我们将详细介绍它们。

  • flagging_mode: 此参数可以设置为 "manual"(默认值)、"auto""never"
    • manual: 用户将看到一个标记按钮,并且只有在点击该按钮时才会标记样本。
    • auto: 用户将看不到标记按钮,但每个样本都会自动标记。
    • never: 用户将看不到标记按钮,并且不会标记任何样本。
  • flagging_options: 此参数可以是 None(默认值)或一个字符串列表。
    • 如果为 None,则用户只需点击 **Flag** 按钮,不显示其他选项。
    • 如果提供了一个字符串列表,则用户会看到几个按钮,对应于提供的每个字符串。例如,如果此参数的值为 ["Incorrect", "Ambiguous"],则会出现标有 **Flag as Incorrect** 和 **Flag as Ambiguous** 的按钮。这仅适用于 flagging_mode"manual" 的情况。
    • 然后选择的选项将与输入和输出一起记录。
  • flagging_dir: 此参数接受一个字符串。
    • 它表示存储标记数据的目录名称。
  • flagging_callback: 此参数接受 FlaggingCallback 类的子类的实例
    • 使用此参数,您可以在点击标记按钮时运行自定义代码
    • 默认情况下,这设置为 gr.JSONLogger 的实例

标记的数据会发生什么?

flagging_dir 参数提供的目录中,一个 JSON 文件将记录标记的数据。

例如:下面的代码创建了下面嵌入的计算器界面

import gradio as gr


def calculator(num1, operation, num2):
    if operation == "add":
        return num1 + num2
    elif operation == "subtract":
        return num1 - num2
    elif operation == "multiply":
        return num1 * num2
    elif operation == "divide":
        return num1 / num2


iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual"
)

iface.launch()

当您点击上面的标记按钮时,启动界面的目录将包含一个新的 flagged 子文件夹,里面有一个 csv 文件。这个 csv 文件包含所有被标记的数据。

+-- flagged/
|   +-- logs.csv

flagged/logs.csv

num1,operation,num2,Output,timestamp
5,add,7,12,2022-01-31 11:40:51.093412
6,subtract,1.5,4.5,2022-01-31 03:25:32.023542

如果界面涉及文件数据,例如图像和音频组件,也将创建文件夹来存储这些标记的数据。例如,一个 image 输入到 image 输出的界面将创建以下结构。

+-- flagged/
|   +-- logs.csv
|   +-- image/
|   |   +-- 0.png
|   |   +-- 1.png
|   +-- Output/
|   |   +-- 0.png
|   |   +-- 1.png

flagged/logs.csv

im,Output timestamp
im/0.png,Output/0.png,2022-02-04 19:49:58.026963
im/1.png,Output/1.png,2022-02-02 10:40:51.093412

如果您希望用户提供标记原因,可以将字符串列表传递给 Interface 的 flagging_options 参数。用户在标记时必须选择其中一个选项,该选项将作为附加列保存到 CSV 中。

如果我们回到计算器示例,下面的代码将创建下面嵌入的界面。

iface = gr.Interface(
    calculator,
    ["number", gr.Radio(["add", "subtract", "multiply", "divide"]), "number"],
    "number",
    flagging_mode="manual",
    flagging_options=["wrong sign", "off by one", "other"]
)

iface.launch()

当用户点击标记按钮时,csv 文件现在将包含一列指示所选选项。

flagged/logs.csv

num1,operation,num2,Output,flag,timestamp
5,add,7,-12,wrong sign,2022-02-04 11:40:51.093412
6,subtract,1.5,3.5,off by one,2022-02-04 11:42:32.062512

使用 Blocks 进行标记

如果你使用的是 gradio.Blocks 呢?一方面,Blocks 提供了更大的灵活性——你可以编写任何你想要的 Python 代码,并在点击按钮时使用 Blocks 中的内置事件来运行它。

同时,你可能想使用现有的 FlaggingCallback 来避免编写额外的代码。这需要两个步骤

  1. 你必须在第一次标记数据之前在代码中的某个位置运行回调的 .setup()
  2. 当点击标记按钮时,你触发回调的 .flag() 方法,确保正确收集参数并禁用典型的预处理。

这里有一个图像棕褐色滤镜 Blocks 演示的示例,它允许你使用默认的 CSVLogger 标记数据

import numpy as np
import gradio as gr

def sepia(input_img, strength):
    sepia_filter = strength * np.array(
        [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
    ) + (1-strength) * np.identity(3)
    sepia_img = input_img.dot(sepia_filter.T)
    sepia_img /= sepia_img.max()
    return sepia_img

callback = gr.CSVLogger()

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            img_input = gr.Image()
            strength = gr.Slider(0, 1, 0.5)
        img_output = gr.Image()
    with gr.Row():
        btn = gr.Button("Flag")

    # This needs to be called at some point prior to the first call to callback.flag()
    callback.setup([img_input, strength, img_output], "flagged_data_points")

    img_input.change(sepia, [img_input, strength], img_output)
    strength.change(sepia, [img_input, strength], img_output)

    # We can choose which components to flag -- in this case, we'll flag all of them
    btn.click(lambda *args: callback.flag(list(args)), [img_input, strength, img_output], None, preprocess=False)

demo.launch()

隐私

重要提示:请确保您的用户了解他们提交的数据何时被保存以及您打算如何处理它。当您使用 flagging_mode=auto(通过演示提交的所有数据都被标记时)时,这一点尤为重要

就这些了!愉快的构建吧:)