In [None]:
## 准备部分
# 指定Embed对象
embed_object = "labors-law"

In [None]:
# 导入必要的包
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb
import time
from tqdm.notebook import tqdm
from langchain.docstore.document import Document
import os
import sys
import json
import requests
from chromadb import Client

# 获取当前脚本的目录
current_dir = os.getcwd()

# 读取设定文件
config_file_path = os.path.join(current_dir, "..", "setup", "config_embed.json")

# 读取配置项目
try:
    with open(config_file_path, "r", encoding="utf-8") as f:
        dict_config = json.load(f)
    # 读取通用项目
    API_URL = dict_config["API_URL"]
    cohere_access_token = dict_config["cohere_access_token"]
    custom_proxies = dict_config["custom_proxies"]
    list_chroma_dir = dict_config["list_chroma_dir"]
    list_embed_file_path = dict_config["list_embed_file_path"]
    # 读取Embed对象项目
    model_name = dict_config["docs"][embed_object]["model_name"]
    embed_file_name = dict_config["docs"][embed_object]["file_name"]
    split_chunk_size = dict_config["docs"][embed_object]["split_chunk_size"]
    split_overlap = dict_config["docs"][embed_object]["split_overlap"]
    model_batch_size = dict_config["docs"][embed_object]["model_batch_size"]
    collection_ids_prefix = dict_config["docs"][embed_object]["collection_ids_prefix"]
    embed_collection_name = dict_config["docs"][embed_object]["collection_name"]
except Exception as e:
    print("配置文件读取失败", e)
    sys.exit("Stop processing")

# 回退到上级目录并构建相对路径
chroma_dir = os.path.join(current_dir, *list_chroma_dir)
embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)

In [None]:
def load_and_split_pdf(file_path, chunk_size=500, chunk_overlap=100):
    # 加载PDF文档
    loader = PyPDFLoader(file_path)
    documents = loader.load()

    # 将所有页内容合并成一个字符串
    full_text = " ".join([doc.page_content for doc in documents])

    import re
    # 使用正则表达式去除跨页的页码或分隔符，例如 "\n—6—\n"
    # full_text = re.sub(r"\n—\d+—\n", "\n", full_text)

    # 去掉多余换行符，将换行符替换为空格
    full_text = re.sub(r"\s*\n\s*", "", full_text)

    # 分割文档
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )
    split_docs = text_splitter.split_text(full_text)

    # 将每个片段转为Document对象以保持一致性
    return [Document(page_content=chunk) for chunk in split_docs]

In [None]:
docs = load_and_split_pdf(embed_file_path, split_chunk_size, split_overlap)

In [None]:
# 初始化 Chroma 数据库客户端
client = chromadb.PersistentClient(path=chroma_dir)

# 重建Collection
try:
    client.delete_collection(embed_collection_name)
except Exception as e:
    print(e)
    pass

collection = client.get_or_create_collection(name=embed_collection_name)

In [None]:
from IPython.display import clear_output

In [None]:
# 装载本地的Hugging Face模型
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)

In [None]:
request_docs = []
for idx, doc in enumerate(docs):
    # 使用 page_content 提取文档内容文本
    text = doc.page_content
    request_docs.append(text)

embedded_docs = []
for i in range(0, len(request_docs), model_batch_size):
    # embeddings_response = get_embeddings(request_docs[i:i+model_batch_size])
    # embedded_docs.extend(embeddings_response['embeddings'])
    embedding = model.encode(request_docs[i:i+model_batch_size])
    embedded_docs.extend(embedding)

In [None]:
for idx, doc in tqdm(enumerate(embedded_docs), total=len(embedded_docs), desc="Processing documents"):
    
    # 将嵌入结果存储到 Chroma 数据库
    collection.add(
        ids=collection_ids_prefix + str(idx),
        documents=[request_docs[idx]],
        metadatas=[{"chunk": idx}],
        embeddings=[doc]
    )

In [None]:
# 释放内存
del model
import gc
gc.collect()

In [None]:
# 导入 umap 和 matplotlib
import umap
import matplotlib.pyplot as plt
import numpy as np
import mplcursors

# 创建 UMAP 转换器
reducer = umap.UMAP()

# 将数据降维到二维
embedding = reducer.fit_transform(embedded_docs)

In [None]:
# 绘制二维数据点
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], s=10)

# 添加交互式标签
cursor = mplcursors.cursor(scatter, hover=True)
cursor.connect("add", lambda sel: sel.annotation.set_text(request_docs[sel.index]))

plt.gca().set_aspect('equal', 'datalim')
plt.title('UMAP Projection', fontsize=24)

# 显示图形
plt.show()