10000 chore: Enhance a2a event converter · google/adk-python@87da892 · GitHub
[go: up one dir, main page]

Skip to content

Commit 87da892

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Enhance a2a event converter
a. fix function call long running id matching logic b. fix error code conversion logic c. add input required and auth required status conversion logic PiperOrigin-RevId: 774238964
1 parent 7c670f6 commit 87da892

File tree

6 files changed

+385
-34
lines changed

6 files changed

+385
-34
lines changed

src/google/adk/a2a/converters/event_converter.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from __future__ import annotations
1616

17-
import datetime
17+
from datetime import datetime
18+
from datetime import timezone
1819
import logging
1920
from typing import Any
2021
from typing import Dict
@@ -35,6 +36,7 @@
3536

3637
from ...agents.invocation_context import InvocationContext
3738
from ...events.event import Event
39< 10000 span class="diff-text-marker">+
from ...flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
3840
from ...utils.feature_decorator import working_in_progress
3941
from .part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
4042
from .part_converter import A2A_DATA_PART_METADATA_TYPE_KEY
@@ -224,7 +226,7 @@ def _process_long_running_tool(a2a_part, event: Event) -> None:
224226
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
225227
)
226228
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
227-
and a2a_part.root.metadata.get("id") in event.long_running_tool_ids
229+
and a2a_part.root.data.get("id") in event.long_running_tool_ids
228230
):
229231
a2a_part.root.metadata[_get_adk_metadata_key("is_long_running")] = True
230232

