Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add anthropic support and prompt truncation #2114

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 168 additions & 64 deletions src/seer/automation/agent/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _prep_message_and_tools(
system_prompt: str | None = None,
tools: list[FunctionTool] | None = None,
reasoning_effort: str | None = None,
):
) -> tuple[list[MessageParam], list[ToolParam] | None, list[TextBlockParam] | None]:
message_dicts = [cls.to_message_dict(message) for message in messages] if messages else []
if system_prompt:
message_dicts.insert(
Expand All @@ -321,7 +321,7 @@ def _prep_message_and_tools(

return message_dicts, tool_dicts

@observe(as_type="generation", name="OpenAI Stream")
@observe(as_type="generation", name="Generate Text Stream")
def generate_text_stream(
self,
*,
Expand All @@ -333,75 +333,103 @@ def generate_text_stream(
max_tokens: int | None = None,
timeout: float | None = None,
reasoning_effort: str | None = None,
first_token_timeout: float = 40.0, # Time to first token timeout
inactivity_timeout: float = 20.0, # Timeout for inactivity after first token
) -> Iterator[Tuple[str, str] | ToolCall | Usage]:
message_dicts, tool_dicts = self._prep_message_and_tools(
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
tools=tools,
reasoning_effort=reasoning_effort,
)

openai_client = self.get_client()

stream = openai_client.chat.completions.create(
model=self.model_name,
messages=cast(Iterable[ChatCompletionMessageParam], message_dicts),
temperature=temperature,
tools=(
cast(Iterable[ChatCompletionToolParam], tool_dicts)
if tool_dicts
else openai.NotGiven()
),
max_tokens=max_tokens or openai.NotGiven(),
timeout=timeout or openai.NotGiven(),
stream=True,
stream_options={"include_usage": True},
reasoning_effort=reasoning_effort if reasoning_effort else openai.NotGiven(),
)

try:
current_tool_call: dict[str, Any] | None = None
current_tool_call_index = 0
defaults = self.defaults
default_temperature = defaults.temperature if defaults else None

for chunk in stream:
if not chunk.choices and chunk.usage:
usage = Usage(
completion_tokens=chunk.usage.completion_tokens,
prompt_tokens=chunk.usage.prompt_tokens,
total_tokens=chunk.usage.total_tokens,
messages = LlmClient.clean_message_content(messages if messages else [])
if not tools:
messages = LlmClient.clean_tool_call_assistant_messages(messages)

# Get the appropriate stream generator based on provider
try:
if self.provider_name == LlmProviderType.OPENAI:
stream_generator = self.generate_text_stream(
max_tokens=max_tokens,
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature or default_temperature,
tools=tools,
timeout=timeout,
reasoning_effort=reasoning_effort,
)
yield usage
langfuse_context.update_current_observation(model=self.model_name, usage=usage)
break
elif self.provider_name == LlmProviderType.ANTHROPIC:
stream_generator = self.generate_text_stream(
max_tokens=max_tokens,
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature or default_temperature,
tools=tools,
timeout=timeout,
reasoning_effort=reasoning_effort,
)
elif self.provider_name == LlmProviderType.GEMINI:
stream_generator = self.generate_text_stream(
max_tokens=max_tokens,
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature or default_temperature,
tools=tools,
)
else:
raise ValueError(f"Invalid provider: {self.provider_name}")

except anthropic.APIStatusError as e:
if getattr(e, "status_code", None) == 413:
logger.warning(f"Prompt too long for Anthropic, falling back to OpenAI")
# Fall back to OpenAI for oversized prompts
fallback_model = OpenAiProvider.model("gpt-4-turbo")
stream_generator = fallback_model.generate_text_stream(
max_tokens=max_tokens,
messages=messages,
prompt=prompt,
system_prompt=system_prompt,
temperature=temperature or 0.0,
tools=tools,
timeout=timeout,
reasoning_effort=reasoning_effort,
)
else:
raise

delta = chunk.choices[0].delta
if delta.tool_calls:
tool_call = delta.tool_calls[0]
# Add timeout check to the stream with differentiated timeouts
# for first token and subsequent tokens
last_yield_time = time.time()
first_token_received = False

if (
not current_tool_call or current_tool_call_index != tool_call.index
): # Start of new tool call
current_tool_call_index = tool_call.index
if current_tool_call:
yield ToolCall(**current_tool_call)
current_tool_call = None
current_tool_call = {
"id": tool_call.id,
"function": tool_call.function.name if tool_call.function.name else "",
"args": (
tool_call.function.arguments if tool_call.function.arguments else ""
),
}
for item in stream_generator:
current_time = time.time()
# Use first_token_timeout for the first token, inactivity_timeout for subsequent tokens
timeout_to_use = (
first_token_timeout if not first_token_received else inactivity_timeout
)

if current_time - last_yield_time > timeout_to_use:
if first_token_received:
raise LlmStreamInactivityTimeoutError(
f"Stream inactivity timeout after {timeout_to_use} seconds"
)
else:
if tool_call.function.arguments:
current_tool_call["args"] += tool_call.function.arguments
if chunk.choices[0].finish_reason == "tool_calls" and current_tool_call:
yield ToolCall(**current_tool_call)
if delta.content:
yield "content", delta.content
finally:
stream.response.close()
raise LlmStreamFirstTokenTimeoutError(
f"Stream time to first token timeout after {timeout_to_use} seconds"
)

# Mark that we've received at least one token
first_token_received = True
last_yield_time = current_time
yield item

except Exception as e:
logger.exception(
f"Text stream generation failed with provider {self.provider_name}: {e}"
)
raise e

def construct_message_from_stream(
self, content_chunks: list[str], tool_calls: list[ToolCall]
Expand Down Expand Up @@ -456,8 +484,79 @@ def is_completion_exception_retryable(exception: Exception) -> bool:
return (
isinstance(exception, anthropic.AnthropicError)
and ("overloaded_error" in str(exception))
and not (
isinstance(exception, anthropic.APIStatusError)
and getattr(exception, "status_code", None) == 413
)
) or isinstance(exception, LlmStreamTimeoutError)

# Add token estimation function
def estimate_tokens(self, messages: list[MessageParam], system_prompt_block=None) -> int:
# Anthropic uses ~4 chars per token as a rough approximation
chars_per_token = 4.0
total_chars = 0

# Count system prompt chars
if system_prompt_block:
for block in system_prompt_block:
if block.text:
total_chars += len(block.text)

# Count message chars
for message in messages:
if isinstance(message.content, list):
for content_block in message.content:
if getattr(content_block, "text", None):
total_chars += len(content_block.text)
elif getattr(content_block, "thinking", None):
total_chars += len(content_block.thinking)

return int(total_chars / chars_per_token)

# Add truncation function
def truncate_messages_to_fit(
self, messages: list[MessageParam], system_prompt_block=None
) -> tuple[list[MessageParam], list[TextBlockParam] | None]:
MAX_TOKENS = 100000 # Anthropic's token limit
messages_copy = list(messages)

# First estimate tokens
estimated_tokens = self.estimate_tokens(messages_copy, system_prompt_block)

# If under limit, return as is
if estimated_tokens <= MAX_TOKENS:
return messages_copy, system_prompt_block

# Strategy 1: Remove thinking blocks which are the largest parts
for message in messages_copy:
if isinstance(message.content, list):
filtered_content = []
for block in message.content:
if not getattr(block, "type", None) == "thinking":
filtered_content.append(block)
message.content = filtered_content

# Check if that was enough
estimated_tokens = self.estimate_tokens(messages_copy, system_prompt_block)
if estimated_tokens <= MAX_TOKENS:
return messages_copy, system_prompt_block

# Strategy 2: Keep only the most recent messages
truncated_messages = messages_copy[-5:] if len(messages_copy) > 5 else messages_copy

# Check if that was enough
estimated_tokens = self.estimate_tokens(truncated_messages, system_prompt_block)
if estimated_tokens <= MAX_TOKENS:
return truncated_messages, system_prompt_block

# Strategy 3: If still too large, reduce system prompt
if system_prompt_block:
system_text = system_prompt_block[0].text
if len(system_text) > 1000:
system_prompt_block[0].text = system_text[:1000] + "..."

return truncated_messages, system_prompt_block

@observe(as_type="generation", name="Anthropic Generation")
@inject
def generate_text(
Expand Down Expand Up @@ -646,6 +745,11 @@ def _prep_message_and_tools(
else None
)

# Apply truncation to ensure prompt fits within token limits
message_dicts, system_prompt_block = cls().truncate_messages_to_fit(
message_dicts, system_prompt_block
)

return message_dicts, tool_dicts, system_prompt_block

@observe(as_type="generation", name="Anthropic Stream")
Expand Down
Loading