代码内容转移
This commit is contained in:
194
tools/ingest.py
Normal file
194
tools/ingest.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# 该代码用于将本地知识库中的文档导入到ChromaDB中,并使用远程嵌入模型进行向量化
|
||||
import os
|
||||
from pathlib import Path
|
||||
import chromadb
|
||||
# from chromadb.utils import embedding_functions - 不再需要
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
from unstructured.partition.auto import partition
|
||||
from rich.progress import track
|
||||
import logging
|
||||
import requests # 导入requests
|
||||
import json # 导入json模块
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 配置 ---
|
||||
# 获取脚本所在目录,确保路径的正确性
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base"
|
||||
VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store"
|
||||
COLLECTION_NAME = "drone_docs"
|
||||
# EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理
|
||||
|
||||
# --- 自定义远程嵌入函数 ---
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。
|
||||
"""
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}")
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
对输入的文档进行嵌入。
|
||||
"""
|
||||
# 我们的服务只能处理文本,所以检查输入是否为字符串列表
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 移除 "model" 参数,因为embedding服务可能不需要它
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
json={"input": input},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status() # 如果请求失败则抛出HTTPError
|
||||
|
||||
# 按照OpenAI API的格式解析返回的嵌入向量
|
||||
data = response.json().get("data", [])
|
||||
if not data:
|
||||
raise ValueError("API响应中没有找到'data'字段或'data'为空")
|
||||
|
||||
embeddings = [item['embedding'] for item in data]
|
||||
return embeddings
|
||||
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"调用嵌入API失败: {e}")
|
||||
# 返回一个空列表或根据需要处理错误
|
||||
return []
|
||||
except (ValueError, KeyError) as e:
|
||||
logging.error(f"解析API响应失败: {e}")
|
||||
logging.error(f"收到的响应内容: {response.text}")
|
||||
return []
|
||||
|
||||
|
||||
def get_documents(directory: Path):
|
||||
"""从知识库目录加载所有文档并进行切分"""
|
||||
documents = []
|
||||
logging.info(f"从 '{directory}' 加载文档...")
|
||||
for file_path in directory.rglob("*"):
|
||||
if file_path.is_file() and not file_path.name.startswith('.'):
|
||||
try:
|
||||
# 对简单文本文件直接读取
|
||||
if file_path.suffix in ['.txt', '.md']:
|
||||
text = file_path.read_text(encoding='utf-8')
|
||||
documents.append({
|
||||
"text": text,
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理文本文件: {file_path.name}")
|
||||
# 特别处理常规的JSON文件
|
||||
elif file_path.suffix == '.json':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if 'elements' in data and isinstance(data['elements'], list):
|
||||
for element in data['elements']:
|
||||
documents.append({
|
||||
"text": json.dumps(element, ensure_ascii=False),
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。")
|
||||
else:
|
||||
documents.append({
|
||||
"text": json.dumps(data, ensure_ascii=False),
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)")
|
||||
# 新增:专门处理我们生成的 NDJSON 文件
|
||||
elif file_path.suffix == '.ndjson':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
count = 0
|
||||
for line in f:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
if 'text' in record and isinstance(record['text'], str):
|
||||
documents.append({
|
||||
"text": record['text'],
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
count += 1
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(f"跳过无效的JSON行: {line.strip()}")
|
||||
if count > 0:
|
||||
logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。")
|
||||
# 对其他所有文件类型,使用unstructured
|
||||
else:
|
||||
elements = partition(filename=str(file_path))
|
||||
for element in elements:
|
||||
documents.append({
|
||||
"text": element.text,
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理文件: {file_path.name} (使用unstructured)")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文件 {file_path.name} 失败: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,执行文档入库流程"""
|
||||
if not KNOWLEDGE_BASE_DIR.exists():
|
||||
KNOWLEDGE_BASE_DIR.mkdir(parents=True)
|
||||
logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}")
|
||||
logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。")
|
||||
return
|
||||
|
||||
# 1. 加载并切分文档
|
||||
docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR)
|
||||
if not docs_to_ingest:
|
||||
logging.warning("在知识库中未找到可处理的文档。")
|
||||
return
|
||||
|
||||
# 2. 初始化ChromaDB客户端和远程嵌入函数
|
||||
orin_ip = os.getenv("ORIN_IP", "172.101.1.117")
|
||||
embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings"
|
||||
|
||||
logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}")
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR))
|
||||
|
||||
# 3. 创建或获取集合
|
||||
logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}")
|
||||
collection = client.get_or_create_collection(
|
||||
name=COLLECTION_NAME,
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
# 4. 将文档向量化并存入数据库
|
||||
logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...")
|
||||
|
||||
# 为了避免重复添加,可以先检查
|
||||
# (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑)
|
||||
|
||||
doc_texts = [doc['text'] for doc in docs_to_ingest]
|
||||
metadatas = [doc['metadata'] for doc in docs_to_ingest]
|
||||
ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))]
|
||||
|
||||
try:
|
||||
# ChromaDB的add方法会自动处理嵌入
|
||||
collection.add(
|
||||
documents=doc_texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
logging.info("所有文档块已成功入库!")
|
||||
except Exception as e:
|
||||
logging.error(f"向ChromaDB添加文档时出错: {e}")
|
||||
|
||||
|
||||
# 验证一下
|
||||
count = collection.count()
|
||||
logging.info(f"数据库中现在有 {count} 个条目。")
|
||||
|
||||
print("\n✅ 数据入库完成!")
|
||||
print(f"知识库位于: {KNOWLEDGE_BASE_DIR}")
|
||||
print(f"向量数据库位于: {VECTOR_STORE_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user