8000 feat: ✨ redis session class by BloodBoy21 · Pull Request #789 · google/adk-python · GitHub
[go: up one dir, main page]

Skip to content

feat: ✨ redis session class #789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
d481e06
feat: :construction: catch user transcription
BloodBoy21 Apr 30, 2025
bba436b
feat: :sparkles: send user transcription event as llm_response
BloodBoy21 Apr 30, 2025
1de70be
Merge branch 'main' into main
BloodBoy21 May 1, 2025
5480f0d
Merge branch 'main' into main
BloodBoy21 May 1, 2025
ad2abf5
style: :lipstick: update lint problems
BloodBoy21 May 2, 2025
98b1829
Merge branch 'main' into main
BloodBoy21 May 3, 2025
5757063
Merge branch 'main' into main
BloodBoy21 May 3, 2025
405187f
Merge branch 'main' into main
BloodBoy21 May 3, 2025
744703c
fix: set right order for input transcription
hangfei May 3, 2025
31a5d42
remove print
hangfei May 3, 2025
913a492
Merge branch 'main' into main
hangfei May 3, 2025
59e5d9c
remove api version
hangfei May 3, 2025
56dbb93
Merge branch 'main' into main
hangfei May 4, 2025
b5f28fc
Merge branch 'main' into main
BloodBoy21 May 4, 2025
4b782f8
feat: :sparkles: set version to gemini using vertex api
BloodBoy21 May 4, 2025
6e74b74
Merge branch 'main' into main
BloodBoy21 May 5, 2025
ea29015
Merge branch 'main' into main
BloodBoy21 May 6, 2025
578eeb0
Merge branch 'main' into main
hangfei May 6, 2025
fb0b464
Merge branch 'feat/api-version-vertex'
BloodBoy21 May 6, 2025
237e35e
feat: :construction: add redis as memory service
BloodBoy21 May 7, 2025
105192d
fix: :bug: save all data from events in session
BloodBoy21 May 12, 2025
66dfccf
Merge branch 'google:main' into feat-redis-session
BloodBoy21 May 12, 2025
36cd268
Merge branch 'main' of github.com:BloodBoy21/nerds-adk-python
BloodBoy21 May 19, 2025
aa46a76
Merge branch 'main' into feat-redis-session
BloodBoy21 May 19, 2025
30fe6b2
chore: :lipstick: apply code style
BloodBoy21 May 19, 2025
04a66bb
Merge branch 'main' into feat-redis-session
BloodBoy21 May 19, 2025
11498eb
Merge branch 'main' into feat-redis-session
BloodBoy21 May 20, 2025
fb7c0af
Merge branch 'main' into feat-redis-session
seanzhou1023 May 21, 2025
9957c2c
Merge branch 'main' into feat-redis-session
BloodBoy21 May 22, 2025
67153a5
Merge branch 'google:main' into feat-redis-session
BloodBoy21 May 28, 2025
737fd82
fix: :coffin: remove list_events from redis session service
BloodBoy21 May 28, 2025
1fe2ef9
fix: :bug: fix await errors in runner
BloodBoy21 May 29, 2025
3a98976
fix: :bug: await append event in runner
BloodBoy21 May 29, 2025
0368ab4
fix: :bug: await append event if not partial
BloodBoy21 May 29, 2025
530ae04
Merge branch 'main' into feat-redis-session
BloodBoy21 Jun 16, 2025
2a55584
refactor: tool call accept event response
BloodBoy21 Jun 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"python-dateutil>=2.9.0.post0", # For Vertext AI Session Service
"python-dotenv>=1.0.0", # To manage environment variables
"PyYAML>=6.0.2", # For APIHubToolset.
"redis>=6.0.0", # For RedisToolset
"requests>=2.32.4",
"sqlalchemy>=2.0", # SQL database ORM
"starlette>=0.46.2", # For FastAPI CLI
Expand Down Expand Up @@ -78,7 +79,7 @@ dev = [

a2a = [
# go/keep-sorted start
"a2a-sdk>=0.2.7;python_version>='3.10'"
"a2a-sdk>=0.2.7;python_version>='3.10'",
# go/keep-sorted end
]

Expand Down
20 changes: 14 additions & 6 deletions src/google/adk/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime
import random
import string
from typing import Dict
from typing import Optional

from google.genai import types
Expand Down Expand Up @@ -47,16 +48,16 @@ class Event(LlmResponse):
"""

model_config = ConfigDict(
extra='forbid',
ser_json_bytes='base64',
val_json_bytes='base64',
extra="forbid",
ser_json_bytes="base64",
val_json_bytes="base64",
alias_generator=alias_generators.to_camel,
populate_by_name=True,
)
"""The pydantic model config."""

# TODO: revert to be required after spark migration
invocation_id: str = ''
invocation_id: str = ""
"""The invocation ID of the event."""
author: str
"""'user' or the name of the agent, indicating who appended the event to the
Expand All @@ -81,7 +82,7 @@ class Event(LlmResponse):

# The following are computed fields.
# Do not assign the ID. It will be assigned by the session.
id: str = ''
id: str = ""
"""The unique identifier of the event."""
timestamp: float = Field(default_factory=lambda: datetime.now().timestamp())
"""The timestamp of the event."""
Expand Down Expand Up @@ -133,4 +134,11 @@ def has_trailing_code_execution_result(
@staticmethod
def new_id():
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(8))
return "".join(random.choice(characters) for _ in range(8))

def to_dict(self, exclude: Dict = {}) -> dict:
return self.model_dump(exclude=exclude, mode="json")

@staticmethod
def from_dict(data: dict) -> "Event":
return Event.model_validate(data)
65 changes: 32 additions & 33 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext

AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
AF_FUNCTION_CALL_ID_PREFIX = "adk-"
REQUEST_EUC_FUNCTION_CALL_NAME = "adk_request_credential"

logger = logging.getLogger('google_adk.' + __name__)
logger = logging.getLogger("google_adk." + __name__)


def generate_client_function_call_id() -> str:
return f'{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}'
return f"{AF_FUNCTION_CALL_ID_PREFIX}{uuid.uuid4()}"


def populate_client_function_call_id(model_response_event: Event) -> None:
Expand Down Expand Up @@ -100,7 +100,6 @@ def generate_auth_event(
function_call_id,
auth_config,
) in function_response_event.actions.requested_auth_configs.items():

request_euc_function_call = types.FunctionCall(
name=REQUEST_EUC_FUNCTION_CALL_NAME,
args=AuthToolArguments(
Expand Down Expand Up @@ -149,7 +148,7 @@ async def handle_function_calls_async(
tools_dict,
)

with tracer.start_as_current_span(f'execute_tool {tool.name}'):
with tracer.start_as_current_span(f"execute_tool {tool.name}"):
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
Expand Down Expand Up @@ -208,7 +207,7 @@ async def handle_function_calls_async(
# this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('execute_tool (merged)'):
with tracer.start_as_current_span("execute_tool (merged)"):
trace_merged_tool_calls(
response_event_id=merged_event.id,
function_response_event=merged_event,
Expand All @@ -232,7 +231,7 @@ async def handle_function_calls_live(
tool, tool_context = _get_tool_and_context(
invocation_context, function_call_event, function_call, tools_dict
)
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
with tracer.start_as_current_span(f"execute_tool {tool.name}"):
# do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
Expand Down Expand Up @@ -289,7 +288,7 @@ async def handle_function_calls_live(
tool=tool,
args=function_args,
response_event_id=function_response_event.id,
function_response=function_response,
function_response_event=function_response_event,
)
function_response_events.append(function_response_event)

Expand All @@ -302,7 +301,7 @@ async def handle_function_calls_live(
# this is needed for debug traces of parallel calls
# individual response with tool.name is traced in __build_response_event
# (we drop tool.name from span name here as this is merged event)
with tracer.start_as_current_span('execute_tool (merged)'):
with tracer.start_as_current_span("execute_tool (merged)"):
trace_merged_tool_calls(
response_event_id=merged_event.id,
function_response_event=merged_event,
Expand All @@ -316,10 +315,10 @@ async def _process_function_live_helper(
function_response = None
# Check if this is a stop_streaming function call
if (
function_call.name == 'stop_streaming'
and 'function_name' in function_args
function_call.name == "stop_streaming"
and "function_name" in function_args
):
function_name = function_args['function_name']
function_name = function_args["function_name"]
active_tasks = invocation_context.active_streaming_tools
if (
function_name in active_tasks
Expand All @@ -334,29 +333,29 @@ async def _process_function_live_helper(
except (asyncio.CancelledError, asyncio.TimeoutError):
# Log the specific condition
if task.cancelled():
logging.info(f'Task {function_name} was cancelled successfully')
logging.info(f"Task {function_name} was cancelled successfully")
elif task.done():
logging.info(f'Task {function_name} completed during cancellation')
logging.info(f"Task {function_name} completed during cancellation")
else:
logging.warning(
f'Task {function_name} might still be running after'
' cancellation timeout'
f"Task {function_name} might still be running after"
" cancellation timeout"
)
function_response = {
'status': f'The task is not cancelled yet for {function_name}.'
"status": f"The task is not cancelled yet for {function_name}."
}
if not function_response:
# Clean up the reference
active_tasks[function_name].task = None

function_response = {
'status': f'Successfully stopped streaming function {function_name}'
"status": f"Successfully stopped streaming function {function_name}"
}
else:
function_response = {
'status': f'No active streaming function named {function_name} found'
"status": f"No active streaming function named {function_name} found"
}
elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
# for streaming tool use case
# we require the function to be a async generator function
async def run_tool_and_update_queue(tool, function_args, tool_context):
Expand All @@ -368,10 +367,10 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
invocation_context=invocation_context,
):
updated_content = types.Content(
role='user',
role="user",
parts=[
types.Part.from_text(
text=f'Function {tool.name} returned: {result}'
text=f"Function {tool.name} returned: {result}"
)
],
)
Expand All @@ -393,9 +392,9 @@ async def run_tool_and_update_queue(tool, function_args, tool_context):
# Immediately return a pending response.
# This is required by current live model.
function_response = {
'status': (
'The function is running asynchronously and the results are'
' pending.'
"status": (
"The function is running asynchronously and the results are"
" pending."
)
}
else:
Expand All @@ -413,7 +412,7 @@ def _get_tool_and_context(
):
if function_call.name not in tools_dict:
raise ValueError(
f'Function {function_call.name} is not found in the tools_dict.'
f"Function {function_call.name} is not found in the tools_dict."
)

tool_context = ToolContext(
Expand Down Expand Up @@ -458,15 +457,15 @@ def __build_response_event(
) -> Event:
# Specs requires the result to be a dict.
if not isinstance(function_result, dict):
function_result = {'result': function_result}
function_result = {"result": function_result}

part_function_response = types.Part.from_function_response(
name=tool.name, response=function_result
)
part_function_response.function_response.id = tool_context.function_call_id

content = types.Content(
role='user',
role="user",
parts=[part_function_response],
)

Expand All @@ -482,10 +481,10 @@ def __build_response_event(


def merge_parallel_function_response_events(
function_response_events: list['Event'],
) -> 'Event':
function_response_events: list["Event"],
) -> "Event":
if not function_response_events:
raise ValueError('No function response events provided.')
raise ValueError("No function response events provided.")

if len(function_response_events) == 1:
return function_response_events[0]
Expand Down Expand Up @@ -513,7 +512,7 @@ def merge_parallel_function_response_events(
invocation_id=Event.new_id(),
author=base_event.author,
branch=base_event.branch,
content=types.Content(role='user', parts=merged_parts),
content=types.Content(role="user", parts=merged_parts),
actions=merged_actions, # Optionally merge actions if required
)

Expand Down
Loading
0