Files
DronePlanning/tools/ingest.py
2025-08-25 16:43:37 +08:00

194 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 该代码用于将本地知识库中的文档导入到ChromaDB中并使用远程嵌入模型进行向量化
import os
from pathlib import Path
import chromadb
# from chromadb.utils import embedding_functions - 不再需要
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
from unstructured.partition.auto import partition
from rich.progress import track
import logging
import requests # 导入requests
import json # 导入json模块
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 配置 ---
# 获取脚本所在目录,确保路径的正确性
SCRIPT_DIR = Path(__file__).resolve().parent
KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base"
VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store"
COLLECTION_NAME = "drone_docs"
# EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理
# --- 自定义远程嵌入函数 ---
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。
"""
def __init__(self, api_url: str):
self._api_url = api_url
logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}")
def __call__(self, input: Embeddable) -> Embeddings:
"""
对输入的文档进行嵌入。
"""
# 我们的服务只能处理文本,所以检查输入是否为字符串列表
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。")
return []
try:
# 移除 "model" 参数因为embedding服务可能不需要它
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status() # 如果请求失败则抛出HTTPError
# 按照OpenAI API的格式解析返回的嵌入向量
data = response.json().get("data", [])
if not data:
raise ValueError("API响应中没有找到'data'字段或'data'为空")
embeddings = [item['embedding'] for item in data]
return embeddings
except requests.RequestException as e:
logging.error(f"调用嵌入API失败: {e}")
# 返回一个空列表或根据需要处理错误
return []
except (ValueError, KeyError) as e:
logging.error(f"解析API响应失败: {e}")
logging.error(f"收到的响应内容: {response.text}")
return []
def get_documents(directory: Path):
"""从知识库目录加载所有文档并进行切分"""
documents = []
logging.info(f"'{directory}' 加载文档...")
for file_path in directory.rglob("*"):
if file_path.is_file() and not file_path.name.startswith('.'):
try:
# 对简单文本文件直接读取
if file_path.suffix in ['.txt', '.md']:
text = file_path.read_text(encoding='utf-8')
documents.append({
"text": text,
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理文本文件: {file_path.name}")
# 特别处理常规的JSON文件
elif file_path.suffix == '.json':
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'elements' in data and isinstance(data['elements'], list):
for element in data['elements']:
documents.append({
"text": json.dumps(element, ensure_ascii=False),
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。")
else:
documents.append({
"text": json.dumps(data, ensure_ascii=False),
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)")
# 新增:专门处理我们生成的 NDJSON 文件
elif file_path.suffix == '.ndjson':
with open(file_path, 'r', encoding='utf-8') as f:
count = 0
for line in f:
try:
record = json.loads(line)
if 'text' in record and isinstance(record['text'], str):
documents.append({
"text": record['text'],
"metadata": {"source": str(file_path.name)}
})
count += 1
except json.JSONDecodeError:
logging.warning(f"跳过无效的JSON行: {line.strip()}")
if count > 0:
logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。")
# 对其他所有文件类型使用unstructured
else:
elements = partition(filename=str(file_path))
for element in elements:
documents.append({
"text": element.text,
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理文件: {file_path.name} (使用unstructured)")
except Exception as e:
logging.error(f"处理文件 {file_path.name} 失败: {e}")
return documents
def main():
"""主函数,执行文档入库流程"""
if not KNOWLEDGE_BASE_DIR.exists():
KNOWLEDGE_BASE_DIR.mkdir(parents=True)
logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}")
logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。")
return
# 1. 加载并切分文档
docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR)
if not docs_to_ingest:
logging.warning("在知识库中未找到可处理的文档。")
return
# 2. 初始化ChromaDB客户端和远程嵌入函数
orin_ip = os.getenv("ORIN_IP", "localhost")
embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings"
logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}")
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR))
# 3. 创建或获取集合
logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}")
collection = client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_func
)
# 4. 将文档向量化并存入数据库
logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...")
# 为了避免重复添加,可以先检查
# (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑)
doc_texts = [doc['text'] for doc in docs_to_ingest]
metadatas = [doc['metadata'] for doc in docs_to_ingest]
ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))]
try:
# ChromaDB的add方法会自动处理嵌入
collection.add(
documents=doc_texts,
metadatas=metadatas,
ids=ids
)
logging.info("所有文档块已成功入库!")
except Exception as e:
logging.error(f"向ChromaDB添加文档时出错: {e}")
# 验证一下
count = collection.count()
logging.info(f"数据库中现在有 {count} 个条目。")
print("\n✅ 数据入库完成!")
print(f"知识库位于: {KNOWLEDGE_BASE_DIR}")
print(f"向量数据库位于: {VECTOR_STORE_DIR}")
if __name__ == "__main__":
main()