@@ -287,24 +289,34 @@ def _create_error_status_event(
287289
"""
288290
error_message = getattr(event, "error_message", None) or DEFAULT_ERROR_MESSAGE
289291

292+
# Get context metadata and add error code
293+
event_metadata = _get_context_metadata(event, invocation_context)
294+
if event.error_code:
295+
event_metadata[_get_adk_metadata_key("error_code")] = str(event.error_code)
296+
290297
return TaskStatusUpdateEvent(
291298
taskId=str(uuid.uuid4()),
292299
contextId=invocation_context.session.id,
293300
final=False,
294-
metadata=_get_context_metadata(event, invocation_context),
301+
metadata=event_metadata,
295302
status=TaskStatus(
296303
state=TaskState.failed,
297304
message=Message(
298305
messageId=str(uuid.uuid4()),
299306
role=Role.agent,
300307
parts=[TextPart(text=error_message)],
308+
metadata={
309+
_get_adk_metadata_key("error_code"): str(event.error_code)
310+
}
311+
if event.error_code
312+
else {},
301313
),
302-
timestamp=datetime.datetime.now().isoformat(),
314+
timestamp=datetime.now(timezone.utc).isoformat(),
303315
),
304316
)
305317

306318

307-
def _create_running_status_event(
319+
def _create_status_update_event(
308320
message: Message, invocation_context: InvocationContext, event: Event
309321
) -> TaskStatusUpdateEvent:
310322
"""Creates a TaskStatusUpdateEvent for running scenarios.
@@ -317,15 +329,39 @@ def _create_running_status_event(
317329
Returns:
318330
A TaskStatusUpdateEvent with RUNNING state.
319331
"""
332+
status = TaskStatus(
333+
state=TaskState.working,
334+
message=message,
335+
timestamp=datetime.now(timezone.utc).isoformat(),
336+
)
337+
338+
if any(
339+
part.root.metadata.get(
340+
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
341+
)
342+
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
343+
and part.root.metadata.get(_get_adk_metadata_key("is_long_running"))
344+
is True
345+
and part.root.data.get("name") == REQUEST_EUC_FUNCTION_CALL_NAME
346+
for part in message.parts
347+
):
348+
status.state = TaskState.auth_required
349+
elif any(
350+
part.root.metadata.get(
351+
_get_adk_metadata_key(A2A_DATA_PART_METADATA_TYPE_KEY)
352+
)
353+
== A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL
354+
and part.root.metadata.get(_get_adk_metadata_key("is_long_running"))
355+
is True
356+
for part in message.parts
357+
):
358+
status.state = TaskState.input_required
359+
320360
return TaskStatusUpdateEvent(
321361
taskId=str(uuid.uuid4()),
322362
contextId=invocation_context.session.id,
323363
final=False,
324-
status=TaskStatus(
325-
state=TaskState.working,
326-
message 8000 =message,
327-
timestamp=datetime.datetime.now().isoformat(),
328-
),
364+
status=status,
329365
metadata=_get_context_metadata(event, invocation_context),
330366
)
331367

@@ -370,7 +406,7 @@ def convert_event_to_a2a_events(
370406
# Handle regular message content
371407
message = convert_event_to_a2a_status_message(event, invocation_context)
372408
if message:
373-
running_event = _create_running_status_event(
409+
running_event = _create_status_update_event(
374410
message, invocation_context, event
375411
)
376412
a2a_events.append(running_event)

src/google/adk/a2a/converters/part_converter.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import base64
2122
import json
2223
import logging
2324
import sys
@@ -45,6 +46,8 @@
4546
A2A_DATA_PART_METADATA_TYPE_KEY = 'type'
4647
A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = 'function_call'
4748
A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE = 'function_response'
49+
A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT = 'code_execution_result'
50+
A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = 'executable_code'
4851

4952

5053
@working_in_progress
@@ -67,7 +70,8 @@ def convert_a2a_part_to_genai_part(
6770
elif isinstance(part.file, a2a_types.FileWithBytes):
6871
return genai_types.Part(
6972
inline_data=genai_types.Blob(
70-
data=part.file.bytes.encode('utf-8'), mime_type=part.file.mimeType
73+
data=base64.b64decode(part.file.bytes),
74+
mime_type=part.file.mimeType,
7175
)
7276
)
7377
else:
@@ -118,8 +122,12 @@ def convert_genai_part_to_a2a_part(
118122
part: genai_types.Part,
119123
) -> Optional[a2a_types.Part]:
120124
"""Convert a Google GenAI Part to an A2A Part."""
125+
121126
if part.text:
122-
return a2a_types.TextPart(text=part.text)
127+
a2a_part = a2a_types.TextPart(text=part.text)
128+
if part.thought is not None:
129+
a2a_part.metadata = {_get_adk_metadata_key('thought'): part.thought}
130+
return a2a_part
123131

124132
if part.file_data:
125133
return a2a_types.FilePart(
@@ -130,14 +138,22 @@ def convert_genai_part_to_a2a_part(
130138
)
131139

132140
if part.inline_data:
133-
return a2a_types.Part(
141+
a2a_part = a2a_types.Part(
134142
root=a2a_types.FilePart(
135143
file=a2a_types.FileWithBytes(
136-
bytes=part.inline_data.data,
144+
bytes=base64.b64encode(part.inline_data.data).decode('utf-8'),
137145
mimeType=part.inline_data.mime_type,
138146
)
139147
)
140148
)
149+
if part.video_metadata:
150+
a2a_part.metadata = {
151+
_get_adk_metadata_key(
152+
'video_metadata'
153+
): part.video_metadata.model_dump(by_alias=True, exclude_none=True)
154+
}
155+
156+
return a2a_part
141157

142158
# Conver the funcall and function reponse to A2A DataPart.
143159
# This is mainly for converting human in the loop and auth request and
@@ -172,6 +188,34 @@ def convert_genai_part_to_a2a_part(
172188
)
173189
)
174190

191+
if part.code_execution_result:
192+
return a2a_types.Part(
193+
root=a2a_types.DataPart(
194+
data=part.code_execution_result.model_dump(
195+
by_alias=True, exclude_none=True
196+
),
197+
metadata={
198+
A2A_DATA_PART_METADATA_TYPE_KEY: (
199+
A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT
200+
)
201+
},
202+
)
203+
)
204+
205+
if part.executable_code:
206+
return a2a_types.Part(
207+
root=a2a_types.DataPart(
208+
data=part.executable_code.model_dump(
209+
by_alias=True, exclude_none=True
210+
),
211+
metadata={
212+
A2A_DATA_PART_METADATA_TYPE_KEY: (
213+
A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE
214+
)
215+
},
216+
)
217+
)
218+
175219
logger.warning(
176220
'Cannot convert unsupported part for Google GenAI part: %s',
177221
part,

src/google/adk/a2a/converters/utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str:
4545
4646
Returns:
4747
The A2A context id.
48+
49+
Raises:
50+
ValueError: If any of the input parameters are empty or None.
4851
"""
49-
return [ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id].join("$")
52+
if not all([app_name, user_id, session_id]):
53+
raise ValueError(
54+
"All parameters (app_name, user_id, session_id) must be non-empty"
55+
)
56+
return "$".join([ADK_CONTEXT_ID_PREFIX, app_name, user_id, session_id])
5057

5158

5259
def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
@@ -64,8 +71,16 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
6471
if not context_id:
6572
return None, None, None
6673

67-
prefix, app_name, user_id, session_id = context_id.split("$")
68-
if prefix == "ADK" and app_name and user_id and session_id:
69-
return app_name, user_id, session_id
74+
try:
75+
parts = context_id.split("$")
76+
if len(parts) != 4:
77+
return None, None, None
78+
79+
prefix, app_name, user_id, session_id = parts
80+
if prefix == ADK_CONTEXT_ID_PREFIX and app_name and user_id and session_id:
81+
return app_name, user_id, session_id
82+
except ValueError:
83+
# Handle any split errors gracefully
84+
pass
7085

7186
return None, None, None

tests/unittests/a2a/converters/test_event_converter.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
# Skip all tests in this module if Python version is less than 3.10
2222
pytestmark = pytest.mark.skipif(
23-
sys.version_info < (3, 10), reason="A2A tool requires Python 3.10+"
23+
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
2424
)
2525

2626
# Import dependencies with version checking
@@ -34,7 +34,7 @@
3434
from google.adk.a2a.converters.event_converter import _convert_artifact_to_a2a_events
3535
from google.adk.a2a.converters.event_converter import _create_artifact_id
3636
from google.adk.a2a.converters.event_converter import _create_error_status_event
37-
from google.adk.a2a.converters.event_converter import _create_running_status_event
37+
from google.adk.a2a.converters.event_converter import _create_status_update_event
3838
from google.adk.a2a.converters.event_converter import _get_adk_metadata_key
3939
from google.adk.a2a.converters.event_converter import _get_context_metadata
4040
from google.adk.a2a.converters.event_converter import _process_long_running_tool
@@ -63,7 +63,7 @@ class DummyTypes:
6363
_convert_artifact_to_a2a_events = lambda *args: None
6464
_create_artifact_id = lambda *args: None
6565
_create_error_status_event = lambda *args: None
66-
_create_running_status_event = lambda *args: None
66+
_create_status_update_event = lambda *args: None
6767
_get_adk_metadata_key = lambda *args: None
6868
_get_context_metadata = lambda *args: None
6969
_process_long_running_tool = lambda *args: None
@@ -302,6 +302,8 @@ def test_process_long_running_tool_marks_tool(self):
302302
mock_a2a_part = Mock()
303303
mock_data_part = Mock(spec=DataPart)
304304
mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-123"}
305+
mock_data_part.data = Mock()
306+
mock_data_part.data.get = Mock(return_value="tool-123")
305307
mock_a2a_part.root = mock_data_part
306308

307309
self.mock_event.long_running_tool_ids = {"tool-123"}
@@ -315,7 +317,11 @@ def test_process_long_running_tool_marks_tool(self):
315317
"google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL",
316318
"function_call",
317319
),
320+
patch(
321+
"google.adk.a2a.converters.event_converter._get_adk_metadata_key"
322+
) as mock_get_key,
318323
):
324+
mock_get_key.side_effect = lambda key: f"adk_{key}"
319325

