8000 Add debug trace endpoint in api server · Syntax404-coder/adk-python@80813a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 80813a7

Browse files
wyf7107copybara-github
authored andcommitted
Add debug trace endpoint in api server
Details: - Add a in-memory SpanExporter to capture all trace information. - Add /debug/trace/session/{session_id} endpoint to retrieve traces from the in-memory exporter. - Add Session ID in Telemetry spans. PiperOrigin-RevId: 757984565
1 parent d35b99e commit 80813a7

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

src/google/adk/cli/fast_api.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from opentelemetry.sdk.trace import export
4949
from opentelemetry.sdk.trace import ReadableSpan
5050
from opentelemetry.sdk.trace import TracerProvider
51+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
5152
from pydantic import alias_generators
5253
from pydantic import BaseModel
5354
from pydantic import ConfigDict
@@ -112,6 +113,42 @@ def force_flush(self, timeout_millis: int = 30000) -> bool:
112113
return True
113114

114115

116+
class InMemoryExporter(export.SpanExporter):
117+
118+
def __init__(self, trace_dict):
119+
super().__init__()
120+
self._spans = []
121+
self.trace_dict = trace_dict
122+
123+
def export(
124+
self, spans: typing.Sequence[ReadableSpan]
125+
) -> export.SpanExportResult:
126+
for span in spans:
127+
trace_id = span.context.trace_id
128+
if span.name == "call_llm":
129+
attributes = dict(span.attributes)
130+
session_id = attributes.get("gcp.vertex.agent.session_id", None)
131+
if session_id:
132+
if session_id not in self.trace_dict:
133+
self.trace_dict[session_id] = [trace_id]
134+
else:
135+
self.trace_dict[session_id] += [trace_id]
136+
self._spans.extend(spans)
137+
return export.SpanExportResult.SUCCESS
138+
139+
def get_finished_spans(self, session_id: str):
140+
trace_ids = self.trace_dict.get(session_id, None)
141+
if trace_ids is None or not trace_ids:
142+
return []
143+
return [x for x in self._spans if x.context.trace_id in trace_ids]
144+
145+
def force_flush(self, timeout_millis: int = 30000) -> bool:
146+
return True
147+
148+
def clear(self):
149+
self._spans.clear()
150+
151+
115152
class AgentRunRequest(BaseModel):
116153
app_name: str
117154
user_id: str
@@ -152,12 +189,15 @@ def get_fast_api_app(
152189
) -> FastAPI:
153190
# InMemory tracing dict.
154191
trace_dict: dict[str, Any] = {}
192+
session_trace_dict: dict[str, Any] = {}
155193

156194
# Set up tracing in the FastAPI server.
157195
provider = TracerProvider()
158196
provider.add_span_processor(
159197
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
160198
)
199+
memory_exporter = InMemoryExporter(session_trace_dict)
200+
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
161201
if trace_to_cloud:
162202
envs.load_dotenv_for_agent("", agent_dir)
163203
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
@@ -254,6 +294,24 @@ def get_trace_dict(event_id: str) -> Any:
254294
raise HTTPException(status_code=404, detail="Trace not found")
255295
return event_dict
256296

297+
@app.get("/debug/trace/session/{session_id}")
298+
def get_session_trace(session_id: str) -> Any:
299+
spans = memory_exporter.get_finished_spans(session_id)
300+
if not spans:
301+
return []
302+
return [
303+
{
304+
"name": s.name,
305+
"span_id": s.context.span_id,
306+
"trace_id": s.context.trace_id,
307+
"start_time": s.start_time,
308+
"end_time": s.end_time,
309+
"attributes": dict(s.attributes),
310+
"parent_span_id": s.parent.span_id if s.parent else None,
311+
}
312+
for s in spans
313+
]
314+
257315
@app.get(
258316
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
259317
response_model_exclude_none=True,
@@ -306,7 +364,6 @@ def create_session_with_id(
306364
raise HTTPException(
307365
status_code=400, detail=f"Session already exists: {session_id}"
308366
)
309-
310367
logger.info("New session created: %s", session_id)
311368
return session_service.create_session(
312369
app_name=app_name, user_id=user_id, state=state, session_id=session_id
@@ -323,7 +380,6 @@ def create_session(
323380
) -> Session:
324381
# Connect to managed session if agent_engine_id is set.
325382
app_name = agent_engine_id if agent_engine_id else app_name
326-
327383
logger.info("New session created")
328384
return session_service.create_session(
329385
app_name=app_name, user_id=user_id, state=state

src/google/adk/telemetry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def trace_call_llm(
111111
span.set_attribute(
112112
'gcp.vertex.agent.invocation_id', invocation_context.invocation_id
113113
)
114+
span.set_attribute(
115+
'gcp.vertex.agent.session_id', invocation_context.session.id
116+
)
114117
span.set_attribute('gcp.vertex.agent.event_id', event_id)
115118
# Consider removing once GenAI SDK provides a way to record this info.
116119
span.set_attribute(

0 commit comments

Comments
 (0)
0