48
48
from opentelemetry .sdk .trace import export
49
49
from opentelemetry .sdk .trace import ReadableSpan
50
50
from opentelemetry .sdk .trace import TracerProvider
51
+ from opentelemetry .sdk .trace .export .in_memory_span_exporter import InMemorySpanExporter
51
52
from pydantic import alias_generators
52
53
from pydantic import BaseModel
53
54
from pydantic import ConfigDict
@@ -112,6 +113,42 @@ def force_flush(self, timeout_millis: int = 30000) -> bool:
112
113
return True
113
114
114
115
116
+ class InMemoryExporter (export .SpanExporter ):
117
+
118
+ def __init__ (self , trace_dict ):
119
+ super ().__init__ ()
120
+ self ._spans = []
121
+ self .trace_dict = trace_dict
122
+
123
+ def export (
124
+ self , spans : typing .Sequence [ReadableSpan ]
125
+ ) -> export .SpanExportResult :
126
+ for span in spans :
127
+ trace_id = span .context .trace_id
128
+ if span .name == "call_llm" :
129
+ attributes = dict (span .attributes )
130
+ session_id = attributes .get ("gcp.vertex.agent.session_id" , None )
131
+ if session_id :
132
+ if session_id not in self .trace_dict :
133
+ self .trace_dict [session_id ] = [trace_id ]
134
+ else :
135
+ self .trace_dict [session_id ] += [trace_id ]
136
+ self ._spans .extend (spans )
137
+ return export .SpanExportResult .SUCCESS
138
+
139
+ def get_finished_spans (self , session_id : str ):
140
+ trace_ids = self .trace_dict .get (session_id , None )
141
+ if trace_ids is None or not trace_ids :
142
+ return []
143
+ return [x for x in self ._spans if x .context .trace_id in trace_ids ]
144
+
145
+ def force_flush (self , timeout_millis : int = 30000 ) -> bool :
146
+ return True
147
+
148
+ def clear (self ):
149
+ self ._spans .clear ()
150
+
151
+
115
152
class AgentRunRequest (BaseModel ):
116
153
app_name : str
117
154
user_id : str
@@ -152,12 +189,15 @@ def get_fast_api_app(
152
189
) -> FastAPI :
153
190
# InMemory tracing dict.
154
191
trace_dict : dict [str , Any ] = {}
192
+ session_trace_dict : dict [str , Any ] = {}
155
193
156
194
# Set up tracing in the FastAPI server.
157
195
provider = TracerProvider ()
158
196
provider .add_span_processor (
159
197
export .SimpleSpanProcessor (ApiServerSpanExporter (trace_dict ))
160
198
)
199
+ memory_exporter = InMemoryExporter (session_trace_dict )
200
+ provider .add_span_processor (export .SimpleSpanProcessor (memory_exporter ))
161
201
if trace_to_cloud :
162
202
envs .load_dotenv_for_agent ("" , agent_dir )
163
203
if project_id := os .environ .get ("GOOGLE_CLOUD_PROJECT" , None ):
@@ -254,6 +294,24 @@ def get_trace_dict(event_id: str) -> Any:
254
294
raise HTTPException (status_code = 404 , detail = "Trace not found" )
255
295
return event_dict
256
296
297
+ @app .get ("/debug/trace/session/{session_id}" )
298
+ def get_session_trace (session_id : str ) -> Any :
299
+ spans = memory_exporter .get_finished_spans (session_id )
300
+ if not spans :
301
+ return []
302
+ return [
303
+ {
304
+ "name" : s .name ,
305
+ "span_id" : s .context .span_id ,
306
+ "trace_id" : s .context .trace_id ,
307
+ "start_time" : s .start_time ,
308
+ "end_time" : s .end_time ,
309
+ "attributes" : dict (s .attributes ),
310
+ "parent_span_id" : s .parent .span_id if s .parent else None ,
311
+ }
312
+ for s in spans
313
+ ]
314
+
257
315
@app .get (
258
316
"/apps/{app_name}/users/{user_id}/sessions/{session_id}" ,
259
317
response_model_exclude_none = True ,
@@ -306,7 +364,6 @@ def create_session_with_id(
306
364
raise HTTPException (
307
365
status_code = 400 , detail = f"Session already exists: { session_id } "
308
366
)
309
-
310
367
logger .info ("New session created: %s" , session_id )
311
368
return session_service .create_session (
312
369
app_name = app_name , user_id = user_id , state = state , session_id = session_id
@@ -323,7 +380,6 @@ def create_session(
323
380
) -> Session :
324
381
# Connect to managed session if agent_engine_id is set.
325
382
app_name = agent_engine_id if agent_engine_id else app_name
326
-
327
383
logger .info ("New session created" )
328
384
return session_service .create_session (
329
385
app_name = app_name , user_id = user_id , state = state
0 commit comments