320326
_process_long_running_tool(mock_a2a_part, self.mock_event)
321327

@@ -327,6 +333,8 @@ def test_process_long_running_tool_no_marking(self):
327333
mock_a2a_part = Mock()
328334
mock_data_part = Mock(spec=DataPart)
329335
mock_data_part.metadata = {"adk_type": "function_call", "id": "tool-456"}
336+
mock_data_part.data = Mock()
337+
mock_data_part.data.get = Mock(return_value="tool-456")
330338
mock_a2a_part.root = mock_data_part
331339

332340
self.mock_event.long_running_tool_ids = {"tool-123"} # Different ID
@@ -340,7 +348,11 @@ def test_process_long_running_tool_no_marking(self):
340348
"google.adk.a2a.converters.event_converter.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL",
341349
"function_call",
342350
),
351+
patch(
352+
"google.adk.a2a.converters.event_converter._get_adk_metadata_key"
353+
) as mock_get_key,
343354
):
355+
mock_get_key.side_effect = lambda key: f"adk_{key}"
344356

345357
_process_long_running_tool(mock_a2a_part, self.mock_event)
346358

@@ -413,7 +425,7 @@ def test_convert_event_to_message_none_context(self):
413425
assert "Invocation context cannot be None" in str(exc_info.value)
414426

