使用 gr.Blocks() 构建 Gradio 应用程序时,您可能希望在用户之间共享某些值(例如,页面的访问者计数),或者在特定交互中为单个用户保留值(例如,聊天记录)。这被称为 状态,在 Gradio 应用程序中有三种通用的方法来管理状态
Gradio 应用程序中的全局状态非常简单:在函数外部创建的任何变量都会在所有用户之间全局共享。
这使得管理全局状态非常简单,无需外部服务。例如,在这个应用程序中,visitor_count 变量在所有用户之间共享
import gradio as gr
# Shared between all users
visitor_count = 0
def increment_counter():
global visitor_count
visitor_count += 1
return visitor_count
with gr.Blocks() as demo:
number = gr.Textbox(label="Total Visitors", value="Counting...")
demo.load(increment_counter, inputs=None, outputs=number)
demo.launch()这意味着任何时候您不希望在用户之间共享值时,您都应该在函数内部声明它。但是,如果您需要在函数调用之间共享值,例如聊天记录,该怎么办?在这种情况下,您应该使用后续的方法之一来管理状态。
Gradio 支持会话状态,其中数据在页面会话中的多次提交之间保留。重申一下,会话数据不在模型的不同用户之间共享,并且如果用户刷新页面重新加载 Gradio 应用程序,数据也不会保留。要将会话数据存储在会话状态中,您需要做三件事
gr.State() 对象。如果此有状态对象有默认值,请将其传递给构造函数。请注意,gr.State 对象必须是 可深度复制(deepcopy-able)的,否则您需要使用如下所述的不同方法。State 对象作为输入和输出。让我们看一个简单的例子。下面我们有一个简单的结账应用程序,您可以在其中将商品添加到购物车。您还可以查看购物车的大小。
import gradio as gr
with gr.Blocks() as demo:
cart = gr.State([])
items_to_add = gr.CheckboxGroup(["Cereal", "Milk", "Orange Juice", "Water"])
def add_items(new_items, previous_cart):
cart = previous_cart + new_items
return cart
gr.Button("Add Items").click(add_items, [items_to_add, cart], cart)
cart_size = gr.Number(label="Cart Size")
cart.change(lambda cart: len(cart), cart, cart_size)
demo.launch()请注意我们如何使用状态来实现这一点
gr.State() 对象中,这里初始化为一个空列表。.change 监听器附加到购物车,该监听器也将状态变量用作输入。您可以将 gr.State 视为一个不可见的 Gradio 组件,它可以存储任何类型的值。在这里,cart 在前端不可见,但用于计算。
状态变量的 .change 监听器会在任何事件监听器更改状态变量值后触发。如果状态变量包含序列(如 list、set 或 dict),则如果其中任何元素发生更改,则会触发更改。如果它包含对象或原始类型,则如果值的 哈希值 发生更改,则会触发更改。因此,如果您定义一个自定义类并创建一个作为该类实例的 gr.State 变量,请确保该类包含一个合理的 __hash__ 实现。
当用户刷新页面时,会话状态变量的值将被清除。在用户关闭选项卡后,该值会在应用程序后端存储 60 分钟(这可以通过 gr.Blocks 中的 delete_cache 参数进行配置)。
在文档中了解有关 State 的更多信息。
对于不可深度复制的对象怎么办?
如前所述,存储在 gr.State 中的值必须是 可深度复制(deepcopy-able)的。如果您正在处理一个无法深度复制的复杂对象,您可以采取不同的方法来手动读取用户的 session_hash 并存储一个全局 dictionary,其中包含每个用户的对象实例。您可以这样做
import gradio as gr
class NonDeepCopyable:
def __init__(self):
from threading import Lock
self.counter = 0
self.lock = Lock() # Lock objects cannot be deepcopied
def increment(self):
with self.lock:
self.counter += 1
return self.counter
# Global dictionary to store user-specific instances
instances = {}
def initialize_instance(request: gr.Request):
instances[request.session_hash] = NonDeepCopyable()
return "Session initialized!"
def cleanup_instance(request: gr.Request):
if request.session_hash in instances:
del instances[request.session_hash]
def increment_counter(request: gr.Request):
if request.session_hash in instances:
instance = instances[request.session_hash]
return instance.increment()
return "Error: Session not initialized"
with gr.Blocks() as demo:
output = gr.Textbox(label="Status")
counter = gr.Number(label="Counter Value")
increment_btn = gr.Button("Increment Counter")
increment_btn.click(increment_counter, inputs=None, outputs=counter)
# Initialize instance when page loads
demo.load(initialize_instance, inputs=None, outputs=output)
# Clean up instance when page is closed/refreshed
demo.unload(cleanup_instance)
demo.launch()Gradio 还支持浏览器状态,其中数据保留在浏览器的本地存储(localStorage)中,即使在页面刷新或关闭后也是如此。这对于存储用户偏好、设置、API 密钥或其他应跨会话保留的数据非常有用。要使用本地状态
gr.BrowserState 对象。您可以选择提供初始默认值和用于在浏览器本地存储中标识数据的密钥。gr.State 组件一样在事件监听器中将其用作输入和输出。这是一个简单的示例,用于跨会话保存用户的用户名和密码
import random
import string
import gradio as gr
import time
with gr.Blocks() as demo:
gr.Markdown("Your Username and Password will get saved in the browser's local storage. "
"If you refresh the page, the values will be retained.")
username = gr.Textbox(label="Username")
password = gr.Textbox(label="Password", type="password")
btn = gr.Button("Generate Randomly")
local_storage = gr.BrowserState(["", ""])
saved_message = gr.Markdown("✅ Saved to local storage", visible=False)
@btn.click(outputs=[username, password])
def generate_randomly():
u = "".join(random.choices(string.ascii_letters + string.digits, k=10))
p = "".join(random.choices(string.ascii_letters + string.digits, k=10))
return u, p
@demo.load(inputs=[local_storage], outputs=[username, password])
def load_from_local_storage(saved_values):
print("loading from local storage", saved_values)
return saved_values[0], saved_values[1]
@gr.on([username.change, password.change], inputs=[username, password], outputs=[local_storage])
def save_to_local_storage(username, password):
return [username, password]
@gr.on(local_storage.change, outputs=[saved_message])
def show_saved_message():
timestamp = time.strftime("%I:%M:%S %p")
return gr.Markdown(
f"✅ Saved to local storage at {timestamp}",
visible=True
)
demo.launch()
注意:如果 Gradio 应用程序重新启动,存储在 gr.BrowserState 中的值不会保留。要保留它,您可以硬编码 storage_key 和 secret 的特定值到 gr.BrowserState 组件中,并在相同的服务器名称和服务器端口上重新启动 Gradio 应用程序。但是,只有当您运行受信任的 Gradio 应用程序时才应执行此操作,因为原则上,这可能允许一个 Gradio 应用程序访问由不同 Gradio 应用程序创建的本地存储数据。