|
1 | 1 | import json
|
2 | 2 | import os
|
| 3 | +from datetime import date, datetime, timezone |
3 | 4 | from unittest import mock
|
4 | 5 |
|
5 | 6 | import pytest
|
6 | 7 | from opentelemetry.trace import StatusCode # type: ignore
|
7 | 8 |
|
8 |
| -from strands.telemetry.tracer import Tracer, get_tracer |
| 9 | +from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer |
9 | 10 | from strands.types.streaming import Usage
|
10 | 11 |
|
11 | 12 |
|
@@ -268,6 +269,9 @@ def test_start_tool_call_span(mock_tracer):
|
268 | 269 |
|
269 | 270 | mock_tracer.start_span.assert_called_once()
|
270 | 271 | assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool"
|
| 272 | + mock_span.set_attribute.assert_any_call( |
| 273 | + "gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}}) |
| 274 | + ) |
271 | 275 | mock_span.set_attribute.assert_any_call("tool.name", "test-tool")
|
272 | 276 | mock_span.set_attribute.assert_any_call("tool.id", "123")
|
273 | 277 | mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"}))
|
@@ -369,7 +373,7 @@ def test_end_agent_span(mock_span):
|
369 | 373 |
|
370 | 374 | tracer.end_agent_span(mock_span, mock_response)
|
371 | 375 |
|
372 |
| - mock_span.set_attribute.assert_any_call("gen_ai.completion", '"<replaced>"') |
| 376 | + mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response") |
373 | 377 | mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50)
|
374 | 378 | mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100)
|
375 | 379 | mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150)
|
@@ -497,3 +501,137 @@ def test_start_model_invoke_span_with_parent(mock_tracer):
|
497 | 501 |
|
498 | 502 | # Verify span was returned
|
499 | 503 | assert span is mock_span
|
| 504 | + |
| 505 | + |
| 506 | +@pytest.mark.parametrize( |
| 507 | + "input_data, expected_result", |
| 508 | + [ |
| 509 | + ("test string", '"test string"'), |
| 510 | + (1234, "1234"), |
| 511 | + (13.37, "13.37"), |
| 512 | + (False, "false"), |
| 513 | + (None, "null"), |
| 514 | + ], |
| 515 | +) |
| 516 | +def test_json_encoder_serializable(input_data, expected_result): |
| 517 | + """Test encoding of serializable values.""" |
| 518 | + encoder = JSONEncoder() |
| 519 | + |
| 520 | + result = encoder.encode(input_data) |
| 521 | + assert result == expected_result |
| 522 | + |
| 523 | + |
| 524 | +def test_json_encoder_datetime(): |
| 525 | + """Test encoding datetime and date objects.""" |
| 526 | + encoder = JSONEncoder() |
| 527 | + |
| 528 | + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) |
| 529 | + result = encoder.encode(dt) |
| 530 | + assert result == f'"{dt.isoformat()}"' |
| 531 | + |
| 532 | + d = date(2025, 1, 1) |
| 533 | + result = encoder.encode(d) |
| 534 | + assert result == f'"{d.isoformat()}"' |
| 535 | + |
| 536 | + |
| 537 | +def test_json_encoder_list(): |
| 538 | + """Test encoding a list with mixed content.""" |
| 539 | + encoder = JSONEncoder() |
| 540 | + |
| 541 | + non_serializable = lambda x: x # noqa: E731 |
| 542 | + |
| 543 | + data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]] |
| 544 | + |
| 545 | + result = json.loads(encoder.encode(data)) |
| 546 | + assert result == ["value", 42, 13.37, "<replaced>", None, {"key": True}, ["value here"]] |
| 547 | + |
| 548 | + |
| 549 | +def test_json_encoder_dict(): |
| 550 | + """Test encoding a dict with mixed content.""" |
| 551 | + encoder = JSONEncoder() |
| 552 | + |
| 553 | + class UnserializableClass: |
| 554 | + def __str__(self): |
| 555 | + return "Unserializable Object" |
| 556 | + |
| 557 | + non_serializable = lambda x: x # noqa: E731 |
| 558 | + |
| 559 | + now = datetime.now(timezone.utc) |
| 560 | + |
| 561 | + data = { |
| 562 | + "metadata": { |
| 563 | + "timestamp": now, |
| 564 | + "version": "1.0", |
| 565 | + "debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731 |
| 566 | + }, |
| 567 | + "content": [ |
| 568 | + {"type": "text", "value": "Hello world"}, |
| 569 | + {"type": "binary", "value": non_serializable}, |
| 570 | + {"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]}, |
| 571 | + ], |
| 572 | + "statistics": { |
| 573 | + "processed": 100, |
| 574 | + "failed": 5, |
| 575 | + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}], |
| 576 | + }, |
| 577 | + "list": [ |
| 578 | + non_serializable, |
| 579 | + 1234, |
| 580 | + 13.37, |
| 581 | + True, |
| 582 | + None, |
| 583 | + "string here", |
| 584 | + ], |
| 585 | + } |
| 586 | + |
| 587 | + expected = { |
| 588 | + "metadata": { |
| 589 | + "timestamp": now.isoformat(), |
| 590 | + "version": "1.0", |
| 591 | + "debug_info": {"object": "<replaced>", "callable": "<replaced>"}, |
| 592 | + }, |
| 593 | + "content": [ |
| 594 | + {"type": "text", "value": "Hello world"}, |
| 595 | + {"type": "binary", "value": "<replaced>"}, |
| 596 | + {"type": "mixed", "values": [1, "text", "<replaced>", {"nested": "<replaced>"}]}, |
| 597 | + ], |
| 598 | + "statistics": { |
| 599 | + "processed": 100, |
| 600 | + "failed": 5, |
| 601 | + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": "<replaced>"}], |
| 602 | + }, |
| 603 | + "list": [ |
| 604 | + "<replaced>", |
| 605 | + 1234, |
| 606 | + 13.37, |
| 607 | + True, |
| 608 | + None, |
| 609 | + "string here", |
| 610 | + ], |
| 611 | + } |
| 612 | + |
| 613 | + result = json.loads(encoder.encode(data)) |
| 614 | + |
| 615 | + assert result == expected |
<
B061
/code> | 616 | + |
| 617 | + |
| 618 | +def test_json_encoder_value_error(): |
| 619 | + """Test encoding values that cause ValueError.""" |
| 620 | + encoder = JSONEncoder() |
| 621 | + |
| 622 | + # A very large integer that exceeds JSON limits and throws ValueError |
| 623 | + huge_number = 2**100000 |
| 624 | + |
| 625 | + # Test in a dictionary |
| 626 | + dict_data = {"normal": 42, "huge": huge_number} |
| 627 | + result = json.loads(encoder.encode(dict_data)) |
| 628 | + assert result == {"normal": 42, "huge": "<replaced>"} |
| 629 | + |
| 630 | + # Test in a list |
| 631 | + list_data = [42, huge_number] |
| 632 | + result = json.loads(encoder.encode(list_data)) |
| 633 | + assert result == [42, "<replaced>"] |
| 634 | + |
| 635 | + # Test just the value |
| 636 | + result = json.loads(encoder.encode(huge_number)) |
| 637 | + assert result == "<replaced>" |
0 commit comments