代码内容转移

This commit is contained in:
2025-08-17 22:41:54 +08:00
commit 0b50022af1
38 changed files with 72624 additions and 0 deletions

194
tools/ingest.py Normal file
View File

@@ -0,0 +1,194 @@
# 该代码用于将本地知识库中的文档导入到ChromaDB中并使用远程嵌入模型进行向量化
import os
from pathlib import Path
import chromadb
# from chromadb.utils import embedding_functions - 不再需要
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
from unstructured.partition.auto import partition
from rich.progress import track
import logging
import requests # 导入requests
import json # 导入json模块
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 配置 ---
# 获取脚本所在目录,确保路径的正确性
SCRIPT_DIR = Path(__file__).resolve().parent
KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base"
VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store"
COLLECTION_NAME = "drone_docs"
# EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理
# --- 自定义远程嵌入函数 ---
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。
"""
def __init__(self, api_url: str):
self._api_url = api_url
logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}")
def __call__(self, input: Embeddable) -> Embeddings:
"""
对输入的文档进行嵌入。
"""
# 我们的服务只能处理文本,所以检查输入是否为字符串列表
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。")
return []
try:
# 移除 "model" 参数因为embedding服务可能不需要它
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status() # 如果请求失败则抛出HTTPError
# 按照OpenAI API的格式解析返回的嵌入向量
data = response.json().get("data", [])
if not data:
raise ValueError("API响应中没有找到'data'字段或'data'为空")
embeddings = [item['embedding'] for item in data]
return embeddings
except requests.RequestException as e:
logging.error(f"调用嵌入API失败: {e}")
# 返回一个空列表或根据需要处理错误
return []
except (ValueError, KeyError) as e:
logging.error(f"解析API响应失败: {e}")
logging.error(f"收到的响应内容: {response.text}")
return []
def get_documents(directory: Path):
"""从知识库目录加载所有文档并进行切分"""
documents = []
logging.info(f"'{directory}' 加载文档...")
for file_path in directory.rglob("*"):
if file_path.is_file() and not file_path.name.startswith('.'):
try:
# 对简单文本文件直接读取
if file_path.suffix in ['.txt', '.md']:
text = file_path.read_text(encoding='utf-8')
documents.append({
"text": text,
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理文本文件: {file_path.name}")
# 特别处理常规的JSON文件
elif file_path.suffix == '.json':
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'elements' in data and isinstance(data['elements'], list):
for element in data['elements']:
documents.append({
"text": json.dumps(element, ensure_ascii=False),
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。")
else:
documents.append({
"text": json.dumps(data, ensure_ascii=False),
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)")
# 新增:专门处理我们生成的 NDJSON 文件
elif file_path.suffix == '.ndjson':
with open(file_path, 'r', encoding='utf-8') as f:
count = 0
for line in f:
try:
record = json.loads(line)
if 'text' in record and isinstance(record['text'], str):
documents.append({
"text": record['text'],
"metadata": {"source": str(file_path.name)}
})
count += 1
except json.JSONDecodeError:
logging.warning(f"跳过无效的JSON行: {line.strip()}")
if count > 0:
logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。")
# 对其他所有文件类型使用unstructured
else:
elements = partition(filename=str(file_path))
for element in elements:
documents.append({
"text": element.text,
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理文件: {file_path.name} (使用unstructured)")
except Exception as e:
logging.error(f"处理文件 {file_path.name} 失败: {e}")
return documents
def main():
"""主函数,执行文档入库流程"""
if not KNOWLEDGE_BASE_DIR.exists():
KNOWLEDGE_BASE_DIR.mkdir(parents=True)
logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}")
logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。")
return
# 1. 加载并切分文档
docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR)
if not docs_to_ingest:
logging.warning("在知识库中未找到可处理的文档。")
return
# 2. 初始化ChromaDB客户端和远程嵌入函数
orin_ip = os.getenv("ORIN_IP", "172.101.1.117")
embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings"
logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}")
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR))
# 3. 创建或获取集合
logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}")
collection = client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_func
)
# 4. 将文档向量化并存入数据库
logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...")
# 为了避免重复添加,可以先检查
# (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑)
doc_texts = [doc['text'] for doc in docs_to_ingest]
metadatas = [doc['metadata'] for doc in docs_to_ingest]
ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))]
try:
# ChromaDB的add方法会自动处理嵌入
collection.add(
documents=doc_texts,
metadatas=metadatas,
ids=ids
)
logging.info("所有文档块已成功入库!")
except Exception as e:
logging.error(f"向ChromaDB添加文档时出错: {e}")
# 验证一下
count = collection.count()
logging.info(f"数据库中现在有 {count} 个条目。")
print("\n✅ 数据入库完成!")
print(f"知识库位于: {KNOWLEDGE_BASE_DIR}")
print(f"向量数据库位于: {VECTOR_STORE_DIR}")
if __name__ == "__main__":
main()