126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
from typing import List, Dict, Any, Optional, Union
|
|
import numpy as np
|
|
import pandas as pd
|
|
from chromadb.api.types import QueryResult, GetResult
|
|
|
|
|
|
def _transform_embeddings(
|
|
embeddings: Optional[List[np.ndarray]], # type: ignore
|
|
) -> Optional[Union[List[List[float]], List[np.ndarray]]]: # type: ignore
|
|
"""
|
|
Transform embeddings from numpy arrays to lists of floats.
|
|
This is a shared helper function to avoid duplicating the transformation logic.
|
|
"""
|
|
if embeddings is None:
|
|
return None
|
|
return (
|
|
[emb.tolist() for emb in embeddings]
|
|
if isinstance(embeddings[0], np.ndarray)
|
|
else embeddings
|
|
)
|
|
|
|
|
|
def _add_query_fields(
|
|
data_dict: Dict[str, Any],
|
|
query_result: QueryResult,
|
|
query_idx: int,
|
|
) -> None:
|
|
"""
|
|
Helper function to add fields from a query result to a dictionary.
|
|
Handles the nested array structure specific to query results.
|
|
|
|
Args:
|
|
data_dict: Dictionary to add the fields to
|
|
query_result: QueryResult containing the data
|
|
query_idx: Index of the current query being processed
|
|
"""
|
|
for field in query_result["included"]:
|
|
value = query_result.get(field)
|
|
if value is not None:
|
|
key = field.rstrip("s") # DF naming convention is not plural
|
|
if field == "embeddings":
|
|
value = _transform_embeddings(value) # type: ignore
|
|
if isinstance(value, list) and len(value) > 0:
|
|
value = value[query_idx] # type: ignore
|
|
data_dict[key] = value
|
|
|
|
|
|
def _add_get_fields(
|
|
data_dict: Dict[str, Any],
|
|
get_result: GetResult,
|
|
) -> None:
|
|
"""
|
|
Helper function to add fields from a get result to a dictionary.
|
|
Handles the flat array structure specific to get results.
|
|
|
|
Args:
|
|
data_dict: Dictionary to add the fields to
|
|
get_result: GetResult containing the data
|
|
"""
|
|
for field in get_result["included"]:
|
|
value = get_result.get(field)
|
|
if value is not None:
|
|
key = field.rstrip("s") # DF naming convention is not plural
|
|
if field == "embeddings":
|
|
value = _transform_embeddings(value) # type: ignore
|
|
data_dict[key] = value
|
|
|
|
|
|
def query_result_to_dfs(query_result: QueryResult) -> List["pd.DataFrame"]:
|
|
"""
|
|
Function to convert QueryResult to list of DataFrames.
|
|
Handles the nested array structure specific to query results.
|
|
Column order is defined by the order of the fields in the QueryResult.
|
|
|
|
Args:
|
|
query_result: QueryResult to convert to DataFrames.
|
|
|
|
Returns:
|
|
List of DataFrames.
|
|
"""
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
raise ImportError("pandas is required to convert query results to DataFrames.")
|
|
|
|
dfs = []
|
|
num_queries = len(query_result["ids"])
|
|
|
|
for i in range(num_queries):
|
|
data_for_df: Dict[str, Any] = {}
|
|
data_for_df["id"] = query_result["ids"][i]
|
|
|
|
_add_query_fields(data_for_df, query_result, i)
|
|
|
|
df = pd.DataFrame(data_for_df)
|
|
df.set_index("id", inplace=True)
|
|
dfs.append(df)
|
|
return dfs
|
|
|
|
|
|
def get_result_to_df(get_result: GetResult) -> "pd.DataFrame":
|
|
"""
|
|
Function to convert GetResult to a DataFrame.
|
|
Handles the flat array structure specific to get results.
|
|
Column order is defined by the order of the fields in the GetResult.
|
|
|
|
Args:
|
|
get_result: GetResult to convert to a DataFrame.
|
|
|
|
Returns:
|
|
DataFrame.
|
|
"""
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
raise ImportError("pandas is required to convert get results to a DataFrame.")
|
|
|
|
data_for_df: Dict[str, Any] = {}
|
|
data_for_df["id"] = get_result["ids"]
|
|
|
|
_add_get_fields(data_for_df, get_result)
|
|
|
|
df = pd.DataFrame(data_for_df)
|
|
df.set_index("id", inplace=True)
|
|
return df
|