PydanticAI streaming example using agent.iter()
Logging all possible steps. Likely you will not need all steps, those you can delete. Using pydantic-ai[logfire]==0.1.2
For logger, I use a custom wrapper around logfire called logger, you can exchange the calls to logfire without any issues.
from common import utils, logger, stores
from typing import AsyncGenerator
from . import prompts
from pydantic_ai import Agent
from pydantic_ai.messages import (PartDeltaEvent, TextPartDelta, PartStartEvent, ToolCallPartDelta,
FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent,
TextPart, ToolCallPart)
from .pydantic_ai_specific import agent_factory
class AgentService:
def __init__(
self,
store_manager: stores.StoreManager,
prompt_version: prompts.PromptDefinition
):
self.store_manager = store_manager
self.prompt_version = prompt_version
self.agent = agent_factory.create_agent_from_prompt_version(self.prompt_version)
async def stream_agent_response(self, user_prompt: str, conversation_id: str) -> AsyncGenerator[str, None]:
# this is the public contract that should not change even a potential framework change
return self._stream_agent_response_pydantic(user_prompt, conversation_id)
async def _stream_agent_response_pydantic(self, user_prompt: str, conversation_id: str) -> AsyncGenerator[str, None]:
prompt_template = self.prompt_version.prompt_template
final_prompt = utils.replace_placeholders(prompt_template, {"user_prompt": user_prompt})
message_history = await self.store_manager.get_conversation_history(conversation_id)
async with self.agent.iter(user_prompt=final_prompt, message_history=message_history) as run:
async for node in run:
if Agent.is_user_prompt_node(node):
logger.trace('user prompt node')
elif Agent.is_model_request_node(node):
logger.trace('model request node')
async with node.stream(run.ctx) as request_stream:
async for event in request_stream:
if isinstance(event, PartStartEvent):
if isinstance(event.part, TextPart):
yield event.part.content
elif isinstance(event.part, ToolCallPart):
logger.info(f'Tool call started: {event.part.tool_name}')
elif isinstance(event, PartDeltaEvent):
if isinstance(event.delta, TextPartDelta):
logger.trace('TextPartDelta')
yield event.delta.content_delta
elif isinstance(event.delta, ToolCallPartDelta):
logger.trace('ToolCallPartDelta')
elif isinstance(event, FinalResultEvent):
logger.trace('FinalResultEvent')
elif Agent.is_call_tools_node(node):
tool_args_json = {}
logger.trace('is_call_tools_node')
async with node.stream(run.ctx) as handle_stream:
async for event in handle_stream:
if isinstance(event, FunctionToolCallEvent):
tool_args_json = event.part.args_as_json_str()
logger.trace('FunctionToolCallEvent')
elif isinstance(event, FunctionToolResultEvent):
logger.trace('FunctionToolResultEvent')
current_context = {
'tool_args': tool_args_json,
'tool_result': event.result.content
}
await self.store_manager.append_conversation_context(conversation_id, current_context)
elif Agent.is_end_node(node):
logger.trace('is_end_node')
new_messages = run.result.new_messages()
await self.store_manager.extend_conversation_history(conversation_id, new_messages)
logger.info(f'Answer generation for conversation_id {conversation_id} finished.')