PydanticAI streaming example using agent.iter()

Logging all possible steps. Likely you will not need all steps, those you can delete. Using pydantic-ai[logfire]==0.1.2

For logger, I use a custom wrapper around logfire called logger, you can exchange the calls to logfire without any issues.

from common import utils, logger, stores
from typing import AsyncGenerator
from . import prompts
from pydantic_ai import Agent
from pydantic_ai.messages import (PartDeltaEvent, TextPartDelta, PartStartEvent, ToolCallPartDelta,
                                  FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent,
                                  TextPart, ToolCallPart)
from .pydantic_ai_specific import agent_factory


class AgentService:
    def __init__(
            self,
            store_manager: stores.StoreManager,
            prompt_version: prompts.PromptDefinition
    ):
        self.store_manager = store_manager
        self.prompt_version = prompt_version

        self.agent = agent_factory.create_agent_from_prompt_version(self.prompt_version)

    async def stream_agent_response(self, user_prompt: str, conversation_id: str) -> AsyncGenerator[str, None]:
        # this is the public contract that should not change even a potential framework change
        return self._stream_agent_response_pydantic(user_prompt, conversation_id)

    async def _stream_agent_response_pydantic(self, user_prompt: str, conversation_id: str) -> AsyncGenerator[str, None]:
        prompt_template = self.prompt_version.prompt_template
        final_prompt = utils.replace_placeholders(prompt_template, {"user_prompt": user_prompt})
        message_history = await self.store_manager.get_conversation_history(conversation_id)

        async with self.agent.iter(user_prompt=final_prompt, message_history=message_history) as run:
            async for node in run:
                if Agent.is_user_prompt_node(node):
                    logger.trace('user prompt node')
                elif Agent.is_model_request_node(node):
                    logger.trace('model request node')
                    async with node.stream(run.ctx) as request_stream:
                        async for event in request_stream:
                            if isinstance(event, PartStartEvent):
                                if isinstance(event.part, TextPart):
                                    yield event.part.content
                                elif isinstance(event.part, ToolCallPart):
                                    logger.info(f'Tool call started: {event.part.tool_name}')
                            elif isinstance(event, PartDeltaEvent):
                                if isinstance(event.delta, TextPartDelta):
                                    logger.trace('TextPartDelta')
                                    yield event.delta.content_delta
                                elif isinstance(event.delta, ToolCallPartDelta):
                                    logger.trace('ToolCallPartDelta')
                            elif isinstance(event, FinalResultEvent):
                                logger.trace('FinalResultEvent')
                elif Agent.is_call_tools_node(node):
                    tool_args_json = {}
                    logger.trace('is_call_tools_node')
                    async with node.stream(run.ctx) as handle_stream:
                        async for event in handle_stream:
                            if isinstance(event, FunctionToolCallEvent):
                                tool_args_json = event.part.args_as_json_str()
                                logger.trace('FunctionToolCallEvent')
                            elif isinstance(event, FunctionToolResultEvent):
                                logger.trace('FunctionToolResultEvent')
                                current_context = {
                                    'tool_args': tool_args_json,
                                    'tool_result': event.result.content
                                }
                                await self.store_manager.append_conversation_context(conversation_id, current_context)
                elif Agent.is_end_node(node):
                    logger.trace('is_end_node')
            new_messages = run.result.new_messages()
            await self.store_manager.extend_conversation_history(conversation_id, new_messages)
            logger.info(f'Answer generation for conversation_id {conversation_id} finished.')
Continue ReadingPydanticAI streaming example using agent.iter()