chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
from .anthropic import Anthropic
|
||||
from .anthropic_async import AsyncAnthropic
|
||||
from .anthropic_providers import (
|
||||
AnthropicBedrock,
|
||||
AnthropicVertex,
|
||||
AsyncAnthropicBedrock,
|
||||
AsyncAnthropicVertex,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anthropic",
|
||||
"AsyncAnthropic",
|
||||
"AnthropicBedrock",
|
||||
"AsyncAnthropicBedrock",
|
||||
"AnthropicVertex",
|
||||
"AsyncAnthropicVertex",
|
||||
]
|
||||
@@ -0,0 +1,217 @@
|
||||
try:
|
||||
import anthropic
|
||||
from anthropic.resources import Messages
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the Anthropic SDK to use this feature: 'pip install anthropic'"
|
||||
)
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from posthog.ai.utils import (
|
||||
call_llm_and_track_usage,
|
||||
get_model_params,
|
||||
merge_system_prompt,
|
||||
with_privacy_mode,
|
||||
)
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class Anthropic(anthropic.Anthropic):
|
||||
"""
|
||||
A wrapper around the Anthropic SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
posthog_client: PostHog client for tracking usage
|
||||
**kwargs: Additional arguments passed to the Anthropic client
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
self.messages = WrappedMessages(self)
|
||||
|
||||
|
||||
class WrappedMessages(Messages):
|
||||
_client: Anthropic
|
||||
|
||||
def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Create a message using Anthropic's API while tracking usage in PostHog.
|
||||
|
||||
Args:
|
||||
posthog_distinct_id: Optional ID to associate with the usage event
|
||||
posthog_trace_id: Optional trace UUID for linking events
|
||||
posthog_properties: Optional dictionary of extra properties to include in the event
|
||||
posthog_privacy_mode: Whether to redact sensitive information in tracking
|
||||
posthog_groups: Optional group analytics properties
|
||||
**kwargs: Arguments passed to Anthropic's messages.create
|
||||
"""
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs.get("stream", False):
|
||||
return self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return call_llm_and_track_usage(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"anthropic",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
super().create,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
return self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _create_streaming(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
|
||||
accumulated_content = []
|
||||
response = super().create(**kwargs)
|
||||
|
||||
def generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal accumulated_content # noqa: F824
|
||||
try:
|
||||
for event in response:
|
||||
if hasattr(event, "usage") and event.usage:
|
||||
usage_stats = {
|
||||
k: getattr(event.usage, k, 0)
|
||||
for k in [
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"cache_read_input_tokens",
|
||||
"cache_creation_input_tokens",
|
||||
]
|
||||
}
|
||||
|
||||
if hasattr(event, "content") and event.content:
|
||||
accumulated_content.append(event.content)
|
||||
|
||||
yield event
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = "".join(accumulated_content)
|
||||
|
||||
self._capture_streaming_event(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
def _capture_streaming_event(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: str,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "anthropic",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
merge_system_prompt(kwargs, "anthropic"),
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
[{"content": output, "role": "assistant"}],
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
|
||||
"$ai_cache_read_input_tokens": usage_stats.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
"$ai_cache_creation_input_tokens": usage_stats.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
@@ -0,0 +1,217 @@
|
||||
try:
|
||||
import anthropic
|
||||
from anthropic.resources import AsyncMessages
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the Anthropic SDK to use this feature: 'pip install anthropic'"
|
||||
)
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from posthog.ai.utils import (
|
||||
call_llm_and_track_usage_async,
|
||||
get_model_params,
|
||||
merge_system_prompt,
|
||||
with_privacy_mode,
|
||||
)
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class AsyncAnthropic(anthropic.AsyncAnthropic):
|
||||
"""
|
||||
An async wrapper around the Anthropic SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
posthog_client: PostHog client for tracking usage
|
||||
**kwargs: Additional arguments passed to the Anthropic client
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
self.messages = AsyncWrappedMessages(self)
|
||||
|
||||
|
||||
class AsyncWrappedMessages(AsyncMessages):
|
||||
_client: AsyncAnthropic
|
||||
|
||||
async def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Create a message using Anthropic's API while tracking usage in PostHog.
|
||||
|
||||
Args:
|
||||
posthog_distinct_id: Optional ID to associate with the usage event
|
||||
posthog_trace_id: Optional trace UUID for linking events
|
||||
posthog_properties: Optional dictionary of extra properties to include in the event
|
||||
posthog_privacy_mode: Whether to redact sensitive information in tracking
|
||||
posthog_groups: Optional group analytics properties
|
||||
**kwargs: Arguments passed to Anthropic's messages.create
|
||||
"""
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs.get("stream", False):
|
||||
return await self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await call_llm_and_track_usage_async(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"anthropic",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
super().create,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
return await self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _create_streaming(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
|
||||
accumulated_content = []
|
||||
response = await super().create(**kwargs)
|
||||
|
||||
async def generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal accumulated_content # noqa: F824
|
||||
try:
|
||||
async for event in response:
|
||||
if hasattr(event, "usage") and event.usage:
|
||||
usage_stats = {
|
||||
k: getattr(event.usage, k, 0)
|
||||
for k in [
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"cache_read_input_tokens",
|
||||
"cache_creation_input_tokens",
|
||||
]
|
||||
}
|
||||
|
||||
if hasattr(event, "content") and event.content:
|
||||
accumulated_content.append(event.content)
|
||||
|
||||
yield event
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = "".join(accumulated_content)
|
||||
|
||||
await self._capture_streaming_event(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
async def _capture_streaming_event(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: str,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "anthropic",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
merge_system_prompt(kwargs, "anthropic"),
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
[{"content": output, "role": "assistant"}],
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
|
||||
"$ai_cache_read_input_tokens": usage_stats.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
"$ai_cache_creation_input_tokens": usage_stats.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
@@ -0,0 +1,62 @@
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the Anthropic SDK to use this feature: 'pip install anthropic'"
|
||||
)
|
||||
|
||||
from posthog.ai.anthropic.anthropic import WrappedMessages
|
||||
from posthog.ai.anthropic.anthropic_async import AsyncWrappedMessages
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class AnthropicBedrock(anthropic.AnthropicBedrock):
|
||||
"""
|
||||
A wrapper around the Anthropic Bedrock SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
self.messages = WrappedMessages(self)
|
||||
|
||||
|
||||
class AsyncAnthropicBedrock(anthropic.AsyncAnthropicBedrock):
|
||||
"""
|
||||
A wrapper around the Anthropic Bedrock SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
self.messages = AsyncWrappedMessages(self)
|
||||
|
||||
|
||||
class AnthropicVertex(anthropic.AnthropicVertex):
|
||||
"""
|
||||
A wrapper around the Anthropic Vertex SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
self.messages = WrappedMessages(self)
|
||||
|
||||
|
||||
class AsyncAnthropicVertex(anthropic.AsyncAnthropicVertex):
|
||||
"""
|
||||
A wrapper around the Anthropic Vertex SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
self.messages = AsyncWrappedMessages(self)
|
||||
@@ -0,0 +1,11 @@
|
||||
from .gemini import Client
|
||||
|
||||
|
||||
# Create a genai-like module for perfect drop-in replacement
|
||||
class _GenAI:
|
||||
Client = Client
|
||||
|
||||
|
||||
genai = _GenAI()
|
||||
|
||||
__all__ = ["Client", "genai"]
|
||||
@@ -0,0 +1,366 @@
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the Google Gemini SDK to use this feature: 'pip install google-genai'"
|
||||
)
|
||||
|
||||
from posthog.ai.utils import (
|
||||
call_llm_and_track_usage,
|
||||
get_model_params,
|
||||
with_privacy_mode,
|
||||
)
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class Client:
|
||||
"""
|
||||
A drop-in replacement for genai.Client that automatically sends LLM usage events to PostHog.
|
||||
|
||||
Usage:
|
||||
client = Client(
|
||||
api_key="your_api_key",
|
||||
posthog_client=posthog_client,
|
||||
posthog_distinct_id="default_user", # Optional defaults
|
||||
posthog_properties={"team": "ai"} # Optional defaults
|
||||
)
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=["Hello world"],
|
||||
posthog_distinct_id="specific_user" # Override default
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
posthog_client: Optional[PostHogClient] = None,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable
|
||||
posthog_client: PostHog client for tracking usage
|
||||
posthog_distinct_id: Default distinct ID for all calls (can be overridden per call)
|
||||
posthog_properties: Default properties for all calls (can be overridden per call)
|
||||
posthog_privacy_mode: Default privacy mode for all calls (can be overridden per call)
|
||||
posthog_groups: Default groups for all calls (can be overridden per call)
|
||||
**kwargs: Additional arguments (for future compatibility)
|
||||
"""
|
||||
if posthog_client is None:
|
||||
raise ValueError("posthog_client is required for PostHog tracking")
|
||||
|
||||
self.models = Models(
|
||||
api_key=api_key,
|
||||
posthog_client=posthog_client,
|
||||
posthog_distinct_id=posthog_distinct_id,
|
||||
posthog_properties=posthog_properties,
|
||||
posthog_privacy_mode=posthog_privacy_mode,
|
||||
posthog_groups=posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class Models:
|
||||
"""
|
||||
Models interface that mimics genai.Client().models with PostHog tracking.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient # Not None after __init__ validation
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
posthog_client: Optional[PostHogClient] = None,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: Google AI API key. If not provided, will use GOOGLE_API_KEY or API_KEY environment variable
|
||||
posthog_client: PostHog client for tracking usage
|
||||
posthog_distinct_id: Default distinct ID for all calls
|
||||
posthog_properties: Default properties for all calls
|
||||
posthog_privacy_mode: Default privacy mode for all calls
|
||||
posthog_groups: Default groups for all calls
|
||||
**kwargs: Additional arguments (for future compatibility)
|
||||
"""
|
||||
if posthog_client is None:
|
||||
raise ValueError("posthog_client is required for PostHog tracking")
|
||||
|
||||
self._ph_client = posthog_client
|
||||
|
||||
# Store default PostHog settings
|
||||
self._default_distinct_id = posthog_distinct_id
|
||||
self._default_properties = posthog_properties or {}
|
||||
self._default_privacy_mode = posthog_privacy_mode
|
||||
self._default_groups = posthog_groups
|
||||
|
||||
# Handle API key - try parameter first, then environment variables
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("GOOGLE_API_KEY") or os.environ.get("API_KEY")
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"API key must be provided either as parameter or via GOOGLE_API_KEY/API_KEY environment variable"
|
||||
)
|
||||
|
||||
self._client = genai.Client(api_key=api_key)
|
||||
self._base_url = "https://generativelanguage.googleapis.com"
|
||||
|
||||
def _merge_posthog_params(
|
||||
self,
|
||||
call_distinct_id: Optional[str],
|
||||
call_trace_id: Optional[str],
|
||||
call_properties: Optional[Dict[str, Any]],
|
||||
call_privacy_mode: Optional[bool],
|
||||
call_groups: Optional[Dict[str, Any]],
|
||||
):
|
||||
"""Merge call-level PostHog parameters with client defaults."""
|
||||
# Use call-level values if provided, otherwise fall back to defaults
|
||||
distinct_id = (
|
||||
call_distinct_id
|
||||
if call_distinct_id is not None
|
||||
else self._default_distinct_id
|
||||
)
|
||||
privacy_mode = (
|
||||
call_privacy_mode
|
||||
if call_privacy_mode is not None
|
||||
else self._default_privacy_mode
|
||||
)
|
||||
groups = call_groups if call_groups is not None else self._default_groups
|
||||
|
||||
# Merge properties: default properties + call properties (call properties override)
|
||||
properties = dict(self._default_properties)
|
||||
if call_properties:
|
||||
properties.update(call_properties)
|
||||
|
||||
if call_trace_id is None:
|
||||
call_trace_id = str(uuid.uuid4())
|
||||
|
||||
return distinct_id, call_trace_id, properties, privacy_mode, groups
|
||||
|
||||
def generate_content(
|
||||
self,
|
||||
model: str,
|
||||
contents,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: Optional[bool] = None,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Generate content using Gemini's API while tracking usage in PostHog.
|
||||
|
||||
This method signature exactly matches genai.Client().models.generate_content()
|
||||
with additional PostHog tracking parameters.
|
||||
|
||||
Args:
|
||||
model: The model to use (e.g., 'gemini-2.0-flash')
|
||||
contents: The input content for generation
|
||||
posthog_distinct_id: ID to associate with the usage event (overrides client default)
|
||||
posthog_trace_id: Trace UUID for linking events (auto-generated if not provided)
|
||||
posthog_properties: Extra properties to include in the event (merged with client defaults)
|
||||
posthog_privacy_mode: Whether to redact sensitive information (overrides client default)
|
||||
posthog_groups: Group analytics properties (overrides client default)
|
||||
**kwargs: Arguments passed to Gemini's generate_content
|
||||
"""
|
||||
# Merge PostHog parameters
|
||||
distinct_id, trace_id, properties, privacy_mode, groups = (
|
||||
self._merge_posthog_params(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
)
|
||||
)
|
||||
|
||||
kwargs_with_contents = {"model": model, "contents": contents, **kwargs}
|
||||
|
||||
return call_llm_and_track_usage(
|
||||
distinct_id,
|
||||
self._ph_client,
|
||||
"gemini",
|
||||
trace_id,
|
||||
properties,
|
||||
privacy_mode,
|
||||
groups,
|
||||
self._base_url,
|
||||
self._client.models.generate_content,
|
||||
**kwargs_with_contents,
|
||||
)
|
||||
|
||||
def _generate_content_streaming(
|
||||
self,
|
||||
model: str,
|
||||
contents,
|
||||
distinct_id: Optional[str],
|
||||
trace_id: Optional[str],
|
||||
properties: Optional[Dict[str, Any]],
|
||||
privacy_mode: bool,
|
||||
groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {"input_tokens": 0, "output_tokens": 0}
|
||||
accumulated_content = []
|
||||
|
||||
kwargs_without_stream = {"model": model, "contents": contents, **kwargs}
|
||||
response = self._client.models.generate_content_stream(**kwargs_without_stream)
|
||||
|
||||
def generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal accumulated_content # noqa: F824
|
||||
try:
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
|
||||
usage_stats = {
|
||||
"input_tokens": getattr(
|
||||
chunk.usage_metadata, "prompt_token_count", 0
|
||||
),
|
||||
"output_tokens": getattr(
|
||||
chunk.usage_metadata, "candidates_token_count", 0
|
||||
),
|
||||
}
|
||||
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
accumulated_content.append(chunk.text)
|
||||
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = "".join(accumulated_content)
|
||||
|
||||
self._capture_streaming_event(
|
||||
model,
|
||||
contents,
|
||||
distinct_id,
|
||||
trace_id,
|
||||
properties,
|
||||
privacy_mode,
|
||||
groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
def _capture_streaming_event(
|
||||
self,
|
||||
model: str,
|
||||
contents,
|
||||
distinct_id: Optional[str],
|
||||
trace_id: Optional[str],
|
||||
properties: Optional[Dict[str, Any]],
|
||||
privacy_mode: bool,
|
||||
groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: str,
|
||||
):
|
||||
if trace_id is None:
|
||||
trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "gemini",
|
||||
"$ai_model": model,
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._ph_client,
|
||||
privacy_mode,
|
||||
self._format_input(contents),
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._ph_client,
|
||||
privacy_mode,
|
||||
[{"content": output, "role": "assistant"}],
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": trace_id,
|
||||
"$ai_base_url": self._base_url,
|
||||
**(properties or {}),
|
||||
}
|
||||
|
||||
if distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._ph_client, "capture"):
|
||||
self._ph_client.capture(
|
||||
distinct_id=distinct_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
def _format_input(self, contents):
|
||||
"""Format input contents for PostHog tracking"""
|
||||
if isinstance(contents, str):
|
||||
return [{"role": "user", "content": contents}]
|
||||
elif isinstance(contents, list):
|
||||
formatted = []
|
||||
for item in contents:
|
||||
if isinstance(item, str):
|
||||
formatted.append({"role": "user", "content": item})
|
||||
elif hasattr(item, "text"):
|
||||
formatted.append({"role": "user", "content": item.text})
|
||||
else:
|
||||
formatted.append({"role": "user", "content": str(item)})
|
||||
return formatted
|
||||
else:
|
||||
return [{"role": "user", "content": str(contents)}]
|
||||
|
||||
def generate_content_stream(
|
||||
self,
|
||||
model: str,
|
||||
contents,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: Optional[bool] = None,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
# Merge PostHog parameters
|
||||
distinct_id, trace_id, properties, privacy_mode, groups = (
|
||||
self._merge_posthog_params(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
)
|
||||
)
|
||||
|
||||
return self._generate_content_streaming(
|
||||
model,
|
||||
contents,
|
||||
distinct_id,
|
||||
trace_id,
|
||||
properties,
|
||||
privacy_mode,
|
||||
groups,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .callbacks import CallbackHandler
|
||||
|
||||
__all__ = ["CallbackHandler"]
|
||||
@@ -0,0 +1,841 @@
|
||||
try:
|
||||
import langchain # noqa: F401
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install LangChain to use this feature: 'pip install langchain'"
|
||||
)
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema.agent import AgentAction, AgentFinish
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from pydantic import BaseModel
|
||||
|
||||
from posthog import default_client
|
||||
from posthog.ai.utils import get_model_params, with_privacy_mode
|
||||
from posthog.client import Client
|
||||
|
||||
log = logging.getLogger("posthog")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpanMetadata:
|
||||
name: str
|
||||
"""Name of the run: chain name, model name, etc."""
|
||||
start_time: float
|
||||
"""Start time of the run."""
|
||||
end_time: Optional[float]
|
||||
"""End time of the run."""
|
||||
input: Optional[Any]
|
||||
"""Input of the run: messages, prompt variables, etc."""
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
if not self.end_time:
|
||||
return 0
|
||||
return self.end_time - self.start_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationMetadata(SpanMetadata):
|
||||
provider: Optional[str] = None
|
||||
"""Provider of the run: OpenAI, Anthropic"""
|
||||
model: Optional[str] = None
|
||||
"""Model used in the run"""
|
||||
model_params: Optional[Dict[str, Any]] = None
|
||||
"""Model parameters of the run: temperature, max_tokens, etc."""
|
||||
base_url: Optional[str] = None
|
||||
"""Base URL of the provider's API used in the run."""
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
"""Tools provided to the model."""
|
||||
|
||||
|
||||
RunMetadata = Union[SpanMetadata, GenerationMetadata]
|
||||
RunMetadataStorage = Dict[UUID, RunMetadata]
|
||||
|
||||
|
||||
class CallbackHandler(BaseCallbackHandler):
|
||||
"""
|
||||
The PostHog LLM observability callback handler for LangChain.
|
||||
"""
|
||||
|
||||
_client: Client
|
||||
"""PostHog client instance."""
|
||||
|
||||
_distinct_id: Optional[Union[str, int, float, UUID]]
|
||||
"""Distinct ID of the user to associate the trace with."""
|
||||
|
||||
_trace_id: Optional[Union[str, int, float, UUID]]
|
||||
"""Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""
|
||||
|
||||
_trace_input: Optional[Any]
|
||||
"""The input at the start of the trace. Any JSON object."""
|
||||
|
||||
_trace_name: Optional[str]
|
||||
"""Name of the trace, exposed in the UI."""
|
||||
|
||||
_properties: Optional[Dict[str, Any]]
|
||||
"""Global properties to be sent with every event."""
|
||||
|
||||
_runs: RunMetadataStorage
|
||||
"""Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""
|
||||
|
||||
_parent_tree: Dict[UUID, UUID]
|
||||
"""
|
||||
A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
|
||||
so the top level can be found from a bottom-level run ID.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Optional[Client] = None,
|
||||
*,
|
||||
distinct_id: Optional[Union[str, int, float, UUID]] = None,
|
||||
trace_id: Optional[Union[str, int, float, UUID]] = None,
|
||||
properties: Optional[Dict[str, Any]] = None,
|
||||
privacy_mode: bool = False,
|
||||
groups: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
client: PostHog client instance.
|
||||
distinct_id: Optional distinct ID of the user to associate the trace with.
|
||||
trace_id: Optional trace ID to use for the event.
|
||||
properties: Optional additional metadata to use for the trace.
|
||||
privacy_mode: Whether to redact the input and output of the trace.
|
||||
groups: Optional additional PostHog groups to use for the trace.
|
||||
"""
|
||||
posthog_client = client or default_client
|
||||
if posthog_client is None:
|
||||
raise ValueError("PostHog client is required")
|
||||
self._client = posthog_client
|
||||
self._distinct_id = distinct_id
|
||||
self._trace_id = trace_id
|
||||
self._properties = properties or {}
|
||||
self._privacy_mode = privacy_mode
|
||||
self._groups = groups or {}
|
||||
self._runs = {}
|
||||
self._parent_tree = {}
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
inputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
|
||||
self._set_parent_of_run(run_id, parent_run_id)
|
||||
self._set_trace_or_span_metadata(
|
||||
serialized, inputs, run_id, parent_run_id, **kwargs
|
||||
)
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
outputs: Dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, outputs)
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._log_debug_event(
|
||||
"on_chat_model_start", run_id, parent_run_id, messages=messages
|
||||
)
|
||||
self._set_parent_of_run(run_id, parent_run_id)
|
||||
input = [
|
||||
_convert_message_to_dict(message) for row in messages for message in row
|
||||
]
|
||||
self._set_llm_metadata(serialized, run_id, input, **kwargs)
|
||||
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
prompts: List[str],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
|
||||
self._set_parent_of_run(run_id, parent_run_id)
|
||||
self._set_llm_metadata(serialized, run_id, prompts, **kwargs)
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token)
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: LLMResult,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
|
||||
"""
|
||||
self._log_debug_event(
|
||||
"on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs
|
||||
)
|
||||
self._pop_run_and_capture_generation(run_id, parent_run_id, response)
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
|
||||
self._pop_run_and_capture_generation(run_id, parent_run_id, error)
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
input_str: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._log_debug_event(
|
||||
"on_tool_start", run_id, parent_run_id, input_str=input_str
|
||||
)
|
||||
self._set_parent_of_run(run_id, parent_run_id)
|
||||
self._set_trace_or_span_metadata(
|
||||
serialized, input_str, run_id, parent_run_id, **kwargs
|
||||
)
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, output)
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
query: str,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
|
||||
self._set_parent_of_run(run_id, parent_run_id)
|
||||
self._set_trace_or_span_metadata(
|
||||
serialized, query, run_id, parent_run_id, **kwargs
|
||||
)
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._log_debug_event(
|
||||
"on_retriever_end", run_id, parent_run_id, documents=documents
|
||||
)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever errors."""
|
||||
self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)
|
||||
|
||||
def on_agent_action(
|
||||
self,
|
||||
action: AgentAction,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run on agent action."""
|
||||
self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)
|
||||
self._set_parent_of_run(run_id, parent_run_id)
|
||||
self._set_trace_or_span_metadata(None, action, run_id, parent_run_id, **kwargs)
|
||||
|
||||
def on_agent_finish(
|
||||
self,
|
||||
finish: AgentFinish,
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
|
||||
self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, finish)
|
||||
|
||||
def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
|
||||
"""
|
||||
Set the parent run ID for a chain run. If there is no parent, the run is the root.
|
||||
"""
|
||||
if parent_run_id is not None:
|
||||
self._parent_tree[run_id] = parent_run_id
|
||||
|
||||
def _pop_parent_of_run(self, run_id: UUID):
|
||||
"""
|
||||
Remove the parent run ID for a chain run.
|
||||
"""
|
||||
try:
|
||||
self._parent_tree.pop(run_id)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _find_root_run(self, run_id: UUID) -> UUID:
|
||||
"""
|
||||
Finds the root ID of a chain run.
|
||||
"""
|
||||
id: UUID = run_id
|
||||
while id in self._parent_tree:
|
||||
id = self._parent_tree[id]
|
||||
return id
|
||||
|
||||
def _set_trace_or_span_metadata(
|
||||
self,
|
||||
serialized: Optional[Dict[str, Any]],
|
||||
input: Any,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs,
|
||||
):
|
||||
default_name = "trace" if parent_run_id is None else "span"
|
||||
run_name = _get_langchain_run_name(serialized, **kwargs) or default_name
|
||||
self._runs[run_id] = SpanMetadata(
|
||||
name=run_name, input=input, start_time=time.time(), end_time=None
|
||||
)
|
||||
|
||||
def _set_llm_metadata(
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
run_id: UUID,
|
||||
messages: Union[List[Dict[str, Any]], List[str]],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
invocation_params: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
run_name = _get_langchain_run_name(serialized, **kwargs) or "generation"
|
||||
generation = GenerationMetadata(
|
||||
name=run_name, input=messages, start_time=time.time(), end_time=None
|
||||
)
|
||||
if isinstance(invocation_params, dict):
|
||||
generation.model_params = get_model_params(invocation_params)
|
||||
if tools := invocation_params.get("tools"):
|
||||
generation.tools = tools
|
||||
if isinstance(metadata, dict):
|
||||
if model := metadata.get("ls_model_name"):
|
||||
generation.model = model
|
||||
if provider := metadata.get("ls_provider"):
|
||||
generation.provider = provider
|
||||
try:
|
||||
base_url = serialized["kwargs"]["openai_api_base"]
|
||||
if base_url is not None:
|
||||
generation.base_url = base_url
|
||||
except KeyError:
|
||||
pass
|
||||
self._runs[run_id] = generation
|
||||
|
||||
def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
|
||||
end_time = time.time()
|
||||
try:
|
||||
run = self._runs.pop(run_id)
|
||||
except KeyError:
|
||||
log.warning(f"No run metadata found for run {run_id}")
|
||||
return None
|
||||
run.end_time = end_time
|
||||
return run
|
||||
|
||||
def _get_trace_id(self, run_id: UUID):
|
||||
trace_id = self._trace_id or self._find_root_run(run_id)
|
||||
if not trace_id:
|
||||
return run_id
|
||||
return trace_id
|
||||
|
||||
def _get_parent_run_id(
|
||||
self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]
|
||||
):
|
||||
"""
|
||||
Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
|
||||
"""
|
||||
if parent_run_id is not None and parent_run_id not in self._parent_tree:
|
||||
return trace_id
|
||||
return parent_run_id
|
||||
|
||||
def _pop_run_and_capture_trace_or_span(
|
||||
self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any
|
||||
):
|
||||
trace_id = self._get_trace_id(run_id)
|
||||
self._pop_parent_of_run(run_id)
|
||||
run = self._pop_run_metadata(run_id)
|
||||
if not run:
|
||||
return
|
||||
if isinstance(run, GenerationMetadata):
|
||||
log.warning(
|
||||
f"Run {run_id} is a generation, but attempted to be captured as a trace or span."
|
||||
)
|
||||
return
|
||||
self._capture_trace_or_span(
|
||||
trace_id,
|
||||
run_id,
|
||||
run,
|
||||
outputs,
|
||||
self._get_parent_run_id(trace_id, run_id, parent_run_id),
|
||||
)
|
||||
|
||||
def _capture_trace_or_span(
|
||||
self,
|
||||
trace_id: Any,
|
||||
run_id: UUID,
|
||||
run: SpanMetadata,
|
||||
outputs: Any,
|
||||
parent_run_id: Optional[UUID],
|
||||
):
|
||||
event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
|
||||
event_properties = {
|
||||
"$ai_trace_id": trace_id,
|
||||
"$ai_input_state": with_privacy_mode(
|
||||
self._client, self._privacy_mode, run.input
|
||||
),
|
||||
"$ai_latency": run.latency,
|
||||
"$ai_span_name": run.name,
|
||||
"$ai_span_id": run_id,
|
||||
}
|
||||
if parent_run_id is not None:
|
||||
event_properties["$ai_parent_id"] = parent_run_id
|
||||
if self._properties:
|
||||
event_properties.update(self._properties)
|
||||
|
||||
if isinstance(outputs, BaseException):
|
||||
event_properties["$ai_error"] = _stringify_exception(outputs)
|
||||
event_properties["$ai_is_error"] = True
|
||||
elif outputs is not None:
|
||||
event_properties["$ai_output_state"] = with_privacy_mode(
|
||||
self._client, self._privacy_mode, outputs
|
||||
)
|
||||
|
||||
if self._distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
self._client.capture(
|
||||
distinct_id=self._distinct_id or run_id,
|
||||
event=event_name,
|
||||
properties=event_properties,
|
||||
groups=self._groups,
|
||||
)
|
||||
|
||||
def _pop_run_and_capture_generation(
|
||||
self,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID],
|
||||
response: Union[LLMResult, BaseException],
|
||||
):
|
||||
trace_id = self._get_trace_id(run_id)
|
||||
self._pop_parent_of_run(run_id)
|
||||
run = self._pop_run_metadata(run_id)
|
||||
if not run:
|
||||
return
|
||||
if not isinstance(run, GenerationMetadata):
|
||||
log.warning(
|
||||
f"Run {run_id} is not a generation, but attempted to be captured as a generation."
|
||||
)
|
||||
return
|
||||
self._capture_generation(
|
||||
trace_id,
|
||||
run_id,
|
||||
run,
|
||||
response,
|
||||
self._get_parent_run_id(trace_id, run_id, parent_run_id),
|
||||
)
|
||||
|
||||
def _capture_generation(
|
||||
self,
|
||||
trace_id: Any,
|
||||
run_id: UUID,
|
||||
run: GenerationMetadata,
|
||||
output: Union[LLMResult, BaseException],
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
):
|
||||
event_properties = {
|
||||
"$ai_trace_id": trace_id,
|
||||
"$ai_span_id": run_id,
|
||||
"$ai_span_name": run.name,
|
||||
"$ai_parent_id": parent_run_id,
|
||||
"$ai_provider": run.provider,
|
||||
"$ai_model": run.model,
|
||||
"$ai_model_parameters": run.model_params,
|
||||
"$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.input),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_latency": run.latency,
|
||||
"$ai_base_url": run.base_url,
|
||||
}
|
||||
if run.tools:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
self._client,
|
||||
self._privacy_mode,
|
||||
run.tools,
|
||||
)
|
||||
|
||||
if isinstance(output, BaseException):
|
||||
event_properties["$ai_http_status"] = _get_http_status(output)
|
||||
event_properties["$ai_error"] = _stringify_exception(output)
|
||||
event_properties["$ai_is_error"] = True
|
||||
else:
|
||||
# Add usage
|
||||
usage = _parse_usage(output)
|
||||
event_properties["$ai_input_tokens"] = usage.input_tokens
|
||||
event_properties["$ai_output_tokens"] = usage.output_tokens
|
||||
event_properties["$ai_cache_creation_input_tokens"] = (
|
||||
usage.cache_write_tokens
|
||||
)
|
||||
event_properties["$ai_cache_read_input_tokens"] = usage.cache_read_tokens
|
||||
event_properties["$ai_reasoning_tokens"] = usage.reasoning_tokens
|
||||
|
||||
# Generation results
|
||||
generation_result = output.generations[-1]
|
||||
if isinstance(generation_result[-1], ChatGeneration):
|
||||
completions = [
|
||||
_convert_message_to_dict(cast(ChatGeneration, generation).message)
|
||||
for generation in generation_result
|
||||
]
|
||||
else:
|
||||
completions = [
|
||||
_extract_raw_esponse(generation) for generation in generation_result
|
||||
]
|
||||
event_properties["$ai_output_choices"] = with_privacy_mode(
|
||||
self._client, self._privacy_mode, completions
|
||||
)
|
||||
|
||||
if self._properties:
|
||||
event_properties.update(self._properties)
|
||||
|
||||
if self._distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
self._client.capture(
|
||||
distinct_id=self._distinct_id or trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=self._groups,
|
||||
)
|
||||
|
||||
def _log_debug_event(
|
||||
self,
|
||||
event_name: str,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs,
|
||||
):
|
||||
log.debug(
|
||||
f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_raw_esponse(last_response):
|
||||
"""Extract the response from the last response of the LLM call."""
|
||||
# We return the text of the response if not empty
|
||||
if last_response.text is not None and last_response.text.strip() != "":
|
||||
return last_response.text.strip()
|
||||
elif hasattr(last_response, "message"):
|
||||
# Additional kwargs contains the response in case of tool usage
|
||||
return last_response.message.additional_kwargs
|
||||
else:
|
||||
# Not tool usage, some LLM responses can be simply empty
|
||||
return ""
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
|
||||
# assistant message
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {"role": "tool", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
message_dict = {"role": message.type, "content": str(message.content)}
|
||||
|
||||
if message.additional_kwargs:
|
||||
message_dict.update(message.additional_kwargs)
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelUsage:
|
||||
input_tokens: Optional[int]
|
||||
output_tokens: Optional[int]
|
||||
cache_write_tokens: Optional[int]
|
||||
cache_read_tokens: Optional[int]
|
||||
reasoning_tokens: Optional[int]
|
||||
|
||||
|
||||
def _parse_usage_model(
|
||||
usage: Union[BaseModel, dict],
|
||||
) -> ModelUsage:
|
||||
if isinstance(usage, BaseModel):
|
||||
usage = usage.__dict__
|
||||
|
||||
conversion_list = [
|
||||
# https://pypi.org/project/langchain-anthropic/ (works also for Bedrock-Anthropic)
|
||||
("input_tokens", "input"),
|
||||
("output_tokens", "output"),
|
||||
("cache_creation_input_tokens", "cache_write"),
|
||||
("cache_read_input_tokens", "cache_read"),
|
||||
# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/get-token-count
|
||||
("prompt_token_count", "input"),
|
||||
("candidates_token_count", "output"),
|
||||
("cached_content_token_count", "cache_read"),
|
||||
("thoughts_token_count", "reasoning"),
|
||||
# Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-cw.html#runtime-cloudwatch-metrics
|
||||
("inputTokenCount", "input"),
|
||||
("outputTokenCount", "output"),
|
||||
("cacheCreationInputTokenCount", "cache_write"),
|
||||
("cacheReadInputTokenCount", "cache_read"),
|
||||
# Bedrock Anthropic
|
||||
("prompt_tokens", "input"),
|
||||
("completion_tokens", "output"),
|
||||
("cache_creation_input_tokens", "cache_write"),
|
||||
("cache_read_input_tokens", "cache_read"),
|
||||
# langchain-ibm https://pypi.org/project/langchain-ibm/
|
||||
("input_token_count", "input"),
|
||||
("generated_token_count", "output"),
|
||||
]
|
||||
|
||||
parsed_usage = {}
|
||||
for model_key, type_key in conversion_list:
|
||||
if model_key in usage:
|
||||
captured_count = usage[model_key]
|
||||
final_count = (
|
||||
sum(captured_count)
|
||||
if isinstance(captured_count, list)
|
||||
else captured_count
|
||||
) # For Bedrock, the token count is a list when streamed
|
||||
|
||||
parsed_usage[type_key] = final_count
|
||||
|
||||
# Caching (OpenAI & langchain 0.3.9+)
|
||||
if "input_token_details" in usage and isinstance(
|
||||
usage["input_token_details"], dict
|
||||
):
|
||||
parsed_usage["cache_write"] = usage["input_token_details"].get("cache_creation")
|
||||
parsed_usage["cache_read"] = usage["input_token_details"].get("cache_read")
|
||||
|
||||
# Reasoning (OpenAI & langchain 0.3.9+)
|
||||
if "output_token_details" in usage and isinstance(
|
||||
usage["output_token_details"], dict
|
||||
):
|
||||
parsed_usage["reasoning"] = usage["output_token_details"].get("reasoning")
|
||||
|
||||
field_mapping = {
|
||||
"input": "input_tokens",
|
||||
"output": "output_tokens",
|
||||
"cache_write": "cache_write_tokens",
|
||||
"cache_read": "cache_read_tokens",
|
||||
"reasoning": "reasoning_tokens",
|
||||
}
|
||||
return ModelUsage(
|
||||
**{
|
||||
dataclass_key: parsed_usage.get(mapped_key) or 0
|
||||
for mapped_key, dataclass_key in field_mapping.items()
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _parse_usage(response: LLMResult) -> ModelUsage:
|
||||
# langchain-anthropic uses the usage field
|
||||
llm_usage_keys = ["token_usage", "usage"]
|
||||
llm_usage: ModelUsage = ModelUsage(
|
||||
input_tokens=None,
|
||||
output_tokens=None,
|
||||
cache_write_tokens=None,
|
||||
cache_read_tokens=None,
|
||||
reasoning_tokens=None,
|
||||
)
|
||||
|
||||
if response.llm_output is not None:
|
||||
for key in llm_usage_keys:
|
||||
if response.llm_output.get(key):
|
||||
llm_usage = _parse_usage_model(response.llm_output[key])
|
||||
break
|
||||
|
||||
if hasattr(response, "generations"):
|
||||
for generation in response.generations:
|
||||
if "usage" in generation:
|
||||
llm_usage = _parse_usage_model(generation["usage"])
|
||||
break
|
||||
|
||||
for generation_chunk in generation:
|
||||
if generation_chunk.generation_info and (
|
||||
"usage_metadata" in generation_chunk.generation_info
|
||||
):
|
||||
llm_usage = _parse_usage_model(
|
||||
generation_chunk.generation_info["usage_metadata"]
|
||||
)
|
||||
break
|
||||
|
||||
message_chunk = getattr(generation_chunk, "message", {})
|
||||
response_metadata = getattr(message_chunk, "response_metadata", {})
|
||||
|
||||
bedrock_anthropic_usage = (
|
||||
response_metadata.get("usage", None) # for Bedrock-Anthropic
|
||||
if isinstance(response_metadata, dict)
|
||||
else None
|
||||
)
|
||||
bedrock_titan_usage = (
|
||||
response_metadata.get(
|
||||
"amazon-bedrock-invocationMetrics", None
|
||||
) # for Bedrock-Titan
|
||||
if isinstance(response_metadata, dict)
|
||||
else None
|
||||
)
|
||||
ollama_usage = getattr(
|
||||
message_chunk, "usage_metadata", None
|
||||
) # for Ollama
|
||||
|
||||
chunk_usage = (
|
||||
bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
|
||||
)
|
||||
if chunk_usage:
|
||||
llm_usage = _parse_usage_model(chunk_usage)
|
||||
break
|
||||
|
||||
return llm_usage
|
||||
|
||||
|
||||
def _get_http_status(error: BaseException) -> int:
|
||||
# OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py
|
||||
# Anthropic: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
|
||||
# Google: https://github.com/googleapis/python-api-core/blob/main/google/api_core/exceptions.py
|
||||
status_code = getattr(error, "status_code", getattr(error, "code", 0))
|
||||
return status_code
|
||||
|
||||
|
||||
def _get_langchain_run_name(
|
||||
serialized: Optional[Dict[str, Any]], **kwargs: Any
|
||||
) -> Optional[str]:
|
||||
"""Retrieve the name of a serialized LangChain runnable.
|
||||
|
||||
The prioritization for the determination of the run name is as follows:
|
||||
- The value assigned to the "name" key in `kwargs`.
|
||||
- The value assigned to the "name" key in `serialized`.
|
||||
- The last entry of the value assigned to the "id" key in `serialized`.
|
||||
- "<unknown>".
|
||||
|
||||
Args:
|
||||
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
|
||||
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.
|
||||
|
||||
Returns:
|
||||
str: The determined name of the Langchain runnable.
|
||||
"""
|
||||
if "name" in kwargs and kwargs["name"] is not None:
|
||||
return kwargs["name"]
|
||||
if serialized is None:
|
||||
return None
|
||||
try:
|
||||
return serialized["name"]
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
try:
|
||||
return serialized["id"][-1]
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _stringify_exception(exception: BaseException) -> str:
|
||||
description = str(exception)
|
||||
if description:
|
||||
return f"{exception.__class__.__name__}: {description}"
|
||||
return exception.__class__.__name__
|
||||
@@ -0,0 +1,5 @@
|
||||
from .openai import OpenAI
|
||||
from .openai_async import AsyncOpenAI
|
||||
from .openai_providers import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
__all__ = ["OpenAI", "AsyncOpenAI", "AzureOpenAI", "AsyncAzureOpenAI"]
|
||||
@@ -0,0 +1,636 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the OpenAI SDK to use this feature: 'pip install openai'"
|
||||
)
|
||||
|
||||
from posthog.ai.utils import (
|
||||
call_llm_and_track_usage,
|
||||
get_model_params,
|
||||
with_privacy_mode,
|
||||
)
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class OpenAI(openai.OpenAI):
|
||||
"""
|
||||
A wrapper around the OpenAI SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
api_key: OpenAI API key.
|
||||
posthog_client: If provided, events will be captured via this client instead
|
||||
of the global posthog.
|
||||
**openai_config: Any additional keyword args to set on openai (e.g. organization="xxx").
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
|
||||
# Store original objects after parent initialization (only if they exist)
|
||||
self._original_chat = getattr(self, "chat", None)
|
||||
self._original_embeddings = getattr(self, "embeddings", None)
|
||||
self._original_beta = getattr(self, "beta", None)
|
||||
self._original_responses = getattr(self, "responses", None)
|
||||
|
||||
# Replace with wrapped versions (only if originals exist)
|
||||
if self._original_chat is not None:
|
||||
self.chat = WrappedChat(self, self._original_chat)
|
||||
|
||||
if self._original_embeddings is not None:
|
||||
self.embeddings = WrappedEmbeddings(self, self._original_embeddings)
|
||||
|
||||
if self._original_beta is not None:
|
||||
self.beta = WrappedBeta(self, self._original_beta)
|
||||
|
||||
if self._original_responses is not None:
|
||||
self.responses = WrappedResponses(self, self._original_responses)
|
||||
|
||||
|
||||
class WrappedResponses:
|
||||
"""Wrapper for OpenAI responses that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_responses):
|
||||
self._client = client
|
||||
self._original = original_responses
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original responses object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs.get("stream", False):
|
||||
return self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return call_llm_and_track_usage(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.create,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _create_streaming(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {}
|
||||
final_content = []
|
||||
response = self._original.create(**kwargs)
|
||||
|
||||
def generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal final_content # noqa: F824
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "type") and chunk.type == "response.completed":
|
||||
res = chunk.response
|
||||
if res.output and len(res.output) > 0:
|
||||
final_content.append(res.output[0])
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_stats = {
|
||||
k: getattr(chunk.usage, k, 0)
|
||||
for k in [
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"total_tokens",
|
||||
]
|
||||
}
|
||||
|
||||
# Add support for cached tokens
|
||||
if hasattr(chunk.usage, "output_tokens_details") and hasattr(
|
||||
chunk.usage.output_tokens_details, "reasoning_tokens"
|
||||
):
|
||||
usage_stats["reasoning_tokens"] = (
|
||||
chunk.usage.output_tokens_details.reasoning_tokens
|
||||
)
|
||||
|
||||
if hasattr(chunk.usage, "input_tokens_details") and hasattr(
|
||||
chunk.usage.input_tokens_details, "cached_tokens"
|
||||
):
|
||||
usage_stats["cache_read_input_tokens"] = (
|
||||
chunk.usage.input_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = final_content
|
||||
self._capture_streaming_event(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
def _capture_streaming_event(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: Any,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "openai",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client, posthog_privacy_mode, kwargs.get("input")
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
output,
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
|
||||
"$ai_cache_read_input_tokens": usage_stats.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
"$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
tool_calls,
|
||||
)
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
def parse(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Parse structured output using OpenAI's 'responses.parse' method, but also track usage in PostHog.
|
||||
|
||||
Args:
|
||||
posthog_distinct_id: Optional ID to associate with the usage event.
|
||||
posthog_trace_id: Optional trace UUID for linking events.
|
||||
posthog_properties: Optional dictionary of extra properties to include in the event.
|
||||
posthog_privacy_mode: Whether to anonymize the input and output.
|
||||
posthog_groups: Optional dictionary of groups to associate with the event.
|
||||
**kwargs: Any additional parameters for the OpenAI Responses Parse API.
|
||||
|
||||
Returns:
|
||||
The response from OpenAI's responses.parse call.
|
||||
"""
|
||||
return call_llm_and_track_usage(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.parse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WrappedChat:
|
||||
"""Wrapper for OpenAI chat that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_chat):
|
||||
self._client = client
|
||||
self._original = original_chat
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original chat object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
@property
|
||||
def completions(self):
|
||||
return WrappedCompletions(self._client, self._original.completions)
|
||||
|
||||
|
||||
class WrappedCompletions:
|
||||
"""Wrapper for OpenAI chat completions that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_completions):
|
||||
self._client = client
|
||||
self._original = original_completions
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original completions object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs.get("stream", False):
|
||||
return self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return call_llm_and_track_usage(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.create,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _create_streaming(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {}
|
||||
accumulated_content = []
|
||||
accumulated_tools = {}
|
||||
if "stream_options" not in kwargs:
|
||||
kwargs["stream_options"] = {}
|
||||
kwargs["stream_options"]["include_usage"] = True
|
||||
response = self._original.create(**kwargs)
|
||||
|
||||
def generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal accumulated_content # noqa: F824
|
||||
nonlocal accumulated_tools # noqa: F824
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_stats = {
|
||||
k: getattr(chunk.usage, k, 0)
|
||||
for k in [
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]
|
||||
}
|
||||
|
||||
# Add support for cached tokens
|
||||
if hasattr(chunk.usage, "prompt_tokens_details") and hasattr(
|
||||
chunk.usage.prompt_tokens_details, "cached_tokens"
|
||||
):
|
||||
usage_stats["cache_read_input_tokens"] = (
|
||||
chunk.usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
if hasattr(chunk.usage, "output_tokens_details") and hasattr(
|
||||
chunk.usage.output_tokens_details, "reasoning_tokens"
|
||||
):
|
||||
usage_stats["reasoning_tokens"] = (
|
||||
chunk.usage.output_tokens_details.reasoning_tokens
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(chunk, "choices")
|
||||
and chunk.choices
|
||||
and len(chunk.choices) > 0
|
||||
):
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
if content:
|
||||
accumulated_content.append(content)
|
||||
|
||||
# Process tool calls
|
||||
tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
index = tool_call.index
|
||||
if index not in accumulated_tools:
|
||||
accumulated_tools[index] = tool_call
|
||||
else:
|
||||
# Append arguments for existing tool calls
|
||||
if hasattr(tool_call, "function") and hasattr(
|
||||
tool_call.function, "arguments"
|
||||
):
|
||||
accumulated_tools[
|
||||
index
|
||||
].function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = "".join(accumulated_content)
|
||||
tools = list(accumulated_tools.values()) if accumulated_tools else None
|
||||
self._capture_streaming_event(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
tools,
|
||||
)
|
||||
|
||||
return generator()
|
||||
|
||||
def _capture_streaming_event(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: Any,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "openai",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client, posthog_privacy_mode, kwargs.get("messages")
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
[{"content": output, "role": "assistant"}],
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("completion_tokens", 0),
|
||||
"$ai_cache_read_input_tokens": usage_stats.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
"$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
tool_calls,
|
||||
)
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
|
||||
class WrappedEmbeddings:
|
||||
"""Wrapper for OpenAI embeddings that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_embeddings):
|
||||
self._client = client
|
||||
self._original = original_embeddings
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original embeddings object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Create an embedding using OpenAI's 'embeddings.create' method, but also track usage in PostHog.
|
||||
|
||||
Args:
|
||||
posthog_distinct_id: Optional ID to associate with the usage event.
|
||||
posthog_trace_id: Optional trace UUID for linking events.
|
||||
posthog_properties: Optional dictionary of extra properties to include in the event.
|
||||
posthog_privacy_mode: Whether to anonymize the input and output.
|
||||
posthog_groups: Optional dictionary of groups to associate with the event.
|
||||
**kwargs: Any additional parameters for the OpenAI Embeddings API.
|
||||
|
||||
Returns:
|
||||
The response from OpenAI's embeddings.create call.
|
||||
"""
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
start_time = time.time()
|
||||
response = self._original.create(**kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
# Extract usage statistics if available
|
||||
usage_stats = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage_stats = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
|
||||
latency = end_time - start_time
|
||||
|
||||
# Build the event properties
|
||||
event_properties = {
|
||||
"$ai_provider": "openai",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client, posthog_privacy_mode, kwargs.get("input")
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
# Send capture event for embeddings
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_embedding",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class WrappedBeta:
|
||||
"""Wrapper for OpenAI beta features that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_beta):
|
||||
self._client = client
|
||||
self._original = original_beta
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original beta object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
@property
|
||||
def chat(self):
|
||||
return WrappedBetaChat(self._client, self._original.chat)
|
||||
|
||||
|
||||
class WrappedBetaChat:
|
||||
"""Wrapper for OpenAI beta chat that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_beta_chat):
|
||||
self._client = client
|
||||
self._original = original_beta_chat
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original beta chat object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
@property
|
||||
def completions(self):
|
||||
return WrappedBetaCompletions(self._client, self._original.completions)
|
||||
|
||||
|
||||
class WrappedBetaCompletions:
|
||||
"""Wrapper for OpenAI beta chat completions that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: OpenAI, original_beta_completions):
|
||||
self._client = client
|
||||
self._original = original_beta_completions
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original beta completions object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
def parse(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
return call_llm_and_track_usage(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.parse,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,639 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the OpenAI SDK to use this feature: 'pip install openai'"
|
||||
)
|
||||
|
||||
from posthog.ai.utils import (
|
||||
call_llm_and_track_usage_async,
|
||||
get_model_params,
|
||||
with_privacy_mode,
|
||||
)
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class AsyncOpenAI(openai.AsyncOpenAI):
|
||||
"""
|
||||
An async wrapper around the OpenAI SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
api_key: OpenAI API key.
|
||||
posthog_client: If provided, events will be captured via this client instead
|
||||
of the global posthog.
|
||||
**openai_config: Any additional keyword args to set on openai (e.g. organization="xxx").
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
|
||||
# Store original objects after parent initialization (only if they exist)
|
||||
self._original_chat = getattr(self, "chat", None)
|
||||
self._original_embeddings = getattr(self, "embeddings", None)
|
||||
self._original_beta = getattr(self, "beta", None)
|
||||
self._original_responses = getattr(self, "responses", None)
|
||||
|
||||
# Replace with wrapped versions (only if originals exist)
|
||||
if self._original_chat is not None:
|
||||
self.chat = WrappedChat(self, self._original_chat)
|
||||
|
||||
if self._original_embeddings is not None:
|
||||
self.embeddings = WrappedEmbeddings(self, self._original_embeddings)
|
||||
|
||||
if self._original_beta is not None:
|
||||
self.beta = WrappedBeta(self, self._original_beta)
|
||||
|
||||
if self._original_responses is not None:
|
||||
self.responses = WrappedResponses(self, self._original_responses)
|
||||
|
||||
|
||||
class WrappedResponses:
|
||||
"""Async wrapper for OpenAI responses that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_responses):
|
||||
self._client = client
|
||||
self._original = original_responses
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original responses object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if kwargs.get("stream", False):
|
||||
return await self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return await call_llm_and_track_usage_async(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.create,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _create_streaming(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {}
|
||||
final_content = []
|
||||
response = await self._original.create(**kwargs)
|
||||
|
||||
async def async_generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal final_content # noqa: F824
|
||||
|
||||
try:
|
||||
async for chunk in response:
|
||||
if hasattr(chunk, "type") and chunk.type == "response.completed":
|
||||
res = chunk.response
|
||||
if res.output and len(res.output) > 0:
|
||||
final_content.append(res.output[0])
|
||||
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_stats = {
|
||||
k: getattr(chunk.usage, k, 0)
|
||||
for k in [
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"total_tokens",
|
||||
]
|
||||
}
|
||||
|
||||
# Add support for cached tokens
|
||||
if hasattr(chunk.usage, "output_tokens_details") and hasattr(
|
||||
chunk.usage.output_tokens_details, "reasoning_tokens"
|
||||
):
|
||||
usage_stats["reasoning_tokens"] = (
|
||||
chunk.usage.output_tokens_details.reasoning_tokens
|
||||
)
|
||||
|
||||
if hasattr(chunk.usage, "input_tokens_details") and hasattr(
|
||||
chunk.usage.input_tokens_details, "cached_tokens"
|
||||
):
|
||||
usage_stats["cache_read_input_tokens"] = (
|
||||
chunk.usage.input_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = final_content
|
||||
await self._capture_streaming_event(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
)
|
||||
|
||||
return async_generator()
|
||||
|
||||
async def _capture_streaming_event(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: Any,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "openai",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client, posthog_privacy_mode, kwargs.get("input")
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
output,
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("output_tokens", 0),
|
||||
"$ai_cache_read_input_tokens": usage_stats.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
"$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
tool_calls,
|
||||
)
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
async def parse(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Parse structured output using OpenAI's 'responses.parse' method, but also track usage in PostHog.
|
||||
|
||||
Args:
|
||||
posthog_distinct_id: Optional ID to associate with the usage event.
|
||||
posthog_trace_id: Optional trace UUID for linking events.
|
||||
posthog_properties: Optional dictionary of extra properties to include in the event.
|
||||
posthog_privacy_mode: Whether to anonymize the input and output.
|
||||
posthog_groups: Optional dictionary of groups to associate with the event.
|
||||
**kwargs: Any additional parameters for the OpenAI Responses Parse API.
|
||||
|
||||
Returns:
|
||||
The response from OpenAI's responses.parse call.
|
||||
"""
|
||||
return await call_llm_and_track_usage_async(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.parse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WrappedChat:
|
||||
"""Async wrapper for OpenAI chat that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_chat):
|
||||
self._client = client
|
||||
self._original = original_chat
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original chat object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
@property
|
||||
def completions(self):
|
||||
return WrappedCompletions(self._client, self._original.completions)
|
||||
|
||||
|
||||
class WrappedCompletions:
|
||||
"""Async wrapper for OpenAI chat completions that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_completions):
|
||||
self._client = client
|
||||
self._original = original_completions
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original completions object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
# If streaming, handle streaming specifically
|
||||
if kwargs.get("stream", False):
|
||||
return await self._create_streaming(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
response = await call_llm_and_track_usage_async(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.create,
|
||||
**kwargs,
|
||||
)
|
||||
return response
|
||||
|
||||
async def _create_streaming(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
**kwargs: Any,
|
||||
):
|
||||
start_time = time.time()
|
||||
usage_stats: Dict[str, int] = {}
|
||||
accumulated_content = []
|
||||
accumulated_tools = {}
|
||||
|
||||
if "stream_options" not in kwargs:
|
||||
kwargs["stream_options"] = {}
|
||||
kwargs["stream_options"]["include_usage"] = True
|
||||
response = await self._original.create(**kwargs)
|
||||
|
||||
async def async_generator():
|
||||
nonlocal usage_stats
|
||||
nonlocal accumulated_content # noqa: F824
|
||||
nonlocal accumulated_tools # noqa: F824
|
||||
|
||||
try:
|
||||
async for chunk in response:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_stats = {
|
||||
k: getattr(chunk.usage, k, 0)
|
||||
for k in [
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]
|
||||
}
|
||||
|
||||
# Add support for cached tokens
|
||||
if hasattr(chunk.usage, "prompt_tokens_details") and hasattr(
|
||||
chunk.usage.prompt_tokens_details, "cached_tokens"
|
||||
):
|
||||
usage_stats["cache_read_input_tokens"] = (
|
||||
chunk.usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
|
||||
if hasattr(chunk.usage, "output_tokens_details") and hasattr(
|
||||
chunk.usage.output_tokens_details, "reasoning_tokens"
|
||||
):
|
||||
usage_stats["reasoning_tokens"] = (
|
||||
chunk.usage.output_tokens_details.reasoning_tokens
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(chunk, "choices")
|
||||
and chunk.choices
|
||||
and len(chunk.choices) > 0
|
||||
):
|
||||
if chunk.choices[0].delta and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
if content:
|
||||
accumulated_content.append(content)
|
||||
|
||||
# Process tool calls
|
||||
tool_calls = getattr(chunk.choices[0].delta, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
index = tool_call.index
|
||||
if index not in accumulated_tools:
|
||||
accumulated_tools[index] = tool_call
|
||||
else:
|
||||
# Append arguments for existing tool calls
|
||||
if hasattr(tool_call, "function") and hasattr(
|
||||
tool_call.function, "arguments"
|
||||
):
|
||||
accumulated_tools[
|
||||
index
|
||||
].function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
output = "".join(accumulated_content)
|
||||
tools = list(accumulated_tools.values()) if accumulated_tools else None
|
||||
await self._capture_streaming_event(
|
||||
posthog_distinct_id,
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
kwargs,
|
||||
usage_stats,
|
||||
latency,
|
||||
output,
|
||||
tools,
|
||||
)
|
||||
|
||||
return async_generator()
|
||||
|
||||
async def _capture_streaming_event(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str],
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
usage_stats: Dict[str, int],
|
||||
latency: float,
|
||||
output: Any,
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None,
|
||||
):
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": "openai",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client, posthog_privacy_mode, kwargs.get("messages")
|
||||
),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
[{"content": output, "role": "assistant"}],
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),
|
||||
"$ai_output_tokens": usage_stats.get("completion_tokens", 0),
|
||||
"$ai_cache_read_input_tokens": usage_stats.get(
|
||||
"cache_read_input_tokens", 0
|
||||
),
|
||||
"$ai_reasoning_tokens": usage_stats.get("reasoning_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
self._client._ph_client,
|
||||
posthog_privacy_mode,
|
||||
tool_calls,
|
||||
)
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
|
||||
class WrappedEmbeddings:
|
||||
"""Async wrapper for OpenAI embeddings that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_embeddings):
|
||||
self._client = client
|
||||
self._original = original_embeddings
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original embeddings object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Create an embedding using OpenAI's 'embeddings.create' method, but also track usage in PostHog.
|
||||
|
||||
Args:
|
||||
posthog_distinct_id: Optional ID to associate with the usage event.
|
||||
posthog_trace_id: Optional trace UUID for linking events.
|
||||
posthog_properties: Optional dictionary of extra properties to include in the event.
|
||||
posthog_privacy_mode: Whether to anonymize the input and output.
|
||||
posthog_groups: Optional dictionary of groups to associate with the event.
|
||||
**kwargs: Any additional parameters for the OpenAI Embeddings API.
|
||||
|
||||
Returns:
|
||||
The response from OpenAI's embeddings.create call.
|
||||
"""
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._original.create(**kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
# Extract usage statistics if available
|
||||
usage_stats = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage_stats = {
|
||||
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||
}
|
||||
|
||||
latency = end_time - start_time
|
||||
|
||||
# Build the event properties
|
||||
event_properties = {
|
||||
"$ai_provider": "openai",
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_input": with_privacy_mode(
|
||||
self._client._ph_client, posthog_privacy_mode, kwargs.get("input")
|
||||
),
|
||||
"$ai_http_status": 200,
|
||||
"$ai_input_tokens": usage_stats.get("prompt_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(self._client.base_url),
|
||||
**(posthog_properties or {}),
|
||||
}
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
# Send capture event for embeddings
|
||||
if hasattr(self._client._ph_client, "capture"):
|
||||
self._client._ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_embedding",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class WrappedBeta:
|
||||
"""Async wrapper for OpenAI beta features that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_beta):
|
||||
self._client = client
|
||||
self._original = original_beta
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original beta object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
@property
|
||||
def chat(self):
|
||||
return WrappedBetaChat(self._client, self._original.chat)
|
||||
|
||||
|
||||
class WrappedBetaChat:
|
||||
"""Async wrapper for OpenAI beta chat that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_beta_chat):
|
||||
self._client = client
|
||||
self._original = original_beta_chat
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original beta chat object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
@property
|
||||
def completions(self):
|
||||
return WrappedBetaCompletions(self._client, self._original.completions)
|
||||
|
||||
|
||||
class WrappedBetaCompletions:
|
||||
"""Async wrapper for OpenAI beta chat completions that tracks usage in PostHog."""
|
||||
|
||||
def __init__(self, client: AsyncOpenAI, original_beta_completions):
|
||||
self._client = client
|
||||
self._original = original_beta_completions
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Fallback to original beta completions object for any methods we don't explicitly handle."""
|
||||
return getattr(self._original, name)
|
||||
|
||||
async def parse(
|
||||
self,
|
||||
posthog_distinct_id: Optional[str] = None,
|
||||
posthog_trace_id: Optional[str] = None,
|
||||
posthog_properties: Optional[Dict[str, Any]] = None,
|
||||
posthog_privacy_mode: bool = False,
|
||||
posthog_groups: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
return await call_llm_and_track_usage_async(
|
||||
posthog_distinct_id,
|
||||
self._client._ph_client,
|
||||
"openai",
|
||||
posthog_trace_id,
|
||||
posthog_properties,
|
||||
posthog_privacy_mode,
|
||||
posthog_groups,
|
||||
self._client.base_url,
|
||||
self._original.parse,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,95 @@
|
||||
try:
|
||||
import openai
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install the Open AI SDK to use this feature: 'pip install openai'"
|
||||
)
|
||||
|
||||
from posthog.ai.openai.openai import (
|
||||
WrappedBeta,
|
||||
WrappedChat,
|
||||
WrappedEmbeddings,
|
||||
WrappedResponses,
|
||||
)
|
||||
from posthog.ai.openai.openai_async import WrappedBeta as AsyncWrappedBeta
|
||||
from posthog.ai.openai.openai_async import WrappedChat as AsyncWrappedChat
|
||||
from posthog.ai.openai.openai_async import WrappedEmbeddings as AsyncWrappedEmbeddings
|
||||
from posthog.ai.openai.openai_async import WrappedResponses as AsyncWrappedResponses
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
class AzureOpenAI(openai.AzureOpenAI):
|
||||
"""
|
||||
A wrapper around the Azure OpenAI SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
api_key: Azure OpenAI API key.
|
||||
posthog_client: If provided, events will be captured via this client instead
|
||||
of the global posthog.
|
||||
**openai_config: Any additional keyword args to set on Azure OpenAI (e.g. azure_endpoint="xxx").
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
|
||||
# Store original objects after parent initialization (only if they exist)
|
||||
self._original_chat = getattr(self, "chat", None)
|
||||
self._original_embeddings = getattr(self, "embeddings", None)
|
||||
self._original_beta = getattr(self, "beta", None)
|
||||
self._original_responses = getattr(self, "responses", None)
|
||||
|
||||
# Replace with wrapped versions (only if originals exist)
|
||||
if self._original_chat is not None:
|
||||
self.chat = WrappedChat(self, self._original_chat)
|
||||
|
||||
if self._original_embeddings is not None:
|
||||
self.embeddings = WrappedEmbeddings(self, self._original_embeddings)
|
||||
|
||||
if self._original_beta is not None:
|
||||
self.beta = WrappedBeta(self, self._original_beta)
|
||||
|
||||
if self._original_responses is not None:
|
||||
self.responses = WrappedResponses(self, self._original_responses)
|
||||
|
||||
|
||||
class AsyncAzureOpenAI(openai.AsyncAzureOpenAI):
|
||||
"""
|
||||
An async wrapper around the Azure OpenAI SDK that automatically sends LLM usage events to PostHog.
|
||||
"""
|
||||
|
||||
_ph_client: PostHogClient
|
||||
|
||||
def __init__(self, posthog_client: PostHogClient, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
api_key: Azure OpenAI API key.
|
||||
posthog_client: If provided, events will be captured via this client instead
|
||||
of the global posthog.
|
||||
**openai_config: Any additional keyword args to set on Azure OpenAI (e.g. azure_endpoint="xxx").
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._ph_client = posthog_client
|
||||
|
||||
# Store original objects after parent initialization (only if they exist)
|
||||
self._original_chat = getattr(self, "chat", None)
|
||||
self._original_embeddings = getattr(self, "embeddings", None)
|
||||
self._original_beta = getattr(self, "beta", None)
|
||||
self._original_responses = getattr(self, "responses", None)
|
||||
|
||||
# Replace with wrapped versions (only if originals exist)
|
||||
if self._original_chat is not None:
|
||||
self.chat = AsyncWrappedChat(self, self._original_chat)
|
||||
|
||||
if self._original_embeddings is not None:
|
||||
self.embeddings = AsyncWrappedEmbeddings(self, self._original_embeddings)
|
||||
|
||||
if self._original_beta is not None:
|
||||
self.beta = AsyncWrappedBeta(self, self._original_beta)
|
||||
|
||||
# Only add responses if available (newer OpenAI versions)
|
||||
if self._original_responses is not None:
|
||||
self.responses = AsyncWrappedResponses(self, self._original_responses)
|
||||
@@ -0,0 +1,544 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from httpx import URL
|
||||
|
||||
from posthog.client import Client as PostHogClient
|
||||
|
||||
|
||||
def get_model_params(kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts model parameters from the kwargs dictionary.
|
||||
"""
|
||||
model_params = {}
|
||||
for param in [
|
||||
"temperature",
|
||||
"max_tokens", # Deprecated field
|
||||
"max_completion_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"n",
|
||||
"stop",
|
||||
"stream", # OpenAI-specific field
|
||||
"streaming", # Anthropic-specific field
|
||||
]:
|
||||
if param in kwargs and kwargs[param] is not None:
|
||||
model_params[param] = kwargs[param]
|
||||
return model_params
|
||||
|
||||
|
||||
def get_usage(response, provider: str) -> Dict[str, Any]:
|
||||
if provider == "anthropic":
|
||||
return {
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
"cache_read_input_tokens": response.usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": response.usage.cache_creation_input_tokens,
|
||||
}
|
||||
elif provider == "openai":
|
||||
cached_tokens = 0
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
reasoning_tokens = 0
|
||||
|
||||
# responses api
|
||||
if hasattr(response.usage, "input_tokens"):
|
||||
input_tokens = response.usage.input_tokens
|
||||
if hasattr(response.usage, "output_tokens"):
|
||||
output_tokens = response.usage.output_tokens
|
||||
if hasattr(response.usage, "input_tokens_details") and hasattr(
|
||||
response.usage.input_tokens_details, "cached_tokens"
|
||||
):
|
||||
cached_tokens = response.usage.input_tokens_details.cached_tokens
|
||||
if hasattr(response.usage, "output_tokens_details") and hasattr(
|
||||
response.usage.output_tokens_details, "reasoning_tokens"
|
||||
):
|
||||
reasoning_tokens = response.usage.output_tokens_details.reasoning_tokens
|
||||
|
||||
# chat completions
|
||||
if hasattr(response.usage, "prompt_tokens"):
|
||||
input_tokens = response.usage.prompt_tokens
|
||||
if hasattr(response.usage, "completion_tokens"):
|
||||
output_tokens = response.usage.completion_tokens
|
||||
if hasattr(response.usage, "prompt_tokens_details") and hasattr(
|
||||
response.usage.prompt_tokens_details, "cached_tokens"
|
||||
):
|
||||
cached_tokens = response.usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_read_input_tokens": cached_tokens,
|
||||
"reasoning_tokens": reasoning_tokens,
|
||||
}
|
||||
elif provider == "gemini":
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
|
||||
if hasattr(response, "usage_metadata") and response.usage_metadata:
|
||||
input_tokens = getattr(response.usage_metadata, "prompt_token_count", 0)
|
||||
output_tokens = getattr(
|
||||
response.usage_metadata, "candidates_token_count", 0
|
||||
)
|
||||
|
||||
return {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
}
|
||||
return {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
}
|
||||
|
||||
|
||||
def format_response(response, provider: str):
|
||||
"""
|
||||
Format a regular (non-streaming) response.
|
||||
"""
|
||||
output = []
|
||||
if response is None:
|
||||
return output
|
||||
if provider == "anthropic":
|
||||
return format_response_anthropic(response)
|
||||
elif provider == "openai":
|
||||
return format_response_openai(response)
|
||||
elif provider == "gemini":
|
||||
return format_response_gemini(response)
|
||||
return output
|
||||
|
||||
|
||||
def format_response_anthropic(response):
|
||||
output = []
|
||||
for choice in response.content:
|
||||
if choice.text:
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": choice.text,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def format_response_openai(response):
|
||||
output = []
|
||||
if hasattr(response, "choices"):
|
||||
for choice in response.choices:
|
||||
# Handle Chat Completions response format
|
||||
if hasattr(choice, "message") and choice.message and choice.message.content:
|
||||
output.append(
|
||||
{
|
||||
"content": choice.message.content,
|
||||
"role": choice.message.role,
|
||||
}
|
||||
)
|
||||
# Handle Responses API format
|
||||
if hasattr(response, "output"):
|
||||
for item in response.output:
|
||||
if item.type == "message":
|
||||
# Extract text content from the content list
|
||||
if hasattr(item, "content") and isinstance(item.content, list):
|
||||
for content_item in item.content:
|
||||
if (
|
||||
hasattr(content_item, "type")
|
||||
and content_item.type == "output_text"
|
||||
and hasattr(content_item, "text")
|
||||
):
|
||||
output.append(
|
||||
{
|
||||
"content": content_item.text,
|
||||
"role": item.role,
|
||||
}
|
||||
)
|
||||
elif hasattr(content_item, "text"):
|
||||
output.append(
|
||||
{
|
||||
"content": content_item.text,
|
||||
"role": item.role,
|
||||
}
|
||||
)
|
||||
elif (
|
||||
hasattr(content_item, "type")
|
||||
and content_item.type == "input_image"
|
||||
and hasattr(content_item, "image_url")
|
||||
):
|
||||
output.append(
|
||||
{
|
||||
"content": {
|
||||
"type": "image",
|
||||
"image": content_item.image_url,
|
||||
},
|
||||
"role": item.role,
|
||||
}
|
||||
)
|
||||
else:
|
||||
output.append(
|
||||
{
|
||||
"content": item.content,
|
||||
"role": item.role,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def format_response_gemini(response):
|
||||
output = []
|
||||
if hasattr(response, "candidates") and response.candidates:
|
||||
for candidate in response.candidates:
|
||||
if hasattr(candidate, "content") and candidate.content:
|
||||
content_text = ""
|
||||
if hasattr(candidate.content, "parts") and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "text") and part.text:
|
||||
content_text += part.text
|
||||
if content_text:
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": content_text,
|
||||
}
|
||||
)
|
||||
elif hasattr(candidate, "text") and candidate.text:
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": candidate.text,
|
||||
}
|
||||
)
|
||||
elif hasattr(response, "text") and response.text:
|
||||
output.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response.text,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def format_tool_calls(response, provider: str):
|
||||
if provider == "anthropic":
|
||||
if hasattr(response, "tools") and response.tools and len(response.tools) > 0:
|
||||
return response.tools
|
||||
elif provider == "openai":
|
||||
# Handle both Chat Completions and Responses API
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
# Check for tool_calls in message (Chat Completions format)
|
||||
if (
|
||||
hasattr(response.choices[0], "message")
|
||||
and hasattr(response.choices[0].message, "tool_calls")
|
||||
and response.choices[0].message.tool_calls
|
||||
):
|
||||
return response.choices[0].message.tool_calls
|
||||
|
||||
# Check for tool_calls directly in response (Responses API format)
|
||||
if (
|
||||
hasattr(response.choices[0], "tool_calls")
|
||||
and response.choices[0].tool_calls
|
||||
):
|
||||
return response.choices[0].tool_calls
|
||||
return None
|
||||
|
||||
|
||||
def merge_system_prompt(kwargs: Dict[str, Any], provider: str):
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if provider == "anthropic":
|
||||
messages = kwargs.get("messages") or []
|
||||
if kwargs.get("system") is None:
|
||||
return messages
|
||||
return [{"role": "system", "content": kwargs.get("system")}] + messages
|
||||
elif provider == "gemini":
|
||||
contents = kwargs.get("contents", [])
|
||||
if isinstance(contents, str):
|
||||
return [{"role": "user", "content": contents}]
|
||||
elif isinstance(contents, list):
|
||||
formatted = []
|
||||
for item in contents:
|
||||
if isinstance(item, str):
|
||||
formatted.append({"role": "user", "content": item})
|
||||
elif hasattr(item, "text"):
|
||||
formatted.append({"role": "user", "content": item.text})
|
||||
else:
|
||||
formatted.append({"role": "user", "content": str(item)})
|
||||
return formatted
|
||||
else:
|
||||
return [{"role": "user", "content": str(contents)}]
|
||||
|
||||
# For OpenAI, handle both Chat Completions and Responses API
|
||||
if kwargs.get("messages") is not None:
|
||||
messages = list(kwargs.get("messages", []))
|
||||
|
||||
if kwargs.get("input") is not None:
|
||||
input_data = kwargs.get("input")
|
||||
if isinstance(input_data, list):
|
||||
messages.extend(input_data)
|
||||
else:
|
||||
messages.append({"role": "user", "content": input_data})
|
||||
|
||||
# Check if system prompt is provided as a separate parameter
|
||||
if kwargs.get("system") is not None:
|
||||
has_system = any(msg.get("role") == "system" for msg in messages)
|
||||
if not has_system:
|
||||
messages = [{"role": "system", "content": kwargs.get("system")}] + messages
|
||||
|
||||
# For Responses API, add instructions to the system prompt if provided
|
||||
if kwargs.get("instructions") is not None:
|
||||
# Find the system message if it exists
|
||||
system_idx = next(
|
||||
(i for i, msg in enumerate(messages) if msg.get("role") == "system"), None
|
||||
)
|
||||
|
||||
if system_idx is not None:
|
||||
# Append instructions to existing system message
|
||||
system_content = messages[system_idx].get("content", "")
|
||||
messages[system_idx]["content"] = (
|
||||
f"{system_content}\n\n{kwargs.get('instructions')}"
|
||||
)
|
||||
else:
|
||||
# Create a new system message with instructions
|
||||
messages = [
|
||||
{"role": "system", "content": kwargs.get("instructions")}
|
||||
] + messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def call_llm_and_track_usage(
|
||||
posthog_distinct_id: Optional[str],
|
||||
ph_client: PostHogClient,
|
||||
provider: str,
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
base_url: URL,
|
||||
call_method: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Common usage-tracking logic for both sync and async calls.
|
||||
call_method: the llm call method (e.g. openai.chat.completions.create)
|
||||
"""
|
||||
start_time = time.time()
|
||||
response = None
|
||||
error = None
|
||||
http_status = 200
|
||||
usage: Dict[str, Any] = {}
|
||||
error_params: Dict[str, any] = {}
|
||||
|
||||
try:
|
||||
response = call_method(**kwargs)
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
http_status = getattr(
|
||||
exc, "status_code", 0
|
||||
) # default to 0 becuase its likely an SDK error
|
||||
error_params = {
|
||||
"$ai_is_error": True,
|
||||
"$ai_error": exc.__str__(),
|
||||
}
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if response and (
|
||||
hasattr(response, "usage")
|
||||
or (provider == "gemini" and hasattr(response, "usage_metadata"))
|
||||
):
|
||||
usage = get_usage(response, provider)
|
||||
|
||||
messages = merge_system_prompt(kwargs, provider)
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": provider,
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(ph_client, posthog_privacy_mode, messages),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
ph_client, posthog_privacy_mode, format_response(response, provider)
|
||||
),
|
||||
"$ai_http_status": http_status,
|
||||
"$ai_input_tokens": usage.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage.get("output_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(base_url),
|
||||
**(posthog_properties or {}),
|
||||
**(error_params or {}),
|
||||
}
|
||||
|
||||
tool_calls = format_tool_calls(response, provider)
|
||||
if tool_calls:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
ph_client, posthog_privacy_mode, tool_calls
|
||||
)
|
||||
|
||||
if (
|
||||
usage.get("cache_read_input_tokens") is not None
|
||||
and usage.get("cache_read_input_tokens", 0) > 0
|
||||
):
|
||||
event_properties["$ai_cache_read_input_tokens"] = usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
)
|
||||
|
||||
if (
|
||||
usage.get("cache_creation_input_tokens") is not None
|
||||
and usage.get("cache_creation_input_tokens", 0) > 0
|
||||
):
|
||||
event_properties["$ai_cache_creation_input_tokens"] = usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
|
||||
if (
|
||||
usage.get("reasoning_tokens") is not None
|
||||
and usage.get("reasoning_tokens", 0) > 0
|
||||
):
|
||||
event_properties["$ai_reasoning_tokens"] = usage.get("reasoning_tokens", 0)
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
# Process instructions for Responses API
|
||||
if provider == "openai" and kwargs.get("instructions") is not None:
|
||||
event_properties["$ai_instructions"] = with_privacy_mode(
|
||||
ph_client, posthog_privacy_mode, kwargs.get("instructions")
|
||||
)
|
||||
|
||||
# send the event to posthog
|
||||
if hasattr(ph_client, "capture") and callable(ph_client.capture):
|
||||
ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
if error:
|
||||
raise error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def call_llm_and_track_usage_async(
|
||||
posthog_distinct_id: Optional[str],
|
||||
ph_client: PostHogClient,
|
||||
provider: str,
|
||||
posthog_trace_id: Optional[str],
|
||||
posthog_properties: Optional[Dict[str, Any]],
|
||||
posthog_privacy_mode: bool,
|
||||
posthog_groups: Optional[Dict[str, Any]],
|
||||
base_url: URL,
|
||||
call_async_method: Callable[..., Any],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
start_time = time.time()
|
||||
response = None
|
||||
error = None
|
||||
http_status = 200
|
||||
usage: Dict[str, Any] = {}
|
||||
error_params: Dict[str, any] = {}
|
||||
|
||||
try:
|
||||
response = await call_async_method(**kwargs)
|
||||
except Exception as exc:
|
||||
error = exc
|
||||
http_status = getattr(
|
||||
exc, "status_code", 0
|
||||
) # default to 0 because its likely an SDK error
|
||||
error_params = {
|
||||
"$ai_is_error": True,
|
||||
"$ai_error": exc.__str__(),
|
||||
}
|
||||
finally:
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
|
||||
if posthog_trace_id is None:
|
||||
posthog_trace_id = str(uuid.uuid4())
|
||||
|
||||
if response and (
|
||||
hasattr(response, "usage")
|
||||
or (provider == "gemini" and hasattr(response, "usage_metadata"))
|
||||
):
|
||||
usage = get_usage(response, provider)
|
||||
|
||||
messages = merge_system_prompt(kwargs, provider)
|
||||
|
||||
event_properties = {
|
||||
"$ai_provider": provider,
|
||||
"$ai_model": kwargs.get("model"),
|
||||
"$ai_model_parameters": get_model_params(kwargs),
|
||||
"$ai_input": with_privacy_mode(ph_client, posthog_privacy_mode, messages),
|
||||
"$ai_output_choices": with_privacy_mode(
|
||||
ph_client, posthog_privacy_mode, format_response(response, provider)
|
||||
),
|
||||
"$ai_http_status": http_status,
|
||||
"$ai_input_tokens": usage.get("input_tokens", 0),
|
||||
"$ai_output_tokens": usage.get("output_tokens", 0),
|
||||
"$ai_latency": latency,
|
||||
"$ai_trace_id": posthog_trace_id,
|
||||
"$ai_base_url": str(base_url),
|
||||
**(posthog_properties or {}),
|
||||
**(error_params or {}),
|
||||
}
|
||||
|
||||
tool_calls = format_tool_calls(response, provider)
|
||||
if tool_calls:
|
||||
event_properties["$ai_tools"] = with_privacy_mode(
|
||||
ph_client, posthog_privacy_mode, tool_calls
|
||||
)
|
||||
|
||||
if (
|
||||
usage.get("cache_read_input_tokens") is not None
|
||||
and usage.get("cache_read_input_tokens", 0) > 0
|
||||
):
|
||||
event_properties["$ai_cache_read_input_tokens"] = usage.get(
|
||||
"cache_read_input_tokens", 0
|
||||
)
|
||||
|
||||
if (
|
||||
usage.get("cache_creation_input_tokens") is not None
|
||||
and usage.get("cache_creation_input_tokens", 0) > 0
|
||||
):
|
||||
event_properties["$ai_cache_creation_input_tokens"] = usage.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
|
||||
if posthog_distinct_id is None:
|
||||
event_properties["$process_person_profile"] = False
|
||||
|
||||
# Process instructions for Responses API
|
||||
if provider == "openai" and kwargs.get("instructions") is not None:
|
||||
event_properties["$ai_instructions"] = with_privacy_mode(
|
||||
ph_client, posthog_privacy_mode, kwargs.get("instructions")
|
||||
)
|
||||
|
||||
# send the event to posthog
|
||||
if hasattr(ph_client, "capture") and callable(ph_client.capture):
|
||||
ph_client.capture(
|
||||
distinct_id=posthog_distinct_id or posthog_trace_id,
|
||||
event="$ai_generation",
|
||||
properties=event_properties,
|
||||
groups=posthog_groups,
|
||||
)
|
||||
|
||||
if error:
|
||||
raise error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def with_privacy_mode(ph_client: PostHogClient, privacy_mode: bool, value: Any):
|
||||
if ph_client.privacy_mode or privacy_mode:
|
||||
return None
|
||||
return value
|
||||
Reference in New Issue
Block a user