In [None]:
## 准备部分
# 指定Embed对象
embed_object = "labors-law"
# 指定用户问题
query_text = "关于工资福利,有什么规定?"

In [None]:
# 导入必要的包
import time
import os
import sys
import json
from chromadb import Client

# 获取当前脚本的目录
current_dir = os.getcwd()

# 读取设定文件
config_file_path = os.path.join(current_dir, "..", "setup", "config_embed.json")

# 读取配置项目
try:
 with open(config_file_path, "r", encoding="utf-8") as f:
 dict_config = json.load(f)
 # 读取通用项目
 API_URL = dict_config["API_URL"]
 cohere_access_token = dict_config["cohere_access_token"]
 custom_proxies = dict_config["custom_proxies"]
 list_chroma_dir = dict_config["list_chroma_dir"]
 list_embed_file_path = dict_config["list_embed_file_path"]
 # 读取Embed对象项目
 model_name = dict_config["docs"][embed_object]["model_name"]
 embed_file_name = dict_config["docs"][embed_object]["file_name"]
 split_chunk_size = dict_config["docs"][embed_object]["split_chunk_size"]
 split_overlap = dict_config["docs"][embed_object]["split_overlap"]
 model_batch_size = dict_config["docs"][embed_object]["model_batch_size"]
 collection_ids_prefix = dict_config["docs"][embed_object]["collection_ids_prefix"]
 embed_collection_name = dict_config["docs"][embed_object]["collection_name"]
except Exception as e:
 print("配置文件读取失败", e)
 sys.exit("Stop processing")

# 回退到上级目录并构建相对路径
chroma_dir = os.path.join(current_dir, *list_chroma_dir)
embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)

In [None]:
import chromadb

# 初始化 Chroma 客户端,指定之前的数据存储目录
client = chromadb.PersistentClient(path=chroma_dir)

# 获取 "document_embeddings" 集合
collection = client.get_collection(embed_collection_name)

In [None]:
# 装载本地的Hugging Face模型
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)

In [None]:
# 查询文本
# query_embedding = get_embeddings([query_text]) # 使用之前的 get_embedding 函数生成嵌入
query_embedding = model.encode([query_text])

In [None]:
# 查询集合,返回前5个匹配的文档
results = collection.query(
 # query_embeddings=query_embedding["embeddings"],
 query_embeddings=query_embedding,
 n_results=5
)

# 显示查询结果
for doc, metadata in zip(results['documents'], results['metadatas']):
 print("Document:", doc)
 print("Metadata:", metadata)

In [None]:
import json
print(json.dumps(results["documents"], indent=4, ensure_ascii=False))

In [None]:
# 释放内存
del model
import gc
gc.collect()

In [None]:
# 导入QA模型
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen2.5-1.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
 model_name,
 torch_dtype="auto",
 device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
# 准备提示词
prompt = query_text
messages = [
 {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
 {"role": "system", "content": "在回答的过程中,必须根据以下提示回答:\n" + "\n - ".join(results["documents"][0])},
 {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
 messages,
 tokenize=False,
 add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

In [None]:
# 生成回答
generated_ids = model.generate(
 **model_inputs,
 max_new_tokens=512
)
generated_ids = [
 output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)