Commits (2)
.venv
.env
*.db
*.md
%% Cell type:markdown id: tags:
# Use this to design AI agent as a POC
<h2>初始化</h2>
%% Cell type:code id: tags:
``` python
# 读取环境变量
from dotenv import load_dotenv
load_dotenv() # 加载 .env 文件
```
%% Cell type:markdown id: tags:
<hr>
<h2>使用通义千问作为LLM</h2>
%% Cell type:code id: tags:
``` python
# 使用通义千问
from langchain_community.llms import Tongyi
import os
# 初始化通义模型(以qwen-max为例)
llm_tongyi = Tongyi(
model_name="qwen-turbo",
dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
)
```
%% Cell type:markdown id: tags:
<h2>定义智能体工作流</h2>
%% Cell type:code id: tags:
``` python
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
class State(TypedDict):
# Messages have the type "list". The `add_messages` function
# in the annotation defines how this state key should be updated
# (in this case, it appends messages to the list, rather than overwriting them)
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
```
%% Cell type:markdown id: tags:
<h2>添加智能体节点</h2>
%% Cell type:code id: tags:
``` python
def chatbot(state: State):
return {"messages": [llm_tongyi.invoke(state["messages"])]}
# The first argument is the unique node name
# The second argument is the function or object that will be called whenever
# the node is used.
graph_builder.add_node("chatbot", chatbot)
```
%% Cell type:markdown id: tags:
<h2>添加工作流起点和终点</h2>
%% Cell type:code id: tags:
``` python
# 定义起点
graph_builder.add_edge(START, "chatbot")
# 定义终点
graph_builder.add_edge("chatbot", END)
# 完成编辑
graph = graph_builder.compile()
```
%% Cell type:markdown id: tags:
<h2>展示工作流</h2>
%% Cell type:code id: tags:
``` python
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
```
%% Cell type:markdown id: tags:
<h2>执行工作流</h2>
%% Cell type:code id: tags:
``` python
# 流式交互函数(适配通义模型)
def stream_graph_updates(user_input: str):
state = {"messages": [{"role": "user", "content": user_input}]}
for event in graph.stream(state):
for value in event.values():
last_message = value["messages"][-1]
print(f"Assistant: {last_message}")
while True:
try:
user_input = input("User: ")
print("User: ", user_input)
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
stream_graph_updates(user_input)
except:
# fallback if input() is not available
user_input = "What do you know about LangGraph?"
print("User: " + user_input)
stream_graph_updates(user_input)
break
```
%% Cell type:markdown id: tags:
# Use this to design AI agent as a POC
<h2>初始化</h2>
%% Cell type:code id: tags:
``` python
# 读取环境变量
from dotenv import load_dotenv
load_dotenv() # 加载 .env 文件
```
%% Cell type:markdown id: tags:
<hr>
<h2>定义使用搜索引擎</h2>
%% Cell type:code id: tags:
``` python
from langchain_community.tools.tavily_search import TavilySearchResults
tool = TavilySearchResults(max_results=3)
tools = [tool]
# tool.invoke("What's a 'node' in LangGraph?")
```
%% Cell type:markdown id: tags:
<hr>
<h2>使用通义千问作为LLM</h2>
%% Cell type:code id: tags:
``` python
# 使用通义千问
from langchain_community.llms import Tongyi
import os
# 初始化通义模型(以qwen-max为例)
llm_tongyi = Tongyi(
model_name="qwen-turbo",
dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
)
```
%% Cell type:markdown id: tags:
<h2>手动实现通义的工具调用流程</h2>
%% Cell type:code id: tags:
``` python
from typing import List, Dict, Any
from langchain_core.messages import HumanMessage
def manual_tool_invocation(llm: Tongyi, tools: List, query: str) -> str:
# 0. 调用搜索引擎
tool.invoke(query)
# 1. 构建工具描述
tool_descs = "\n".join([
f"{tool.name}: {tool.description}\n参数: {tool.args}"
for tool in tools
])
# 2. 构造特殊 prompt
prompt = f"""请根据问题决定是否需要使用工具。可用工具:
{tool_descs}
问题:{query}
如果需要使用工具,请严格按以下格式回复:
```json
{{"tool": "工具名", "args": {{"参数名": "参数值"}}}}
```
如果不需要工具,请直接回答问题。"""
# 3. 调用模型
response = llm.invoke(prompt)
# 4. 解析工具调用
if "```json" in response:
try:
import json
tool_call = json.loads(response.split("```json")[1].split("```")[0].strip())
selected_tool = next(t for t in tools if t.name == tool_call["tool"])
return selected_tool.invoke(tool_call["args"])
except:
return f"工具调用失败:{response}"
return response
# 使用示例
# result = manual_tool_invocation(llm_tongyi, tools, "What's a 'node' in LangGraph?")
```
%% Cell type:markdown id: tags:
<h2>定义智能体工作流</h2>
%% Cell type:code id: tags:
``` python
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
class State(TypedDict):
# Messages have the type "list". The `add_messages` function
# in the annotation defines how this state key should be updated
# (in this case, it appends messages to the list, rather than overwriting them)
messages: Annotated[list, add_messages]
graph_builder = StateGraph(State)
```
%% Cell type:markdown id: tags:
<h2>添加智能体节点</h2>
%% Cell type:code id: tags:
``` python
def chatbot(state: State):
# return {"messages": [llm_tongyi.invoke(state["messages"])]}
return {"messages": [manual_tool_invocation(llm_tongyi, tools, "What's a 'node' in LangGraph?")]}
# The first argument is the unique node name
# The second argument is the function or object that will be called whenever
# the node is used.
graph_builder.add_node("chatbot", chatbot)
```
%% Cell type:markdown id: tags:
<h2>添加工作流起点和终点</h2>
%% Cell type:code id: tags:
``` python
# 定义起点
graph_builder.add_edge(START, "chatbot")
# 定义终点
graph_builder.add_edge("chatbot", END)
# 完成编辑
graph = graph_builder.compile()
```
%% Cell type:markdown id: tags:
<h2>展示工作流</h2>
%% Cell type:code id: tags:
``` python
from IPython.display import Image, display
try:
display(Image(graph.get_graph().draw_mermaid_png()))
except Exception:
# This requires some extra dependencies and is optional
pass
```
%% Cell type:markdown id: tags:
<h2>执行工作流</h2>
%% Cell type:code id: tags:
``` python
# 流式交互函数(适配通义模型)
def stream_graph_updates(user_input: str):
state = {"messages": [{"role": "user", "content": user_input}]}
for event in graph.stream(state):
for value in event.values():
last_message = value["messages"][-1]
print(f"Assistant: {last_message}")
while True:
try:
user_input = input("User: ")
print("User: ", user_input)
if user_input.lower() in ["quit", "exit", "q"]:
print("Goodbye!")
break
stream_graph_updates(user_input)
except:
# fallback if input() is not available
user_input = "What do you know about LangGraph?"
print("User: " + user_input)
stream_graph_updates(user_input)
break
```
%% Cell type:code id: tags:
``` python
# 读取环境变量
from dotenv import load_dotenv
load_dotenv() # 加载 .env 文件
```
%% Cell type:code id: tags:
``` python
# 使用通义千问
from langchain_community.llms import Tongyi
import os
# 初始化通义模型(以qwen-max为例)
llm_tongyi = Tongyi(
model_name="qwen-turbo",
dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
)
```
%% Cell type:code id: tags:
``` python
# 构造提示,告诉模型用 JSON 输出
prompt = """
请根据以下用户请求,生成一个适合用于搜索的查询,并说明理由。
用户请求: “{0}”
请按照以下 JSON 格式返回:
{{
"search_query": "...",
"justification": "..."
}}
"""
```
%% Cell type:code id: tags:
``` python
# Invoke the augmented LLM
output = llm_tongyi.invoke(prompt.format("How does Calcium CT score relate to high cholesterol?"))
```
%% Cell type:code id: tags:
``` python
print(output)
```
%% Cell type:code id: tags:
``` python
# 读取环境变量
from dotenv import load_dotenv
load_dotenv() # 加载 .env 文件
```
%% Cell type:code id: tags:
``` python
from langsmith import traceable
```
%% Cell type:code id: tags:
``` python
# 使用通义千问
from langchain_community.llms import Tongyi
import os
# 初始化通义模型(以qwen-max为例)
llm_tongyi = Tongyi(
model_name="qwen-turbo",
dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
)
```
%% Cell type:code id: tags:
``` python
from typing import Callable, List, Dict, Any
import inspect
@traceable
def build_tool_prompt(tools: List[Callable], user_input: str) -> str:
"""
构造多工具调用提示,用于引导模型调用正确工具。
"""
tool_descriptions = []
for tool in tools:
sig = inspect.signature(tool)
params = ", ".join(f"{name}: {param.annotation.__name__}"
for name, param in sig.parameters.items())
doc = inspect.getdoc(tool) or "无描述"
tool_descriptions.append(f"""工具名:{tool.__name__}
描述:{doc}
参数:{params}
""")
tool_block = "\n\n".join(tool_descriptions)
prompt = f"""
你是一个助手,有以下工具可以调用:
{tool_block}
请根据用户请求,选择最合适的工具,并以如下格式返回调用指令(请使用英文括号):
调用:<工具名>(参数名1=值1, 参数名2=值2)
如果不需要调用任何工具,请回复:无需调用工具
用户请求:{user_input}
"""
return prompt
```
%% Cell type:code id: tags:
``` python
# 示例工具
@traceable
def multiply(a: int, b: int) -> int:
"""两个整数相乘"""
return a * b
@traceable
def add(a: int, b: int) -> int:
"""两个整数相加"""
return a + b
TOOLS = [multiply, add]
TOOL_MAP = {fn.__name__: fn for fn in TOOLS}
# 用户输入
user_input = "23和34的和是多少"
# 构造提示并调用模型
prompt = build_tool_prompt(TOOLS, user_input)
response = llm_tongyi.invoke(prompt)
print("🔍 模型响应:", response)
```
%% Cell type:code id: tags:
``` python
import re
@traceable
def extract_tool_call(response: str) -> Dict[str, Any]:
"""
从模型响应中提取工具调用命令。
返回格式:{ "tool_name": str, "args": dict }
"""
pattern = r"调用:(\w+)\((.*?)\)"
match = re.search(pattern, response)
if not match:
return {}
tool_name, args_str = match.groups()
args = {}
for part in args_str.split(","):
if "=" in part:
key, value = part.split("=")
key = key.strip()
value = value.strip()
try:
value = eval(value) # 小心执行不可信输入
except:
pass
args[key] = value
return {"tool_name": tool_name, "args": args}
```
%% Cell type:code id: tags:
``` python
# 解析调用
tool_call = extract_tool_call(response)
if tool_call:
tool_fn = TOOL_MAP.get(tool_call["tool_name"])
if tool_fn:
result = tool_fn(**tool_call["args"])
print("✅ 工具执行结果:", result)
else:
print("⚠️ 未知工具:", tool_call["tool_name"])
else:
print("ℹ️ 模型未调用任何工具")
```
%% Cell type:code id: tags:
``` python
print(tool_call)
```
%% Cell type:markdown id: tags:
<H1>数据库查询</H1>
<hr>
<H2>设置环境</H2>
%% Cell type:code id: tags:
``` python
# 读取环境变量
from dotenv import load_dotenv
load_dotenv() # 加载 .env 文件
```
%% Cell type:markdown id: tags:
<H2>定义LLM</H2>
%% Cell type:code id: tags:
``` python
# 使用通义千问
# from langchain_community.llms import Tongyi
# 使用封装后的ChatTongYi代替Tongyi。ChatTongYi封装了对工具的绑定功能bind_tools。
# from langchain_community.chat_models.tongyi import ChatTongyi
from langchain_openai import ChatOpenAI
import os
# 初始化通义模型(以qwen-max为例)
# llm_chat_tongyi = ChatTongyi(
# model_name="qwen-plus",
# dashscope_api_key=os.getenv("DASHSCOPE_API_KEY")
# )
llm_chat_tongyi = ChatOpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model_name="qwen-plus"
)
```
%% Cell type:markdown id: tags:
<H2>下载Sqlite到本地</H2>
%% Cell type:code id: tags:
``` python
import requests
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)
if response.status_code == 200:
# Open a local file in binary write mode
with open("Chinook.db", "wb") as file:
# Write the content of the response (the file) to the local file
file.write(response.content)
print("File downloaded and saved as Chinook.db")
else:
print(f"Failed to download the file. Status code: {response.status_code}")
```
%% Cell type:code id: tags:
``` python
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
# 单步调试
# print(db.dialect)
# print(db.get_usable_table_names())
# db.run("SELECT * FROM Artist LIMIT 10;")
```
%% Cell type:markdown id: tags:
<H2>定义节点异常处理工具</H2>
%% Cell type:code id: tags:
``` python
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
"""
Create a ToolNode with a fallback to handle errors and surface them to the agent.
"""
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
def handle_tool_error(state) -> dict:
error = state.get("error")
tool_calls = state["messages"][-1].tool_calls
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_calls
]
}
```
%% Cell type:markdown id: tags:
<H2>引用Sqlite表定义查询工具</H2>
%% Cell type:code id: tags:
``` python
from langchain_community.agent_toolkits import SQLDatabaseToolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm_chat_tongyi)
tools = toolkit.get_tools()
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
# 单步调试
# print(list_tables_tool.invoke(""))
# print(get_schema_tool.invoke("Artist,Customer"))
```
%% Cell type:markdown id: tags:
<H2>定义Sqlite查询工具</H2>
%% Cell type:code id: tags:
``` python
from langchain_core.tools import tool
@tool
def db_query_tool(query: str) -> str:
"""
Execute a SQL query against the database and get back the result.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
result = db.run_no_throw(query)
if not result:
return "Error: Query failed. Please rewrite your query and try again."
return result
# 单步调试
# print(db_query_tool.invoke("SELECT * FROM Artist LIMIT 10;"))
```
%% Cell type:markdown id: tags:
<H2>定义LLM的Sqlite的SQL检查提示模板</H2>
%% Cell type:code id: tags:
``` python
from langchain_core.prompts import ChatPromptTemplate
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
You will call the appropriate tool to execute the query after running this check."""
query_check_prompt = ChatPromptTemplate.from_messages(
[("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | llm_chat_tongyi.bind_tools(
[db_query_tool], tool_choice="required"
)
# 单步调试
# query_check.invoke({"messages": [("user", "SELECT * FROM Artist LIMIT 10;")]})
```
%% Cell type:markdown id: tags:
<H2>定义工作流</H2>
%% Cell type:code id: tags:
``` python
from typing import Annotated, Literal
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages
# Define the state for the agent
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# Define a new graph
workflow = StateGraph(State)
# Add a node for the first tool call
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
return {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": "sql_db_list_tables",
"args": {},
"id": "tool_abcd123",
}
],
)
]
}
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
"""
Use this tool to double-check if your query is correct before executing it.
"""
return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
workflow.add_node("first_tool_call", first_tool_call)
# Add nodes for the first two tools
workflow.add_node(
"list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
# Add a node for a model to choose the relevant tables based on the question and available tables
model_get_schema = llm_chat_tongyi.bind_tools(
[get_schema_tool]
)
workflow.add_node(
"model_get_schema",
lambda state: {
"messages": [model_get_schema.invoke(state["messages"])],
},
)
# Describe a tool to represent the end state
class SubmitFinalAnswer(BaseModel):
"""Submit the final answer to the user based on the query results."""
final_answer: str = Field(..., description="The final answer to the user")
# Add a node for a model to generate a query based on the question and schema
query_gen_system = """You are a SQL expert with a strong attention to detail.
Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
When generating the query:
Output the SQL query that answers the input question without a tool call.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
If you get an error while executing a query, rewrite the query and try again.
If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
query_gen_prompt = ChatPromptTemplate.from_messages(
[("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | llm_chat_tongyi.bind_tools(
[SubmitFinalAnswer]
)
def query_gen_node(state: State):
message = query_gen.invoke(state)
# Sometimes, the LLM will hallucinate and call the wrong tool. We need to catch this and return an error message.
tool_messages = []
if message.tool_calls:
for tc in message.tool_calls:
if tc["name"] != "SubmitFinalAnswer":
tool_messages.append(
ToolMessage(
content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
tool_call_id=tc["id"],
)
)
else:
tool_messages = []
return {"messages": [message] + tool_messages}
workflow.add_node("query_gen", query_gen_node)
# Add a node for the model to check the query before executing it
workflow.add_node("correct_query", model_check_query)
# Add node for executing the query
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
# Define a conditional edge to decide whether to continue or end the workflow
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
messages = state["messages"]
last_message = messages[-1]
# If there is a tool call, then we finish
if getattr(last_message, "tool_calls", None):
return END
if last_message.content.startswith("Error:"):
return "query_gen"
else:
return "correct_query"
# Specify the edges between the nodes
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges(
"query_gen",
should_continue,
)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")
# Compile the workflow into a runnable
app = workflow.compile()
```
%% Cell type:markdown id: tags:
<H2>展示工作流</H2>
%% Cell type:code id: tags:
``` python
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod
display(
Image(
app.get_graph().draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)
```
%% Cell type:markdown id: tags:
<H2>单步执行</H2>
%% Cell type:code id: tags:
``` python
messages = app.invoke(
{"messages": [("user", "Which sales agent made the most in sales in 2009?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
json_str
```
%% Cell type:markdown id: tags:
<H2>多步执行</H2>
%% Cell type:code id: tags:
``` python
# for event in app.stream(
# {"messages": [("user", "Which sales agent made the most in sales in 2009?")]}
# ):
# print(event)
```