415427
@patch("google.adk.a2a.converters.event_converter.uuid.uuid4")
416-
@patch("google.adk.a2a.converters.event_converter.datetime.datetime")
428+
@patch("google.adk.a2a.converters.event_converter.datetime")
417429
def test_create_error_status_event(self, mock_datetime, mock_uuid):
418430
"""Test creation of error status event."""
419431
mock_uuid.return_value = "test-uuid"
@@ -433,7 +445,7 @@ def test_create_error_status_event(self, mock_datetime, mock_uuid):
433445
assert result.status.message.parts[0].root.text == "Test error message"
434446

435447
@patch("google.adk.a2a.converters.event_converter.uuid.uuid4")
436-
@patch("google.adk.a2a.converters.event_converter.datetime.datetime")
448+
@patch("google.adk.a2a.converters.event_converter.datetime")
437449
def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid):
438450
"""Test creation of error status event without error message."""
439451
mock_uuid.return_value = "test-uuid"
@@ -447,16 +459,17 @@ def test_create_error_status_event_no_message(self, mock_datetime, mock_uuid):
447459

448460
assert result.status.message.parts[0].root.text == DEFAULT_ERROR_MESSAGE
449461

450-
@patch("google.adk.a2a.converters.event_converter.datetime.datetime")
462+
@patch("google.adk.a2a.converters.event_converter.datetime")
451463
def test_create_running_status_event(self, mock_datetime):
452464
"""Test creation of running status event."""
453465
mock_datetime.now.return_value.isoformat.return_value = (
454466
"2023-01-01T00:00:00"
455467
)
456468

457469
mock_message = Mock(spec=Message)
470+
mock_message.parts = []
458471

459-
result = _create_running_status_event(
472+
result = _create_status_update_event(
460473
mock_message, self.mock_invocation_context, self.mock_event
461474
)
462475

@@ -473,7 +486,7 @@ def test_create_running_status_event(self, mock_datetime):
473486
)
474487
@patch("google.adk.a2a.converters.event_converter._create_error_status_event")
475488
@patch(
476-
"google.adk.a2a.converters.event_converter._create_running_status_event"
489+
"google.adk.a2a.converters.event_converter._create_status_update_event"
477490
)
478491
def test_convert_event_to_a2a_events_full_scenario(
479492
self,
@@ -560,7 +573,7 @@ def test_convert_event_to_a2a_events_message_only(self, mock_convert_message):
560573
mock_convert_message.return_value = mock_message
561574

562575
with patch(
563-
"google.adk.a2a.converters.event_converter._create_running_status_event"
576+
"google.adk.a2a.converters.event_converter._create_status_update_event"
564577
) as mock_create_running:
565578
mock_running_event = Mock()
566579
mock_create_running.return_value = mock_running_event

0 commit comments

Comments
 (0)
0