8000 feat: Extract content encode/decode logic to a shared util and resolv… · tsayan/adk-python@14933ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 14933ba

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Extract content encode/decode logic to a shared util and resolve issues with JSON serialization.
feat: Update key length for DB table to avoid key too long issue in mysql PiperOrigin-RevId: 753614879
1 parent b691904 commit 14933ba

File tree

3 files changed

+56
-46
lines changed

3 files changed

+56
-46
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Utility functions for session service."""
2+
3+
import base64
4+
from typing import Any, Optional
5+
6+
from google.genai import types
7+
8+
9+
def encode_content(content: types.Content):
10+
"""Encodes a content object to a JSON dictionary."""
11+
encoded_content = content.model_dump(exclude_none=True)
12+
for p in encoded_content["parts"]:
13+
if "inline_data" in p:
14+
p["inline_data"]["data"] = base64.b64encode(
15+
p["inline_data"]["data"]
16+
).decode("utf-8")
17+
return encoded_content
18+
19+
20+
def decode_content(
21+
content: Optional[dict[str, Any]],
22+
) -> Optional[types.Content]:
23+
"""Decodes a content object from a JSON dictionary."""
24+
if not content:
25+
return None
26+
for p in content["parts"]:
27+
if "inline_data" in p:
28+
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
29+
return types.Content.model_validate(content)

src/google/adk/sessions/database_session_service.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff 10000 line numberDiff line change
@@ -11,22 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
import base64
1614
import copy
1715
from datetime import datetime
1816
import json
1917
import logging
2018
from typing import Any, Optional
2119
import uuid
2220

23-
from google.genai import types
2421
from sqlalchemy import Boolean
2522
from sqlalchemy import delete
2623
from sqlalchemy import Dialect
2724
from sqlalchemy import ForeignKeyConstraint
2825
from sqlalchemy import func
2926
from sqlalchemy import Text
27+
from sqlalchemy.dialects import mysql
3028
from sqlalchemy.dialects import postgresql
3129
from sqlalchemy.engine import create_engine
3230
from sqlalchemy.engine import Engine
@@ -48,6 +46,7 @@
4846
from tzlocal import get_localzone
4947

5048
from ..events.event import Event
49+
from . import _session_util
5150
from .base_session_service import BaseSessionService
5251
from .base_session_service import GetSessionConfig
5352
from .base_session_service import ListEventsResponse
@@ -58,6 +57,7 @@
5857

5958
logger = logging.getLogger(__name__)
6059

60+
DEFAULT_MAX_KEY_LENGTH = 128
6161
DEFAULT_MAX_VARCHAR_LENGTH = 256
6262

6363

@@ -72,15 +72,16 @@ class DynamicJSON(TypeDecorator):
7272
def load_dialect_impl(self, dialect: Dialect):
7373
if dialect.name == "postgresql":
7474
return dialect.type_descriptor(postgresql.JSONB)
75-
else:
76-
return dialect.type_descriptor(Text) # Default to Text for other dialects
75+
if dialect.name == "mysql":
76+
# Use LONGTEXT for MySQL to address the data too long issue
77+
return dialect.type_descriptor(mysql.LONGTEXT)
78+
return dialect.type_descriptor(Text) # Default to Text for other dialects
7779

7880
def process_bind_param(self, value, dialect: Dialect):
7981
if value is not None:
8082
if dialect.name == "postgresql":
8183
return value # JSONB handles dict directly
82-
else:
83-
return json.dumps(value) # Serialize to JSON string for TEXT
84+
return json.dumps(value) # Serialize to JSON string for TEXT
8485
return value
8586

8687
def process_result_value(self, value, dialect: Dialect):
@@ -104,13 +105,13 @@ class StorageSession(Base):
104105
__tablename__ = "sessions"
105106

106107
app_name: Mapped[str] = mapped_column(
107-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
108+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
108109
)
109110
user_id: Mapped[str] = mapped_column(
110-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
111+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
111112
)
112113
id: Mapped[str] = mapped_column(
113-
String(DEFAULT_MAX_VARCHAR_LENGTH),
114+
String(DEFAULT_MAX_KEY_LENGTH),
114115
primary_key=True,
115116
default=lambda: str(uuid.uuid4()),
116117
)
@@ -139,16 +140,16 @@ class StorageEvent(Base):
139140
__tablename__ = "events"
140141

141142
id: Mapped[str] = mapped_column(
142-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
143+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
143144
)
144145
app_name: Mapped[str] = mapped_column(
145-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
146+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
146147
)
147148
user_id: Mapped[str] = mapped_column(
148-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
149+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
149150
)
150151
session_id: Mapped[str] = mapped_column(
151-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
152+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
152153
)
153154

154155
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
@@ -209,7 +210,7 @@ class StorageAppState(Base):
209210
__tablename__ = "app_states"
210211

211212
app_name: Mapped[str] = mapped_column(
212-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
213+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
213214
)
214215
state: Mapped[MutableDict[str, Any]] = mapped_column(
215216
MutableDict.as_mutable(DynamicJSON), default={}
@@ -224,13 +225,10 @@ class StorageUserState(Base):
224225
__tablename__ = "user_states"
225226

226227
app_name: Mapped[str] = mapped_column(
227-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
228+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
228229
)
229230
user_id: Mapped[str] = mapped_column(
230-
String(DEFAULT_MAX_VARCHAR_LENGTH), primary_key=True
231-
)
232-
state: Mapped[MutableDict[str, Any]] = mapped_column(
233-
MutableDict.as_mutable(DynamicJSON), default={}
231+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
234232
)
235233
state: Mapped[MutableDict[str, Any]] = mapped_column(
236234
MutableDict.as_mutable(DynamicJSON), default={}
@@ -417,7 +415,7 @@ def get_session(
417415
author=e.author,
418416
branch=e.branch,
419417
invocation_id=e.invocation_id,
420-
content=_decode_content(e.content),
418+
content=_session_util.decode_content(e.content),
421419
actions=e.actions,
422420
timestamp=e.timestamp.timestamp(),
423421
long_running_tool_ids=e.long_running_tool_ids,
@@ -540,15 +538,7 @@ def append_event(self, session: Session, event: Event) -> Event:
540538
interrupted=event.interrupted,
541539
)
542540
if event.content:
543-
encoded_content = event.content.model_dump(exclude_none=True)
544-
# Workaround for multimodal Content throwing JSON not serializable
545-
# error with SQLAlchemy.
546-
for p in encoded_content["parts"]:
547-
if "inline_data" in p:
548-
p["inline_data"]["data"] = (
549-
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
550-
)
551-
storage_event.content = encoded_content
541+
storage_event.content = _session_util.encode_content(event.content)
552542

553543
sessionFactory.add(storage_event)
554544

@@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state):
608598
for key in user_state.keys():
609599
merged_state[State.USER_PREFIX + key] = user_state[key]
610600
return merged_state
611-
612-
613-
def _decode_content(
614-
content: Optional[dict[str, Any]],
615-
) -> Optional[types.Content]:
616-
if not content:
617-
return None
618-
for p in content["parts"]:
619-
if "inline_data" in p:
620-
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
621-
return types.Content.model_validate(content)

src/google/adk/sessions/vertex_ai_session_service.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,23 @@
1414
import logging
1515
import re
1616
import time
17-
from typing import Any
18-
from typing import Optional
17+
from typing import Any, Optional
1918

20-
from dateutil.parser import isoparse
19+
from dateutil import parser
2120
from google import genai
2221
from typing_extensions import override
2322

2423
from ..events.event import Event
2524
from ..events.event_actions import EventActions
25+
from . import _session_util
2626
from .base_session_service import BaseSessionService
2727
from .base_session_service import GetSessionConfig
2828
from .base_session_service import ListEventsResponse
2929
from .base_session_service import ListSessionsResponse
3030
from .session import Session
3131

32+
33+
isoparse = parser.isoparse
3234
logger = logging.getLogger(__name__)
3335

3436

@@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
289291
}
290292
event_json['actions'] = actions_json
291293
if event.content:
292-
event_json['content'] = event.content.model_dump(exclude_none=True)
294+
event_json['content'] = _session_util.encode_content(event.content)
293295
if event.error_code:
294296
event_json['error_code'] = event.error_code
295297
if event.error_message:
@@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
316318
invocation_id=api_event['invocationId'],
317319
author=api_event['author'],
318320
actions=event_actions,
319-
content=api_event.get('content', None),
321+
content=_session_util.decode_content(api_event.get('content', None)),
320322
timestamp=isoparse(api_event['timestamp']).timestamp(),
321323
error_code=api_event.get('errorCode', None),
322324
error_message=api_event.get('errorMessage', None),

0 commit comments

Comments
 (0)
0