{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "b1f2ff89", "metadata": {}, "outputs": [], "source": [ "## 准备部分\n", "# 指定Embed对象\n", "embed_object = \"labors-law\"\n", "# 指定用户问题\n", "query_text = \"关于工资福利,有什么规定?\"" ] }, { "cell_type": "code", "execution_count": null, "id": "4b9d9fa0-1c73-4cef-abc4-397458215159", "metadata": {}, "outputs": [], "source": [ "# 导入必要的包\n", "import time\n", "import os\n", "import sys\n", "import json\n", "from chromadb import Client\n", "\n", "# 获取当前脚本的目录\n", "current_dir = os.getcwd()\n", "\n", "# 读取设定文件\n", "config_file_path = os.path.join(current_dir, \"..\", \"setup\", \"config_embed.json\")\n", "\n", "# 读取配置项目\n", "try:\n", " with open(config_file_path, \"r\", encoding=\"utf-8\") as f:\n", " dict_config = json.load(f)\n", " # 读取通用项目\n", " API_URL = dict_config[\"API_URL\"]\n", " cohere_access_token = dict_config[\"cohere_access_token\"]\n", " custom_proxies = dict_config[\"custom_proxies\"]\n", " list_chroma_dir = dict_config[\"list_chroma_dir\"]\n", " list_embed_file_path = dict_config[\"list_embed_file_path\"]\n", " # 读取Embed对象项目\n", " model_name = dict_config[\"docs\"][embed_object][\"model_name\"]\n", " embed_file_name = dict_config[\"docs\"][embed_object][\"file_name\"]\n", " split_chunk_size = dict_config[\"docs\"][embed_object][\"split_chunk_size\"]\n", " split_overlap = dict_config[\"docs\"][embed_object][\"split_overlap\"]\n", " model_batch_size = dict_config[\"docs\"][embed_object][\"model_batch_size\"]\n", " collection_ids_prefix = dict_config[\"docs\"][embed_object][\"collection_ids_prefix\"]\n", " embed_collection_name = dict_config[\"docs\"][embed_object][\"collection_name\"]\n", "except Exception as e:\n", " print(\"配置文件读取失败\", e)\n", " sys.exit(\"Stop processing\")\n", "\n", "# 回退到上级目录并构建相对路径\n", "chroma_dir = os.path.join(current_dir, *list_chroma_dir)\n", "embed_file_path = os.path.join(current_dir, *list_embed_file_path, embed_file_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "630b5215-b3ad-43e6-8f48-83b78abf9c10", "metadata": {}, "outputs": [], "source": [ "import chromadb\n", "\n", "# 初始化 Chroma 客户端,指定之前的数据存储目录\n", "client = chromadb.PersistentClient(path=chroma_dir)\n", "\n", "# 获取 \"document_embeddings\" 集合\n", "collection = client.get_collection(embed_collection_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "8fc38844", "metadata": {}, "outputs": [], "source": [ "# 装载本地的Hugging Face模型\n", "from sentence_transformers import SentenceTransformer\n", "\n", "model = SentenceTransformer(\"jinaai/jina-embeddings-v3\", trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "667253d4-1a55-406a-b3d4-51f4b9957aca", "metadata": {}, "outputs": [], "source": [ "# 查询文本\n", "# query_embedding = get_embeddings([query_text]) # 使用之前的 get_embedding 函数生成嵌入\n", "query_embedding = model.encode([query_text])" ] }, { "cell_type": "code", "execution_count": null, "id": "f3df429f-b4d6-4e66-803b-ae5e9ed95922", "metadata": {}, "outputs": [], "source": [ "# 查询集合,返回前5个匹配的文档\n", "results = collection.query(\n", " # query_embeddings=query_embedding[\"embeddings\"],\n", " query_embeddings=query_embedding,\n", " n_results=5\n", ")\n", "\n", "# 显示查询结果\n", "for doc, metadata in zip(results['documents'], results['metadatas']):\n", " print(\"Document:\", doc)\n", " print(\"Metadata:\", metadata)" ] }, { "cell_type": "code", "execution_count": null, "id": "01ed14d1-71b5-42b3-82d6-47a053894a74", "metadata": {}, "outputs": [], "source": [ "import json\n", "print(json.dumps(results[\"documents\"], indent=4, ensure_ascii=False))" ] }, { "cell_type": "code", "execution_count": null, "id": "3aa24e59", "metadata": {}, "outputs": [], "source": [ "# 释放内存\n", "del model\n", "import gc\n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "id": "4b9a1869", "metadata": {}, "outputs": [], "source": [ "# 导入QA模型\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "\n", "model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " torch_dtype=\"auto\",\n", " device_map=\"auto\"\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": null, "id": "902860e9", "metadata": {}, "outputs": [], "source": [ "# 准备提示词\n", "prompt = query_text\n", "messages = [\n", " {\"role\": \"system\", \"content\": \"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\"},\n", " {\"role\": \"system\", \"content\": \"在回答的过程中,必须根据以下提示回答:\\n\" + \"\\n - \".join(results[\"documents\"][0])},\n", " {\"role\": \"user\", \"content\": prompt}\n", "]\n", "text = tokenizer.apply_chat_template(\n", " messages,\n", " tokenize=False,\n", " add_generation_prompt=True\n", ")\n", "model_inputs = tokenizer([text], return_tensors=\"pt\").to(model.device)" ] }, { "cell_type": "code", "execution_count": null, "id": "f4aa0f73", "metadata": {}, "outputs": [], "source": [ "# 生成回答\n", "generated_ids = model.generate(\n", " **model_inputs,\n", " max_new_tokens=512\n", ")\n", "generated_ids = [\n", " output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)\n", "]\n", "\n", "response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]\n", "\n", "print(response)" ] }, { "cell_type": "code", "execution_count": null, "id": "f2baa5f5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }