11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
15
- import base64
16
14
import copy
17
15
from datetime import datetime
18
16
import json
19
17
import logging
20
18
from typing import Any , Optional
21
19
import uuid
22
20
23
- from google .genai import types
24
21
from sqlalchemy import Boolean
25
22
from sqlalchemy import delete
26
23
from sqlalchemy import Dialect
27
24
from sqlalchemy import ForeignKeyConstraint
28
25
from sqlalchemy import func
29
26
from sqlalchemy import Text
27
+ from sqlalchemy .dialects import mysql
30
28
from sqlalchemy .dialects import postgresql
31
29
from sqlalchemy .engine import create_engine
32
30
from sqlalchemy .engine import Engine
48
46
from tzlocal import get_localzone
49
47
50
48
from ..events .event import Event
49
+ from . import _session_util
51
50
from .base_session_service import BaseSessionService
52
51
from .base_session_service import GetSessionConfig
53
52
from .base_session_service import ListEventsResponse
58
57
59
58
logger = logging .getLogger (__name__ )
60
59
60
+ DEFAULT_MAX_KEY_LENGTH = 128
61
61
DEFAULT_MAX_VARCHAR_LENGTH = 256
62
62
63
63
@@ -72,15 +72,16 @@ class DynamicJSON(TypeDecorator):
72
72
def load_dialect_impl (self , dialect : Dialect ):
73
73
if dialect .name == "postgresql" :
74
74
return dialect .type_descriptor (postgresql .JSONB )
75
- else :
76
- return dialect .type_descriptor (Text ) # Default to Text for other dialects
75
+ if dialect .name == "mysql" :
76
+ # Use LONGTEXT for MySQL to address the data too long issue
77
+ return dialect .type_descriptor (mysql .LONGTEXT )
78
+ return dialect .type_descriptor (Text ) # Default to Text for other dialects
77
79
78
80
def process_bind_param (self , value , dialect : Dialect ):
79
81
if value is not None :
80
82
if dialect .name == "postgresql" :
81
83
return value # JSONB handles dict directly
82
- else :
83
- return json .dumps (value ) # Serialize to JSON string for TEXT
84
+ return json .dumps (value ) # Serialize to JSON string for TEXT
84
85
return value
85
86
86
87
def process_result_value (self , value , dialect : Dialect ):
@@ -104,13 +105,13 @@ class StorageSession(Base):
104
105
__tablename__ = "sessions"
105
106
106
107
app_name : Mapped [str ] = mapped_column (
107
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
108
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
108
109
)
109
110
user_id : Mapped [str ] = mapped_column (
110
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
111
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
111
112
)
112
113
id : Mapped [str ] = mapped_column (
113
- String (DEFAULT_MAX_VARCHAR_LENGTH ),
114
+ String (DEFAULT_MAX_KEY_LENGTH ),
114
115
primary_key = True ,
115
116
default = lambda : str (uuid .uuid4 ()),
116
117
)
@@ -139,16 +140,16 @@ class StorageEvent(Base):
139
140
__tablename__ = "events"
140
141
141
142
id : Mapped [str ] = mapped_column (
142
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
143
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
143
144
)
144
145
app_name : Mapped [str ] = mapped_column (
145
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
146
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
146
147
)
147
148
user_id : Mapped [str ] = mapped_column (
148
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
149
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
149
150
)
150
151
session_id : Mapped [str ] = mapped_column (
151
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
152
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
152
153
)
153
154
154
155
invocation_id : Mapped [str ] = mapped_column (String (DEFAULT_MAX_VARCHAR_LENGTH ))
@@ -209,7 +210,7 @@ class StorageAppState(Base):
209
210
__tablename__ = "app_states"
210
211
211
212
app_name : Mapped [str ] = mapped_column (
212
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
213
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
213
214
)
214
215
state : Mapped [MutableDict [str , Any ]] = mapped_column (
215
216
MutableDict .as_mutable (DynamicJSON ), default = {}
@@ -224,13 +225,10 @@ class StorageUserState(Base):
224
225
__tablename__ = "user_states"
225
226
226
227
app_name : Mapped [str ] = mapped_column (
227
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
228
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
228
229
)
229
230
user_id : Mapped [str ] = mapped_column (
230
- String (DEFAULT_MAX_VARCHAR_LENGTH ), primary_key = True
231
- )
232
- state : Mapped [MutableDict [str , Any ]] = mapped_column (
233
- MutableDict .as_mutable (DynamicJSON ), default = {}
231
+ String (DEFAULT_MAX_KEY_LENGTH ), primary_key = True
234
232
)
235
233
state : Mapped [MutableDict [str , Any ]] = mapped_column (
236
234
MutableDict .as_mutable (DynamicJSON ), default = {}
@@ -417,7 +415,7 @@ def get_session(
417
415
author = e .author ,
418
416
branch = e .branch ,
419
417
invocation_id = e .invocation_id ,
420
- content = _decode_content (e .content ),
418
+ content = _session_util . decode_content (e .content ),
421
419
actions = e .actions ,
422
420
timestamp = e .timestamp .timestamp (),
423
421
long_running_tool_ids = e .long_running_tool_ids ,
@@ -540,15 +538,7 @@ def append_event(self, session: Session, event: Event) -> Event:
540
538
interrupted = event .interrupted ,
541
539
)
542
540
if event .content :
543
- encoded_content = event .content .model_dump (exclude_none = True )
544
- # Workaround for multimodal Content throwing JSON not serializable
545
- # error with SQLAlchemy.
546
- for p in encoded_content ["parts" ]:
547
- if "inline_data" in p :
548
- p ["inline_data" ]["data" ] = (
549
- base64 .b64encode (p ["inline_data" ]["data" ]).decode ("utf-8" ),
550
- )
551
- storage_event .content = encoded_content
541
+ storage_event .content = _session_util .encode_content (event .content )
552
542
553
543
sessionFactory .add (storage_event )
554
544
@@ -608,14 +598,3 @@ def _merge_state(app_state, user_state, session_state):
608
598
for key in user_state .keys ():
609
599
merged_state [State .USER_PREFIX + key ] = user_state [key ]
610
600
return merged_state
611
-
612
-
613
- def _decode_content (
614
- content : Optional [dict [str , Any ]],
615
- ) -> Optional [types .Content ]:
616
- if not content :
617
- return None
618
- for p in content ["parts" ]:
619
- if "inline_data" in p :
620
- p ["inline_data" ]["data" ] = base64 .b64decode (p ["inline_data" ]["data" ][0 ])
621
- return types .Content .model_validate (content )
0 commit comments