# 该代码用于将本地知识库中的文档导入到ChromaDB中,并使用远程嵌入模型进行向量化 import os from pathlib import Path import chromadb # from chromadb.utils import embedding_functions - 不再需要 from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable from unstructured.partition.auto import partition from rich.progress import track import logging import requests # 导入requests import json # 导入json模块 # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- 配置 --- # 获取脚本所在目录,确保路径的正确性 SCRIPT_DIR = Path(__file__).resolve().parent KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base" VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store" COLLECTION_NAME = "drone_docs" # EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理 # --- 自定义远程嵌入函数 --- class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]): """ 一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。 """ def __init__(self, api_url: str): self._api_url = api_url logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}") def __call__(self, input: Embeddable) -> Embeddings: """ 对输入的文档进行嵌入。 """ # 我们的服务只能处理文本,所以检查输入是否为字符串列表 if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input): logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。") return [] try: # 移除 "model" 参数,因为embedding服务可能不需要它 response = requests.post( self._api_url, json={"input": input}, headers={"Content-Type": "application/json"} ) response.raise_for_status() # 如果请求失败则抛出HTTPError # 按照OpenAI API的格式解析返回的嵌入向量 data = response.json().get("data", []) if not data: raise ValueError("API响应中没有找到'data'字段或'data'为空") embeddings = [item['embedding'] for item in data] return embeddings except requests.RequestException as e: logging.error(f"调用嵌入API失败: {e}") # 返回一个空列表或根据需要处理错误 return [] except (ValueError, KeyError) as e: logging.error(f"解析API响应失败: {e}") logging.error(f"收到的响应内容: {response.text}") return [] def get_documents(directory: Path): """从知识库目录加载所有文档并进行切分""" documents = [] logging.info(f"从 '{directory}' 加载文档...") for file_path in directory.rglob("*"): if file_path.is_file() and not file_path.name.startswith('.'): try: # 对简单文本文件直接读取 if file_path.suffix in ['.txt', '.md']: text = file_path.read_text(encoding='utf-8') documents.append({ "text": text, "metadata": {"source": str(file_path.name)} }) logging.info(f"成功处理文本文件: {file_path.name}") # 特别处理常规的JSON文件 elif file_path.suffix == '.json': with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) if 'elements' in data and isinstance(data['elements'], list): for element in data['elements']: documents.append({ "text": json.dumps(element, ensure_ascii=False), "metadata": {"source": str(file_path.name)} }) logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。") else: documents.append({ "text": json.dumps(data, ensure_ascii=False), "metadata": {"source": str(file_path.name)} }) logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)") # 新增:专门处理我们生成的 NDJSON 文件 elif file_path.suffix == '.ndjson': with open(file_path, 'r', encoding='utf-8') as f: count = 0 for line in f: try: record = json.loads(line) if 'text' in record and isinstance(record['text'], str): documents.append({ "text": record['text'], "metadata": {"source": str(file_path.name)} }) count += 1 except json.JSONDecodeError: logging.warning(f"跳过无效的JSON行: {line.strip()}") if count > 0: logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。") # 对其他所有文件类型,使用unstructured else: elements = partition(filename=str(file_path)) for element in elements: documents.append({ "text": element.text, "metadata": {"source": str(file_path.name)} }) logging.info(f"成功处理文件: {file_path.name} (使用unstructured)") except Exception as e: logging.error(f"处理文件 {file_path.name} 失败: {e}") return documents def main(): """主函数,执行文档入库流程""" if not KNOWLEDGE_BASE_DIR.exists(): KNOWLEDGE_BASE_DIR.mkdir(parents=True) logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}") logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。") return # 1. 加载并切分文档 docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR) if not docs_to_ingest: logging.warning("在知识库中未找到可处理的文档。") return # 2. 初始化ChromaDB客户端和远程嵌入函数 orin_ip = os.getenv("ORIN_IP", "172.101.1.117") embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings" logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}") embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url) client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR)) # 3. 创建或获取集合 logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}") collection = client.get_or_create_collection( name=COLLECTION_NAME, embedding_function=embedding_func ) # 4. 将文档向量化并存入数据库 logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...") # 为了避免重复添加,可以先检查 # (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑) doc_texts = [doc['text'] for doc in docs_to_ingest] metadatas = [doc['metadata'] for doc in docs_to_ingest] ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))] try: # ChromaDB的add方法会自动处理嵌入 collection.add( documents=doc_texts, metadatas=metadatas, ids=ids ) logging.info("所有文档块已成功入库!") except Exception as e: logging.error(f"向ChromaDB添加文档时出错: {e}") # 验证一下 count = collection.count() logging.info(f"数据库中现在有 {count} 个条目。") print("\n✅ 数据入库完成!") print(f"知识库位于: {KNOWLEDGE_BASE_DIR}") print(f"向量数据库位于: {VECTOR_STORE_DIR}") if __name__ == "__main__": main()