Gradio 月活用户突破 100 万!

阅读更多
Gradio logo
  1. 使用 Blocks 构建
  2. Blocks 中的状态

管理状态

当使用 gr.Blocks() 构建 Gradio 应用程序时,您可能希望在用户之间共享某些值(例如,页面访问者计数),或者为单个用户在某些交互中持久保存值(例如,聊天记录)。这被称为**状态**,并且在 Gradio 应用程序中有三种管理状态的通用方法

  • **全局状态**:在 Gradio 应用程序运行时,在 Gradio 应用程序的所有用户之间持久保存和共享值
  • **会话状态**:在单个会话中,为 Gradio 应用程序的每个用户持久保存值。如果他们刷新页面,会话状态将被重置。
  • **浏览器状态**:在浏览器的 localStorage 中为 Gradio 应用程序的每个用户持久保存值,即使在页面刷新或关闭后数据仍然存在。

全局状态

Gradio 应用程序中的全局状态非常简单:在函数外部创建的任何变量都在所有用户之间全局共享。

这使得全局状态的管理非常简单,无需外部服务。例如,在此应用程序中,visitor_count 变量在所有用户之间共享

import gradio as gr

# Shared between all users
visitor_count = 0

def increment_counter():
    global visitor_count
    visitor_count += 1
    return visitor_count

with gr.Blocks() as demo:    
    number = gr.Textbox(label="Total Visitors", value="Counting...")
    demo.load(increment_counter, inputs=None, outputs=number)

demo.launch()

这意味着,任何时候您想在用户之间共享值时,都应该在函数内部声明它。但是,如果您需要在函数调用之间共享值,例如聊天记录怎么办?在这种情况下,您应该使用后续方法之一来管理状态。

会话状态

Gradio 支持会话状态,其中数据在页面会话中的多次提交之间保持持久性。重申一下,会话数据在模型的不同用户之间共享,并且如果用户刷新页面以重新加载 Gradio 应用程序,则会持久保存。要将会话状态中的数据存储,您需要执行三件事

  1. 创建一个 gr.State() 对象。如果此有状态对象具有默认值,请将其传递给构造函数。请注意,gr.State 对象必须是 可深度复制 的,否则您将需要使用下面描述的不同方法。
  2. 在事件监听器中,根据需要将 State 对象作为输入和输出放置。
  3. 在事件监听器函数中,将变量添加到输入参数和返回值中。

让我们看一个简单的例子。我们在下面有一个简单的结账应用程序,您可以在其中将商品添加到购物车。您还可以看到购物车的大小。

import gradio as gr

with gr.Blocks() as demo:
    cart = gr.State([])
    items_to_add = gr.CheckboxGroup(["Cereal", "Milk", "Orange Juice", "Water"])

    def add_items(new_items, previous_cart):
        cart = previous_cart + new_items
        return cart

    gr.Button("Add Items").click(add_items, [items_to_add, cart], cart)

    cart_size = gr.Number(label="Cart Size")
    cart.change(lambda cart: len(cart), cart, cart_size)

demo.launch()

请注意我们如何使用状态来实现这一点

  1. 我们将购物车商品存储在 gr.State() 对象中,此处初始化为空列表。
  2. 当向购物车添加商品时,事件监听器将购物车用作输入和输出 - 它返回更新后的购物车,其中包含所有商品。
  3. 我们可以将 .change 监听器附加到购物车,该监听器也将状态变量用作输入。

您可以将 gr.State 视为一个不可见的 Gradio 组件,它可以存储任何类型的值。在这里,cart 在前端不可见,但用于计算。

在任何事件监听器更改状态变量的值后,状态变量的 .change 监听器都会触发。如果状态变量保存一个序列(如 listsetdict),则当内部的任何元素更改时,将触发更改。如果它保存一个对象或原始类型,则当值的 **哈希值** 更改时,将触发更改。因此,如果您定义了一个自定义类并创建了一个 gr.State 变量,该变量是该类的实例,请确保该类包含一个合理的 __hash__ 实现。

当用户刷新页面时,会话状态变量的值将被清除。该值在用户关闭选项卡后在应用程序后端存储 60 分钟(可以通过 gr.Blocks 中的 delete_cache 参数配置此时间)。

文档 中了解有关 State 的更多信息。

那么无法深度复制的对象呢?

如前所述,存储在 gr.State 中的值必须是 可深度复制 的。如果您正在处理无法深度复制的复杂对象,您可以采用不同的方法来手动读取用户的 session_hash 并存储一个全局 dictionary,其中包含每个用户的对象实例。以下是如何执行此操作的方法

import gradio as gr

class NonDeepCopyable:
    def __init__(self):
        from threading import Lock
        self.counter = 0
        self.lock = Lock()  # Lock objects cannot be deepcopied
    
    def increment(self):
        with self.lock:
            self.counter += 1
            return self.counter

# Global dictionary to store user-specific instances
instances = {}

def initialize_instance(request: gr.Request):
    instances[request.session_hash] = NonDeepCopyable()
    return "Session initialized!"

def cleanup_instance(request: gr.Request):
    if request.session_hash in instances:
        del instances[request.session_hash]

def increment_counter(request: gr.Request):
    if request.session_hash in instances:
        instance = instances[request.session_hash]
        return instance.increment()
    return "Error: Session not initialized"

with gr.Blocks() as demo:
    output = gr.Textbox(label="Status")
    counter = gr.Number(label="Counter Value")
    increment_btn = gr.Button("Increment Counter")
    increment_btn.click(increment_counter, inputs=None, outputs=counter)
    
    # Initialize instance when page loads
    demo.load(initialize_instance, inputs=None, outputs=output)    
    # Clean up instance when page is closed/refreshed
    demo.close(cleanup_instance)    

demo.launch()

浏览器状态

Gradio 还支持浏览器状态,其中数据持久保存在浏览器的 localStorage 中,即使页面刷新或关闭后也是如此。这对于存储用户偏好、设置、API 密钥或其他应跨会话持久存在的数据非常有用。要使用本地状态

  1. 创建一个 gr.BrowserState 对象。您可以选择提供初始默认值和一个键来标识浏览器 localStorage 中的数据。
  2. 在事件监听器中像常规 gr.State 组件一样使用它作为输入和输出。

这是一个简单的示例,可以跨会话保存用户的用户名和密码

import random
import string
import gradio as gr
import time
with gr.Blocks() as demo:
    gr.Markdown("Your Username and Password will get saved in the browser's local storage. "
                "If you refresh the page, the values will be retained.")
    username = gr.Textbox(label="Username")
    password = gr.Textbox(label="Password", type="password")
    btn = gr.Button("Generate Randomly")
    local_storage = gr.BrowserState(["", ""])
    saved_message = gr.Markdown("✅ Saved to local storage", visible=False)

    @btn.click(outputs=[username, password])
    def generate_randomly():
        u = "".join(random.choices(string.ascii_letters + string.digits, k=10))
        p = "".join(random.choices(string.ascii_letters + string.digits, k=10))
        return u, p

    @demo.load(inputs=[local_storage], outputs=[username, password])
    def load_from_local_storage(saved_values):
        print("loading from local storage", saved_values)
        return saved_values[0], saved_values[1]

    @gr.on([username.change, password.change], inputs=[username, password], outputs=[local_storage])
    def save_to_local_storage(username, password):
        return [username, password]

    @gr.on(local_storage.change, outputs=[saved_message])
    def show_saved_message():
        timestamp = time.strftime("%I:%M:%S %p")
        return gr.Markdown(
            f"✅ Saved to local storage at {timestamp}",
            visible=True
        )

demo.launch()

注意:如果 Grado 应用程序重新启动,存储在 gr.BrowserState 中的值不会持久保存。要使其持久保存,您可以在 gr.BrowserState 组件中硬编码 storage_keysecret 的特定值,并在同一服务器名称和服务器端口上重新启动 Gradio 应用程序。但是,只有在运行受信任的 Gradio 应用程序时才应这样做,因为原则上,这可能允许一个 Gradio 应用程序访问由另一个 Gradio 应用程序创建的 localStorage 数据。