Commit 7f1e6463 authored by Administrator's avatar Administrator
Browse files

initial version of using local Hugging Face model for embed and retrieve

parent bb8a7037
%% Cell type:code id:d14ad3f1 tags:
``` python
## 准备部分
# 指定Embed对象
embed_object = "oracle-scm-planning"
embed_object = "labors-law"
```
%% Cell type:code id:29688979-89c0-47c7-84ab-2b4b182d2bd7 tags:
``` python
......@@ -62,15 +62,15 @@
# 加载PDF文档
loader = PyPDFLoader(file_path)
documents = loader.load()
# 将所有页内容合并成一个字符串
full_text = "\n".join([doc.page_content for doc in documents])
full_text = " ".join([doc.page_content for doc in documents])
import re
# 使用正则表达式去除跨页的页码或分隔符,例如 "\n—6—\n"
full_text = re.sub(r"\n—\d+—\n", "\n", full_text)
# full_text = re.sub(r"\n—\d+—\n", "\n", full_text)
# 去掉多余换行符,将换行符替换为空格
full_text = re.sub(r"\s*\n\s*", "", full_text)
# 分割文档
......@@ -110,42 +110,17 @@
``` python
from IPython.display import clear_output
```
%% Cell type:code id:13e802ab-0f76-4aa9-a7dc-7cde4d7851fb tags:
%% Cell type:code id:5926d5d9 tags:
``` python
# 执行嵌入查询
def get_embeddings(texts: list):
# 装载本地的Hugging Face模型
from sentence_transformers import SentenceTransformer
retry_cnt = 0
while True:
try:
response = requests.post(
API_URL,
proxies=custom_proxies,
timeout=(60, 6),
headers={"Authorization": "Bearer {0}".format(cohere_access_token)},
json={
"model": model_name,
"texts": texts,
"input_type": "classification",
"truncate": "NONE"
}
)
response.raise_for_status() # 确保请求成功
break
except Exception as e:
time.sleep(1)
print("error", e)
if retry_cnt > 5:
return None
retry_cnt += 1
return response.json() # 返回嵌入向量
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
```
%% Cell type:code id:0b23e16a-c069-4be2-8109-166f5cbb722c tags:
``` python
......@@ -155,12 +130,14 @@
text = doc.page_content
request_docs.append(text)
embedded_docs = []
for i in range(0, len(request_docs), model_batch_size):
embeddings_response = get_embeddings(request_docs[i:i+model_batch_size])
embedded_docs.extend(embeddings_response['embeddings'])
# embeddings_response = get_embeddings(request_docs[i:i+model_batch_size])
# embedded_docs.extend(embeddings_response['embeddings'])
embedding = model.encode(request_docs[i:i+model_batch_size])
embedded_docs.extend(embedding)
```
%% Cell type:code id:fb491aa1-4fdc-423f-ba76-861b68959777 tags:
``` python
......@@ -175,5 +152,24 @@
)
```
%%%% Output: display_data
%% Cell type:code id:e7d24e7f tags:
``` python
# 释放内存
del model
import gc
gc.collect()
```
%%%% Output: execute_result
14438
%% Cell type:code id:ac99b839 tags:
``` python
```
......
%% Cell type:code id:b1f2ff89 tags:
``` python
## 准备部分
# 指定Embed对象
embed_object = "oracle-scm-planning"
# 指定用户问题
query_text = "How AI can help SCM planning?"
```
%% Cell type:code id:4b9d9fa0-1c73-4cef-abc4-397458215159 tags:
``` python
# 导入必要的包
import time
import os
import sys
import json
import requests
from chromadb import Client
# 获取当前脚本的目录
current_dir = os.getcwd()
# 读取设定文件
config_file_path = os.path.join(current_dir, "..", "setup", "config_embed.json")
# 读取配置项目
try:
with open(config_file_path, "r", encoding="utf-8") as f:
dict_config = json.load(f)
# 读取通用项目
API_URL = dict_config["API_URL"]
cohere_access_token = dict_config["cohere_access_token"]
custom_proxies = dict_config["custom_proxies"]
list_chroma_dir = dict_config["list_chroma_dir"]
list_embed_file_path = dict_config["list_embed_file_path"]
# 读取Embed对象项目
model_name = dict_config["docs"][embed_object]["model_name"]
embed_file_name = dict_config["docs"][embed_object]["file_name"]
split_chunk_size = dict_config["docs"][embed_object]["split_chunk_size"]
split_overlap = dict_config["docs"][embed_object]["split_overlap"]
model_batch_size = dict_config["docs"][embed_object]["model_batch_size"]
collection_ids_prefix = dict_config["docs"][embed_object]["collection_ids_prefix"]
embed_collection_name = dict_config["docs"][embed_object]["collection_name"]
except Exception as e:
print("配置文件读取失败", e)
sys.exit("Stop processing")
# 回退到上级目录并构建相对路径
chroma_dir = os.path.join(current_dir, *list_chroma_dir)
embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)
```
%% Cell type:code id:630b5215-b3ad-43e6-8f48-83b78abf9c10 tags:
``` python
import chromadb
# 初始化 Chroma 客户端,指定之前的数据存储目录
client = chromadb.PersistentClient(path=chroma_dir)
# 获取 "document_embeddings" 集合
collection = client.get_collection(embed_collection_name)
```
%% Cell type:code id:953b58a8-3c70-4774-ba91-100cfd744337 tags:
``` python
# 执行嵌入查询
def get_embeddings(texts: list):
retry_cnt = 0
while True:
try:
response = requests.post(
API_URL,
proxies=custom_proxies,
timeout=(60, 6),
headers={"Authorization": "Bearer {0}".format(cohere_access_token)},
json={
"model": model_name,
"texts": texts,
"input_type": "classification",
"truncate": "NONE"
}
)
response.raise_for_status() # 确保请求成功
break
except Exception as e:
time.sleep(1)
print("error", e)
if retry_cnt > 5:
return None
retry_cnt += 1
return response.json() # 返回嵌入向量
```
%% Cell type:code id:667253d4-1a55-406a-b3d4-51f4b9957aca tags:
``` python
# 查询文本
query_embedding = get_embeddings([query_text]) # 使用之前的 get_embedding 函数生成嵌入
```
%% Cell type:code id:f3df429f-b4d6-4e66-803b-ae5e9ed95922 tags:
``` python
# 查询集合,返回前5个匹配的文档
results = collection.query(
query_embeddings=query_embedding["embeddings"],
n_results=5
)
# 显示查询结果
for doc, metadata in zip(results['documents'], results['metadatas']):
print("Document:", doc)
print("Metadata:", metadata)
```
%% Cell type:code id:01ed14d1-71b5-42b3-82d6-47a053894a74 tags:
``` python
import json
print(json.dumps(results["documents"], indent=4, ensure_ascii=False))
```
%% Cell type:code id:b1f2ff89 tags:
``` python
## 准备部分
# 指定Embed对象
embed_object = "labors-law"
# 指定用户问题
query_text = "对于劳动合同的解除有什么规定?"
```
%% Cell type:code id:4b9d9fa0-1c73-4cef-abc4-397458215159 tags:
``` python
# 导入必要的包
import time
import os
import sys
import json
from chromadb import Client
# 获取当前脚本的目录
current_dir = os.getcwd()
# 读取设定文件
config_file_path = os.path.join(current_dir, "..", "setup", "config_embed.json")
# 读取配置项目
try:
with open(config_file_path, "r", encoding="utf-8") as f:
dict_config = json.load(f)
# 读取通用项目
API_URL = dict_config["API_URL"]
cohere_access_token = dict_config["cohere_access_token"]
custom_proxies = dict_config["custom_proxies"]
list_chroma_dir = dict_config["list_chroma_dir"]
list_embed_file_path = dict_config["list_embed_file_path"]
# 读取Embed对象项目
model_name = dict_config["docs"][embed_object]["model_name"]
embed_file_name = dict_config["docs"][embed_object]["file_name"]
split_chunk_size = dict_config["docs"][embed_object]["split_chunk_size"]
split_overlap = dict_config["docs"][embed_object]["split_overlap"]
model_batch_size = dict_config["docs"][embed_object]["model_batch_size"]
collection_ids_prefix = dict_config["docs"][embed_object]["collection_ids_prefix"]
embed_collection_name = dict_config["docs"][embed_object]["collection_name"]
except Exception as e:
print("配置文件读取失败", e)
sys.exit("Stop processing")
# 回退到上级目录并构建相对路径
chroma_dir = os.path.join(current_dir, *list_chroma_dir)
embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)
```
%% Cell type:code id:630b5215-b3ad-43e6-8f48-83b78abf9c10 tags:
``` python
import chromadb
# 初始化 Chroma 客户端,指定之前的数据存储目录
client = chromadb.PersistentClient(path=chroma_dir)
# 获取 "document_embeddings" 集合
collection = client.get_collection(embed_collection_name)
```
%% Cell type:code id:8fc38844 tags:
``` python
# 装载本地的Hugging Face模型
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)
```
%% Cell type:code id:667253d4-1a55-406a-b3d4-51f4b9957aca tags:
``` python
# 查询文本
# query_embedding = get_embeddings([query_text]) # 使用之前的 get_embedding 函数生成嵌入
query_embedding = model.encode([query_text])
```
%% Cell type:code id:f3df429f-b4d6-4e66-803b-ae5e9ed95922 tags:
``` python
# 查询集合,返回前5个匹配的文档
results = collection.query(
# query_embeddings=query_embedding["embeddings"],
query_embeddings=query_embedding,
n_results=5
)
# 显示查询结果
for doc, metadata in zip(results['documents'], results['metadatas']):
print("Document:", doc)
print("Metadata:", metadata)
```
%% Cell type:code id:01ed14d1-71b5-42b3-82d6-47a053894a74 tags:
``` python
import json
print(json.dumps(results["documents"], indent=4, ensure_ascii=False))
```
%% Cell type:code id:3aa24e59 tags:
``` python
# 释放内存
del model
import gc
gc.collect()
```
%%%% Output: execute_result
28863
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment