langgraph断点

langgraph断点

起男 23 2025-05-18

langgraph断点

断点是langgraph提供的在流程中供用户暂停修改的机制

基本使用

大模型

from langchain_openai import ChatOpenAI
import os

llm = ChatOpenAI(
    api_key=os.environ.get("ALI_API_KEY"),
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
    model="qwen-max"
)

工具

from langchain_core.tools import tool

@tool
def multiply(a:int,b:int)->int:
    """计算a乘b"""
    return a*b
@tool
def add(a:int,b:int)->int:
    """计算a加b"""
    return a+b
@tool
def divide(a:int,b:int)->float:
    """计算a除b"""
    return a/b
#绑定
tools = [add,multiply,divide]
# parallel_tool_calls关闭工具并行
llm_with_tools = llm.bind_tools(tools,parallel_tool_calls=False)

节点和状态

from typing_extensions import TypedDict
from typing import Annotated
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage

# 提示词
sys_msg = SystemMessage("你是一个乐于助人的助手,负责对一组输入进行算数运算")
#状态
class MessageState(TypedDict):
    # 每次更新进行追加
    messages : Annotated[list[AnyMessage],add_messages]
# 节点
def assistant(state : MessageState):
    return {"messages":[llm_with_tools.invoke([sys_msg]+state["messages"])]}

定义图

from langgraph.graph import StateGraph,START
from langgraph.prebuilt import ToolNode,tools_condition
from langgraph.checkpoint.memory import MemorySaver

builder = StateGraph(MessageState)
# 定义节点
builder.add_node("assistant", assistant)
builder.add_node("tools",ToolNode(tools))
#定义边
builder.add_edge(START,"assistant")
builder.add_conditional_edges("assistant",tools_condition)
# 循环工具到llm的边
builder.add_edge("tools","assistant")
memory = MemorySaver()
#在tools节点之前进行打断
graph = builder.compile(interrupt_before=["tools"],checkpointer=memory)

测试

thread = {"configurable":{"thread_id":"1"}}
# 流式输出
for event in graph.stream({"messages":HumanMessage("3加4是多少")},thread,stream_mode="values"):
    event["messages"][-1].pretty_print()
#查看图中打断的方法
state = graph.get_state(thread)
print(state.next)
#获取用户反馈
user_approval = input("是否调用工具?(是/否):")
if user_approval.lower() == "是":
    #继续执行,通过传None实现
    for event in graph.stream(None,thread,stream_mode="values"):
        event["messages"][-1].pretty_print()
else:
    print("用户操作已取消")

断点时修改状态

首先修改打断的时机,在调用大模型前进行断点

#在assistant节点之前进行打断
graph = builder.compile(interrupt_before=["assistant"],checkpointer=memory)

测试

thread = {"configurable":{"thread_id":"1"}}
# 流式输出
for event in graph.stream({"messages":"3加4是多少"},thread,stream_mode="values"):
    event["messages"][-1].pretty_print()
#修改状态
graph.update_state(
    thread,
    {"messages":[HumanMessage("改成3乘4是多少")]}
)
# 继续执行,通过传None实现
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()

用外部方式修改

创建一个空节点专门用来修改状态

空节点

def human_feedback(state : MessagesState):
    pass

定义图

加入这个空节点,并添加断点

builder = StateGraph(MessagesState)
# 定义节点
builder.add_node("assistant", assistant)
builder.add_node("tools",ToolNode(tools))
builder.add_node("human_feedback",human_feedback)
#定义边
builder.add_edge(START,"human_feedback")
builder.add_edge("human_feedback","assistant")
builder.add_conditional_edges("assistant",tools_condition)
builder.add_edge("tools","assistant")
memory = MemorySaver()
#在human_feedback节点之前进行打断
graph = builder.compile(interrupt_before=["human_feedback"],checkpointer=memory)

测试

thread = {"configurable":{"thread_id":"1"}}
# 流式输出
for event in graph.stream({"messages":"3加4是多少"},thread,stream_mode="values"):
    event["messages"][-1].pretty_print()
#用户输出
user_approval = input("输入你的修改")
#修改状态
graph.update_state(
    thread,
    {"messages":user_approval},
    as_node="human_feedback"
)
# 继续执行,通过传None实现
for event in graph.stream(None, thread, stream_mode="values"):
    event["messages"][-1].pretty_print()

动态断点

通过抛出NodeInterrupt异常可用实现动态断点的效果

定义状态和节点

from typing_extensions import TypedDict
from langgraph.errors import NodeInterrupt

#状态
class State(TypedDict):
    input:str
#节点
def step_1(state:State):
    print("step_1")
    return state
def step_2(state:State):
    if len(state["input"]) > 5:
        raise NodeInterrupt(f"收到的长度超过5个字符")
    print("step_2")
    return state
def step_3(state:State):
    print("step_3")
    return state

构建图

builder = StateGraph(State)
builder.add_node("step_1",step_1)
builder.add_node("step_2",step_2)
builder.add_node("step_3",step_3)
builder.add_edge(START,"step_1")
builder.add_edge("step_1","step_2")
builder.add_edge("step_2","step_3")
builder.add_edge("step_3",END)
memory = MemorySaver()
graph = builder.compile(checkpointer=memory)

测试

thread = {"configurable":{"thread_id":"1"}}
for event in graph.stream({"input":"你好123123"},thread,stream_mode="values"):
    print(event)
#查看下一步计划
state = graph.get_state(thread)
print(state.next)
#查看中断状态
print(state.tasks)
#更新状态
graph.update_state(
    thread,
    {"input":"你好"}
)
#重新运行
for event in graph.stream(None,thread,stream_mode="values"):
    print(event)