Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class Event(LlmResponse):

invocation_id: str = ''
"""The invocation ID of the event. Should be non-empty before appending to a session."""
turn_id: Optional[str] = None
"""Stable identifier for a single LLM response turn."""
author: str
"""'user' or the name of the agent, indicating who appended the event to the
session."""
Expand Down
31 changes: 21 additions & 10 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,9 @@ async def run_async(
break

async def _run_one_step_async(
self,
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
self,
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""One step means one LLM call."""
llm_request = LlmRequest()

Expand All @@ -833,16 +833,10 @@ async def _run_one_step_async(
)

# Long running tool calls should have been handled before this point.
# If there are still long running tool calls, it means the agent is paused
# before, and its branch hasn't been resumed yet.
if (
invocation_context.is_resumable
and events
and len(events) > 1
# TODO: here we are using the last 2 events to decide whether to pause
# the invocation. But this is just being optimistic, we should find a
# way to pause when the long running tool call is followed by more than
# one text responses.
and (
invocation_context.should_pause_invocation(events[-1])
or invocation_context.should_pause_invocation(events[-2])
Expand All @@ -856,20 +850,27 @@ async def _run_one_step_async(
and events[-1].get_function_calls()
):
model_response_event = events[-1]
if not model_response_event.turn_id:
model_response_event.turn_id = Event.new_id()

async with Aclosing(
self._postprocess_handle_function_calls_async(
invocation_context, model_response_event, llm_request
)
) as agen:
async for event in agen:
event.id = Event.new_id()
if not event.turn_id:
event.turn_id = model_response_event.turn_id
yield event
return
return

# Calls the LLM.
turn_id = Event.new_id()
model_response_event = Event(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
turn_id=turn_id,
author=invocation_context.agent.name,
branch=invocation_context.branch,
)
Expand All @@ -886,6 +887,7 @@ async def _run_one_step_async(
llm_request,
llm_response,
model_response_event,
turn_id=turn_id,
)
) as agen:
async for event in agen:
Expand Down Expand Up @@ -932,6 +934,7 @@ async def _postprocess_async(
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
turn_id: str,
) -> AsyncGenerator[Event, None]:
"""Postprocess after calling the LLM.

Expand Down Expand Up @@ -965,6 +968,7 @@ async def _postprocess_async(
model_response_event = self._finalize_model_response_event(
llm_request, llm_response, model_response_event
)
model_response_event.turn_id = turn_id
yield model_response_event

# Handles function calls.
Expand Down Expand Up @@ -1077,6 +1081,7 @@ async def _postprocess_live(
function_response_event = await functions.handle_function_calls_live(
invocation_context, model_response_event, llm_request.tools_dict
)
function_response_event.turn_id = model_response_event.turn_id
# Always yield the function response event first
yield function_response_event

Expand All @@ -1090,6 +1095,7 @@ async def _postprocess_live(
invocation_context, json_response
)
)
final_event.turn_id = model_response_event.turn_id
yield final_event

async def _postprocess_run_processors_async(
Expand All @@ -1111,16 +1117,20 @@ async def _postprocess_handle_function_calls_async(
if function_response_event := await functions.handle_function_calls_async(
invocation_context, function_call_event, llm_request.tools_dict
):
function_response_event.turn_id = function_call_event.turn_id

auth_event = functions.generate_auth_event(
invocation_context, function_response_event
)
if auth_event:
auth_event.turn_id = function_call_event.turn_id
yield auth_event

tool_confirmation_event = functions.generate_request_confirmation_event(
invocation_context, function_call_event, function_response_event
)
if tool_confirmation_event:
tool_confirmation_event.turn_id = function_call_event.turn_id
yield tool_confirmation_event

# Always yield the function response event first
Expand All @@ -1136,6 +1146,7 @@ async def _postprocess_handle_function_calls_async(
invocation_context, json_response
)
)
final_event.turn_id = function_call_event.turn_id
yield final_event
transfer_to_agent = function_response_event.actions.transfer_to_agent
if transfer_to_agent:
Expand Down
Loading