Commit 2aa33054 authored by Clark Lin's avatar Clark Lin
Browse files

added embed and retrive book

parent 38f6151c
.venv
.idea
.DS_Store
/data
/setup/config_embed.json
\ No newline at end of file
# jupyter-genai
# 依赖包
jupyter
langchain_community
chromadb
pypdf
requests
## 部署说明
# Use following command to install necessary python lib
pip install jupyter langchain_community chromadb pypdf requests
## 使用说明
# 设定配置文件
/setup/config_embed_sample.json
# 重命名 /setup/config_embed_sample.json -> /setup/config_embed.json
# 编辑每个book的准备部分,执行
%% Cell type:code id:d14ad3f1 tags:
``` python
## 准备部分
# 指定Embed对象
embed_object = "oracle-scm-planning"
```
%% Cell type:code id:29688979-89c0-47c7-84ab-2b4b182d2bd7 tags:
``` python
# 导入必要的包
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb
import time
from tqdm.notebook import tqdm
from langchain.docstore.document import Document
import os
import sys
import json
import requests
from chromadb import Client
# 获取当前脚本的目录
current_dir = os.getcwd()
# 读取设定文件
config_file_path = os.path.join(current_dir, "..", "setup", "config_embed.json")
# 读取配置项目
try:
with open(config_file_path, "r", encoding="utf-8") as f:
dict_config = json.load(f)
# 读取通用项目
API_URL = dict_config["API_URL"]
cohere_access_token = dict_config["cohere_access_token"]
custom_proxies = dict_config["custom_proxies"]
list_chroma_dir = dict_config["list_chroma_dir"]
list_embed_file_path = dict_config["list_embed_file_path"]
# 读取Embed对象项目
model_name = dict_config["docs"][embed_object]["model_name"]
embed_file_name = dict_config["docs"][embed_object]["file_name"]
split_chunk_size = dict_config["docs"][embed_object]["split_chunk_size"]
split_overlap = dict_config["docs"][embed_object]["split_overlap"]
model_batch_size = dict_config["docs"][embed_object]["model_batch_size"]
collection_ids_prefix = dict_config["docs"][embed_object]["collection_ids_prefix"]
embed_collection_name = dict_config["docs"][embed_object]["collection_name"]
except Exception as e:
print("配置文件读取失败", e)
sys.exit("Stop processing")
# 回退到上级目录并构建相对路径
chroma_dir = os.path.join(current_dir, *list_chroma_dir)
embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)
```
%% Cell type:code id:f168e7ca-a61a-4e64-9d22-1e95b6f95a4d tags:
``` python
def load_and_split_pdf(file_path, chunk_size=500, chunk_overlap=100):
# 加载PDF文档
loader = PyPDFLoader(file_path)
documents = loader.load()
# 将所有页内容合并成一个字符串
full_text = "\n".join([doc.page_content for doc in documents])
import re
# 使用正则表达式去除跨页的页码或分隔符,例如 "\n—6—\n"
full_text = re.sub(r"\n—\d+—\n", "\n", full_text)
# 去掉多余换行符,将换行符替换为空格
full_text = re.sub(r"\s*\n\s*", "", full_text)
# 分割文档
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
split_docs = text_splitter.split_text(full_text)
# 将每个片段转为Document对象以保持一致性
return [Document(page_content=chunk) for chunk in split_docs]
```
%% Cell type:code id:e7288509 tags:
``` python
docs = load_and_split_pdf(embed_file_path, split_chunk_size, split_overlap)
```
%% Cell type:code id:8d3ed2e8-97c6-4db4-a4e8-d582013a2ba9 tags:
``` python
# 初始化 Chroma 数据库客户端
client = chromadb.PersistentClient(path=chroma_dir)
# 重建Collection
try:
client.delete_collection(embed_collection_name)
except Exception as e:
print(e)
pass
collection = client.get_or_create_collection(name=embed_collection_name)
```
%% Cell type:code id:c2b660e6-c38b-41be-a429-6a2edf1edd7d tags:
``` python
from IPython.display import clear_output
```
%% Cell type:code id:13e802ab-0f76-4aa9-a7dc-7cde4d7851fb tags:
``` python
# 执行嵌入查询
def get_embeddings(texts: list):
retry_cnt = 0
while True:
try:
response = requests.post(
API_URL,
proxies=custom_proxies,
timeout=(60, 6),
headers={"Authorization": "Bearer {0}".format(cohere_access_token)},
json={
"model": model_name,
"texts": texts,
"input_type": "classification",
"truncate": "NONE"
}
)
response.raise_for_status() # 确保请求成功
break
except Exception as e:
time.sleep(1)
print("error", e)
if retry_cnt > 5:
return None
retry_cnt += 1
return response.json() # 返回嵌入向量
```
%% Cell type:code id:0b23e16a-c069-4be2-8109-166f5cbb722c tags:
``` python
request_docs = []
for idx, doc in enumerate(docs):
# 使用 page_content 提取文档内容文本
text = doc.page_content
request_docs.append(text)
embedded_docs = []
for i in range(0, len(request_docs), model_batch_size):
embeddings_response = get_embeddings(request_docs[i:i+model_batch_size])
embedded_docs.extend(embeddings_response['embeddings'])
```
%% Cell type:code id:fb491aa1-4fdc-423f-ba76-861b68959777 tags:
``` python
for idx, doc in tqdm(enumerate(embedded_docs), total=len(embedded_docs), desc="Processing documents"):
# 将嵌入结果存储到 Chroma 数据库
collection.add(
ids=collection_ids_prefix + str(idx),
documents=[request_docs[idx]],
metadatas=[{"chunk": idx}],
embeddings=[doc]
)
```
%%%% Output: display_data
%% Cell type:code id:b1f2ff89 tags:
``` python
## 准备部分
# 指定Embed对象
embed_object = "oracle-scm-planning"
# 指定用户问题
query_text = "How AI can help SCM planning?"
```
%% Cell type:code id:4b9d9fa0-1c73-4cef-abc4-397458215159 tags:
``` python
# 导入必要的包
import time
import os
import sys
import json
import requests
from chromadb import Client
# 获取当前脚本的目录
current_dir = os.getcwd()
# 读取设定文件
config_file_path = os.path.join(current_dir, "..", "setup", "config_embed.json")
# 读取配置项目
try:
with open(config_file_path, "r", encoding="utf-8") as f:
dict_config = json.load(f)
# 读取通用项目
API_URL = dict_config["API_URL"]
cohere_access_token = dict_config["cohere_access_token"]
custom_proxies = dict_config["custom_proxies"]
list_chroma_dir = dict_config["list_chroma_dir"]
list_embed_file_path = dict_config["list_embed_file_path"]
# 读取Embed对象项目
model_name = dict_config["docs"][embed_object]["model_name"]
embed_file_name = dict_config["docs"][embed_object]["file_name"]
split_chunk_size = dict_config["docs"][embed_object]["split_chunk_size"]
split_overlap = dict_config["docs"][embed_object]["split_overlap"]
model_batch_size = dict_config["docs"][embed_object]["model_batch_size"]
collection_ids_prefix = dict_config["docs"][embed_object]["collection_ids_prefix"]
embed_collection_name = dict_config["docs"][embed_object]["collection_name"]
except Exception as e:
print("配置文件读取失败", e)
sys.exit("Stop processing")
# 回退到上级目录并构建相对路径
chroma_dir = os.path.join(current_dir, *list_chroma_dir)
embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)
```
%% Cell type:code id:630b5215-b3ad-43e6-8f48-83b78abf9c10 tags:
``` python
import chromadb
# 初始化 Chroma 客户端,指定之前的数据存储目录
client = chromadb.PersistentClient(path=chroma_dir)
# 获取 "document_embeddings" 集合
collection = client.get_collection(embed_collection_name)
```
%% Cell type:code id:953b58a8-3c70-4774-ba91-100cfd744337 tags:
``` python
# 执行嵌入查询
def get_embeddings(texts: list):
retry_cnt = 0
while True:
try:
response = requests.post(
API_URL,
proxies=custom_proxies,
timeout=(60, 6),
headers={"Authorization": "Bearer {0}".format(cohere_access_token)},
json={
"model": model_name,
"texts": texts,
"input_type": "classification",
"truncate": "NONE"
}
)
response.raise_for_status() # 确保请求成功
break
except Exception as e:
time.sleep(1)
print("error", e)
if retry_cnt > 5:
return None
retry_cnt += 1
return response.json() # 返回嵌入向量
```
%% Cell type:code id:667253d4-1a55-406a-b3d4-51f4b9957aca tags:
``` python
# 查询文本
query_embedding = get_embeddings([query_text]) # 使用之前的 get_embedding 函数生成嵌入
```
%% Cell type:code id:f3df429f-b4d6-4e66-803b-ae5e9ed95922 tags:
``` python
# 查询集合,返回前5个匹配的文档
results = collection.query(
query_embeddings=query_embedding["embeddings"],
n_results=5
)
# 显示查询结果
for doc, metadata in zip(results['documents'], results['metadatas']):
print("Document:", doc)
print("Metadata:", metadata)
```
%% Cell type:code id:01ed14d1-71b5-42b3-82d6-47a053894a74 tags:
``` python
import json
print(json.dumps(results["documents"], indent=4, ensure_ascii=False))
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment