10000 fix: tracing of non-serializable values, e.g. bytes (#34) · yonib05/sdk-python@2056888 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2056888

Browse files
authored
fix: tracing of non-serializable values, e.g. bytes (strands-agents#34)
1 parent 6088173 commit 2056888

File tree

2 files changed

+184
-16
lines changed

2 files changed

+184
-16
lines changed

src/strands/telemetry/tracer.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import json
88
import logging
99
import os
10-
from datetime import datetime, timezone
10+
from datetime import date, datetime, timezone
1111
from importlib.metadata import version
1212
from typing import Any, Dict, Mapping, Optional
1313

@@ -30,21 +30,49 @@
3030
class JSONEncoder(json.JSONEncoder):
3131
"""Custom JSON encoder that handles non-serializable types."""
3232

33-
def default(self, obj: Any) -> Any:
34-
"""Handle non-serializable types.
33+
def encode(self, obj: Any) -> str:
34+
"""Recursively encode objects, preserving structure and only replacing unserializable values.
3535
3636
Args:
37-
obj: The object to serialize
37+
obj: The object to encode
3838
3939
Returns:
40-
A JSON serializable version of the object
40+
JSON string representation of the object
4141
"""
42-
value = ""
43-
try:
44-
value = super().default(obj)
45-
except TypeError:
46-
value = "<replaced>"
47-
return value
42+
# Process the object to handle non-serializable values
43+
processed_obj = self._process_value(obj)
44+
# Use the parent class to encode the processed object
45+
return super().encode(processed_obj)
46+
47+
def _process_value(self, value: Any) -> Any:
48+
"""Process any value, handling containers recursively.
49+
50+
Args:
51+
value: The value to process
52+
53+
Returns:
54+
Processed value with unserializable parts replaced
55+
"""
56+
# Handle datetime objects directly
57+
if isinstance(value, (datetime, date)):
58+
return value.isoformat()
59+
60+
# Handle dictionaries
61+
elif isinstance(value, dict):
62+
return {k: self._process_value(v) for k, v in value.items()}
63+
64+
# Handle lists
65+
elif isinstance(value, list):
66+
return [self._process_value(item) for item in value]
67+
68+
# Handle all other values
69+
else:
70+
try:
71+
# Test if the value is JSON serializable
72+
json.dumps(value)
73+
return value
74+
except (TypeError, OverflowError, ValueError):
75+
return "<replaced>"
4876

4977

5078
class Tracer:
@@ -332,6 +360,7 @@ def start_tool_call_span(
332360
The created span, or None if tracing is not enabled.
333361
"""
334362
attributes: Dict[str, AttributeValue] = {
363+
"gen_ai.prompt": json.dumps(tool, cls=JSONEncoder),
335364
"tool.name": tool["name"],
336365
"tool.id": tool["toolUseId"],
337366
"tool.parameters": json.dumps(tool["input"], cls=JSONEncoder),
@@ -358,10 +387,11 @@ def end_tool_call_span(
358387
status = tool_result.get("status")
359388
status_str = str(status) if status is not None else ""
360389

390+
tool_result_content_json = json.dumps(tool_result.get("content"), cls=JSONEncoder)
361391
attributes.update(
362392
{
363-
"tool.result": json.dumps(tool_result.get("content"), cls=JSONEncoder),
364-
"gen_ai.completion": json.dumps(tool_result.get("content"), cls=JSONEncoder),
393+
"tool.result": tool_result_content_json,
394+
"gen_ai.completion": tool_result_content_json,
365395
"tool.status": status_str,
366396
}
367397
)
@@ -492,7 +522,7 @@ def end_agent_span(
492522
if response:
493523
attributes.update(
494524
{
495-
"gen_ai.completion": json.dumps(response, cls=JSONEncoder),
525+
"gen_ai.completion": str(response),
496526
}
497527
)
498528

tests/strands/telemetry/test_tracer.py

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import os
3+
from datetime import date, datetime, timezone
34
from unittest import mock
45

56
import pytest
67
from opentelemetry.trace import StatusCode # type: ignore
78

8-
from strands.telemetry.tracer import Tracer, get_tracer
9+
from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer
910
from strands.types.streaming import Usage
1011

1112

@@ -268,6 +269,9 @@ def test_start_tool_call_span(mock_tracer):
268269

269270
mock_tracer.start_span.assert_called_once()
270271
assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool"
272+
mock_span.set_attribute.assert_any_call(
273+
"gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}})
274+
)
271275
mock_span.set_attribute.assert_any_call("tool.name", "test-tool")
272276
mock_span.set_attribute.assert_any_call("tool.id", "123")
273277
mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"}))
@@ -369,7 +373,7 @@ def test_end_agent_span(mock_span):
369373

370374
tracer.end_agent_span(mock_span, mock_response)
371375

372-
mock_span.set_attribute.assert_any_call("gen_ai.completion", '"<replaced>"')
376+
mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response")
373377
mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50)
374378
mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100)
375379
mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150)
@@ -497,3 +501,137 @@ def test_start_model_invoke_span_with_parent(mock_tracer):
497501

498502
# Verify span was returned
499503
assert span is mock_span
504+
505+
506+
@pytest.mark.parametrize(
507+
"input_data, expected_result",
508+
[
509+
("test string", '"test string"'),
510+
(1234, "1234"),
511+
(13.37, "13.37"),
512+
(False, "false"),
513+
(None, "null"),
514+
],
515+
)
516+
def test_json_encoder_serializable(input_data, expected_result):
517+
"""Test encoding of serializable values."""
518+
encoder = JSONEncoder()
519+
520+
result = encoder.encode(input_data)
521+
assert result == expected_result
522+
523+
524+
def test_json_encoder_datetime():
525+
"""Test encoding datetime and date objects."""
526+
encoder = JSONEncoder()
527+
528+
dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
529+
result = encoder.encode(dt)
530+
assert result == f'"{dt.isoformat()}"'
531+
532+
d = date(2025, 1, 1)
533+
result = encoder.encode(d)
534+
assert result == f'"{d.isoformat()}"'
535+
536+
537+
def test_json_encoder_list():
538+
"""Test encoding a list with mixed content."""
539+
encoder = JSONEncoder()
540+
541+
non_serializable = lambda x: x # noqa: E731
542+
543+
data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]]
544+
545+
result = json.loads(encoder.encode(data))
546+
assert result == ["value", 42, 13.37, "<replaced>", None, {"key": True}, ["value here"]]
547+
548+
549+
def test_json_encoder_dict():
550+
"""Test encoding a dict with mixed content."""
551+
encoder = JSONEncoder()
552+
553+
class UnserializableClass:
554+
def __str__(self):
555+
return "Unserializable Object"
556+
557+
non_serializable = lambda x: x # noqa: E731
558+
559+
now = datetime.now(timezone.utc)
560+
561+
data = {
562+
"metadata": {
563+
"timestamp": now,
564+
"version": "1.0",
565+
"debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731
566+
},
567+
"content": [
568+
{"type": "text", "value": "Hello world"},
569+
{"type": "binary", "value": non_serializable},
570+
{"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]},
571+
],
572+
"statistics": {
573+
"processed": 100,
574+
"failed": 5,
575+
"details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}],
576+
},
577+
"list": [
578+
non_serializable,
579+
1234,
580+
13.37,
581+
True,
582+
None,
583+
"string here",
584+
],
585+
}
586+
587+
expected = {
588+
"metadata": {
589+
"timestamp": now.isoformat(),
590+
"version": "1.0",
591+
"debug_info": {"object": "<replaced>", "callable": "<replaced>"},
592+
},
593+
"content": [
594+
{"type": "text", "value": "Hello world"},
595+
{"type": "binary", "value": "<replaced>"},
596+
{"type": "mixed", "values": [1, "text", "<replaced>", {"nested": "<replaced>"}]},
597+
],
598+
"statistics": {
599+
"processed": 100,
600+
"failed": 5,
601+
"details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": "<replaced>"}],
602+
},
603+
"list": [
604+
"<replaced>",
605+
1234,
606+
13.37,
607+
True,
608+
None,
609+
"string here",
610+
],
611+
}
612+
613+
result = json.loads(encoder.encode(data))
614+
615+
assert result == expected
< B061 /code>616+
617+
618+
def test_json_encoder_value_error():
619+
"""Test encoding values that cause ValueError."""
620+
encoder = JSONEncoder()
621+
622+
# A very large integer that exceeds JSON limits and throws ValueError
623+
huge_number = 2**100000
624+
625+
# Test in a dictionary
626+
dict_data = {"normal": 42, "huge": huge_number}
627+
result = json.loads(encoder.encode(dict_data))
628+
assert result == {"normal": 42, "huge": "<replaced>"}
629+
630+
# Test in a list
631+
list_data = [42, huge_number]
632+
result = json.loads(encoder.encode(list_data))
633+
assert result == [42, "<replaced>"]
634+
635+
# Test just the value
636+
result = json.loads(encoder.encode(huge_number))
637+
assert result == "<replaced>"

0 commit comments

Comments
 (0)
0