diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 2ccd60083..ec09b1473 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -15,6 +15,7 @@ import copy from datetime import datetime +from datetime import timezone import json import logging from typing import Any @@ -27,6 +28,7 @@ from sqlalchemy import Dialect from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import Text from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql @@ -131,9 +133,11 @@ class StorageSession(Base): MutableDict.as_mutable(DynamicJSON), default={} ) - create_time: Mapped[DateTime] = mapped_column(DateTime(), default=func.now()) + create_time: Mapped[DateTime] = mapped_column( + DateTime(timezone=True), default=func.now() + ) update_time: Mapped[DateTime] = mapped_column( - DateTime(), default=func.now(), onupdate=func.now() + DateTime(timezone=True), default=func.now(), onupdate=func.now() ) storage_events: Mapped[list["StorageEvent"]] = relationship( @@ -144,6 +148,19 @@ class StorageSession(Base): def __repr__(self): return f"" + @property + def _dialect_name(self) -> str: + session = inspect(self).session + return session.bind.dialect.name if session else "" + + @property + def create_timestamp_utc(self) -> datetime: + return _convert_datetime_to_utc(self.create_time, self._dialect_name) + + @property + def update_timestamp_utc(self) -> datetime: + return _convert_datetime_to_utc(self.update_time, self._dialect_name) + class StorageEvent(Base): """Represents an event stored in the database.""" @@ -278,7 +295,7 @@ class StorageAppState(Base): MutableDict.as_mutable(DynamicJSON), default={} ) update_time: Mapped[DateTime] = mapped_column( - DateTime(), default=func.now(), onupdate=func.now() + DateTime(timezone=True), default=func.now(), onupdate=func.now() ) @@ -297,7 +314,7 @@ class StorageUserState(Base): MutableDict.as_mutable(DynamicJSON), default={} ) update_time: Mapped[DateTime] = mapped_column( - DateTime(), default=func.now(), onupdate=func.now() + DateTime(timezone=True), default=func.now(), onupdate=func.now() ) @@ -412,7 +429,7 @@ async def create_session( user_id=str(storage_session.user_id), id=str(storage_session.id), state=merged_state, - last_update_time=storage_session.update_time.timestamp(), + last_update_time=storage_session.update_timestamp_utc, ) return session @@ -473,7 +490,7 @@ async def get_session( user_id=user_id, id=session_id, state=merged_state, - last_update_time=storage_session.update_time.timestamp(), + last_update_time=storage_session.update_timestamp_utc, ) session.events = [e.to_event() for e in reversed(storage_events)] return session @@ -496,7 +513,7 @@ async def list_sessions( user_id=user_id, id=storage_session.id, state={}, - last_update_time=storage_session.update_time.timestamp(), + last_update_time=storage_session.update_timestamp_utc, ) sessions.append(session) return ListSessionsResponse(sessions=sessions) @@ -529,13 +546,13 @@ async def append_event(self, session: Session, event: Event) -> Event: StorageSession, (session.app_name, session.user_id, session.id) ) - if storage_session.update_time.timestamp() > session.last_update_time: + if storage_session.update_timestamp_utc > session.last_update_time: raise ValueError( "The last_update_time provided in the session object" f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is" " earlier than the update_time in the storage_session" - f" {storage_session.update_time:'%Y-%m-%d %H:%M:%S'}. Please check" - " if it is a stale session." + f" {datetime.fromtimestamp(storage_session.update_timestamp_utc):'%Y-%m-%d %H:%M:%S'}." + " Please check if it is a stale session." ) # Fetch states from storage @@ -577,7 +594,7 @@ async def append_event(self, session: Session, event: Event) -> Event: session_factory.refresh(storage_session) # Update timestamp with commit time - session.last_update_time = storage_session.update_time.timestamp() + session.last_update_time = storage_session.update_timestamp_utc # Also update the in-memory session await super().append_event(session=session, event=event) @@ -607,3 +624,15 @@ def _merge_state(app_state, user_state, session_state): for key in user_state.keys(): merged_state[State.USER_PREFIX + key] = user_state[key] return merged_state + + +def _convert_datetime_to_utc(dt: DateTime, dialect_name: str) -> datetime: + if dialect_name == "sqlite": + # SQLite does not support timezone. SQLAlchemy returns a naive datetime + # object without timezone information. We need to convert it to UTC + # manually. + return dt.replace(tzinfo=timezone.utc).timestamp() + else: + # For other dialects, SQLAlchemy returns a datetime object with timezone + # information. We can safely convert it to UTC. + return dt.astimezone(timezone.utc).timestamp()