8000 adapt google api toolset and api hub toolset to new toolset interface · Syntax404-coder/adk-python@6a04ff8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6a04ff8

Browse files
seanzhougooglecopybara-github
authored andcommitted
adapt google api toolset and api hub toolset to new toolset interface
PiperOrigin-RevId: 757946541
1 parent 27b2297 commit 6a04ff8

File tree

7 files changed

+168
-122
lines changed

7 files changed

+168
-122
lines changed

src/google/adk/tools/apihub_tool/apihub_toolset.py

Lines changed: 47 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,25 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Dict, List, Optional
16+
from typing import List
17+
from typing import Optional
18+
from typing import override
19+
from typing import Union
1720

1821
import yaml
1922

23+
from ...agents.readonly_context import ReadonlyContext
2024
from ...auth.auth_credential import AuthCredential
2125
from ...auth.auth_schemes import AuthScheme
26+
from ..base_toolset import BaseToolset
27+
from ..base_toolset import ToolPredicate
2228
from ..openapi_tool.common.common import to_snake_case
2329
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
2430
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
2531
from .clients.apihub_client import APIHubClient
2632

2733

28-
class APIHubToolset:
34+
class APIHubToolset(BaseToolset):
2935
"""APIHubTool generates tools from a given API Hub resource.
3036
3137
Examples:
@@ -34,16 +40,13 @@ class APIHubToolset:
3440
apihub_toolset = APIHubToolset(
3541
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
3642
service_account_json="...",
43+
tool_filter=lambda tool, ctx=None: tool.name in ('my_tool',
44+
'my_other_tool')
3745
)
3846
3947
# Get all available tools
40-
agent = LlmAgent(tools=apihub_toolset.get_tools())
48+
agent = LlmAgent(tools=apihub_toolset)
4149
42-
# Get a specific tool
43-
agent = LlmAgent(tools=[
44-
...
45-
apihub_toolset.get_tool('my_tool'),
46-
])
4750
```
4851
4952
**apihub_resource_name** is the resource name from API Hub. It must include
@@ -70,6 +73,7 @@ def __init__(
7073
auth_credential: Optional[AuthCredential] = None,
7174
# Optionally, you can provide a custom API Hub client
7275
apihub_client: Optional[APIHubClient] = None,
76+
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
7377
):
7478
"""Initializes the APIHubTool with the given parameters.
7579
@@ -81,12 +85,17 @@ def __init__(
8185
)
8286
8387
# Get all available tools
84-
agent = LlmAgent(tools=apihub_toolset.get_tools())
88+
agent = LlmAgent(tools=[apihub_toolset])
8589
90+
apihub_toolset = APIHubToolset(
91+
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
92+
service_account_json="...",
93+
tool_filter = ['my_tool']
94+
)
8695
# Get a specific tool
8796
agent = LlmAgent(tools=[
88-
...
89-
apihub_toolset.get_tool('my_tool'),
97+
...,
98+
apihub_toolset,
9099
])
91100
```
92101
@@ -118,6 +127,9 @@ def __init__(
118127
lazy_load_spec: If True, the spec will be loaded lazily when needed.
119128
Otherwise, the spec will be loaded immediately and the tools will be
120129
generated during initialization.
130+
tool_filter: The filter used to filter the tools in the toolset. It can
131+
be either a tool predicate or a list of tool names of the tools to
132+
expose.
121133
"""
122134
self.name = name
123135
self.description = description
@@ -128,82 +140,51 @@ def __init__(
128140
service_account_json=service_account_json,
129141
)
130142

131-
self.generated_tools: Dict[str, RestApiTool] = {}
143+
self.openapi_toolset = None
132144
self.auth_scheme = auth_scheme
133145
self.auth_credential = auth_credential
146+
self.tool_filter = tool_filter
134147

135148
if not self.lazy_load_spec:
136-
self._prepare_tools()
137-
138-
def get_tool(self, name: str) -> Optional[RestApiTool]:
139-
"""Retrieves a specific tool by its name.
140-
141-
Example:
142-
```
143-
apihub_tool = apihub_toolset.get_tool('my_tool')
144-
```
145-
146-
Args:
147-
name: The name of the tool to retrieve.
148-
149-
Returns:
150-
The tool with the given name, or None if no such tool exists.
151-
"""
152-
if not self._are_tools_ready():
153-
self._prepare_tools()
149+
self._prepare_toolset()
154150

155-
return self.generated_tools[name] if name in self.generated_tools else None
156-
157-
def get_tools(self) -> List[RestApiTool]:
151+
@override
152+
async def get_tools(
153+
self, readonly_context: Optional[ReadonlyContext] = None
154+
) -> List[RestApiTool]:
158155
"""Retrieves all available tools.
159156
160157
Returns:
161158
A list of all available RestApiTool objects.
162159
"""
163-
if not self._are_tools_ready():
164-
self._prepare_tools()
165-
166-
return list(self.generated_tools.values())
167-
168-
def _are_tools_ready(self) -> bool:
169-
return not self.lazy_load_spec or self.generated_tools
170-
171-
def _prepare_tools(self) -> str:
172-
"""Fetches the spec from API Hub and generates the tools.
160+
if not self.openapi_toolset:
161+
self._prepare_toolset()
162+
if not self.openapi_toolset:
163+
return []
164+
return await self.openapi_toolset.get_tools(readonly_context)
173165

174-
Returns:
175-
True if the tools are ready, False otherwise.
176-
"""
166+
def _prepare_toolset(self) -> None:
167+
"""Fetches the spec from API Hub and generates the toolset."""
177168
# For each API, get the first version and the first spec of that version.
178-
spec = self.apihub_client.get_spec_content(self.apihub_resource_name)
179-
self.generated_tools: Dict[str, RestApiTool] = {}
180-
181-
tools = self._parse_spec_to_tools(spec)
182-
for tool in tools:
183-
self.generated_tools[tool.name] = tool
184-
185-
def _parse_spec_to_tools(self, spec_str: str) -> List[RestApiTool]:
186-
"""Parses the spec string to a list of RestApiTool.
187-
188-
Args:
189-
spec_str: The spec string to parse.
190-
191-
Returns:
192-
A list of RestApiTool objects.
193-
"""
169+
spec_str = self.apihub_client.get_spec_content(self.apihub_resource_name)
194170
spec_dict = yaml.safe_load(spec_str)
195171
if not spec_dict:
196-
return []
172+
return
197173

198174
self.name = self.name or to_snake_case(
199175
spec_dict.get('info', {}).get('title', 'unnamed')
200176
)
201177
self.description = self.description or spec_dict.get('info', {}).get(
202178
'description', ''
203179
)
204-
tools = OpenAPIToolset(
180+
self.openapi_toolset = OpenAPIToolset(
205181
spec_dict=spec_dict,
206182
auth_credential=self.auth_credential,
207183
auth_scheme=self.auth_scheme,
208-
).get_tools()
209-
return tools
184+
tool_filter=self.tool_filter,
185+
)
186+
187+
@override
188+
async def close(self):
189+
if self.openapi_toolset:
190+
await self.openapi_toolset.close()

src/google/adk/tools/base_toolset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC
22
from abc import abstractmethod
3+
from typing import Optional
34
from typing import Protocol
45

56
from google.adk.agents.readonly_context import ReadonlyContext
@@ -33,7 +34,7 @@ class BaseToolset(ABC):
3334

3435
@abstractmethod
3536
async def get_tools(
36-
self, readony_context: ReadonlyContext = None
37+
self, readonly_context: Optional[ReadonlyContext] = None
3738
) -> list[BaseTool]:
3839
"""Return all tools in the toolset based on the provided context.
3940

src/google/adk/tools/google_api_tool/google_api_tool_set.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,67 @@
1717
import inspect
1818
import os
1919
from typing import Any
20-
from typing import Final
2120
from typing import List
2221
from typing import Optional
22+
from typing import override
2323
from typing import Type
24+
from typing import Union
2425

26+
from ...agents.readonly_context import ReadonlyContext
2527
from ...auth import OpenIdConnectWithConfig
28+
from ...tools.base_toolset import BaseToolset
29+
from ...tools.base_toolset import ToolPredicate
2630
from ..openapi_tool import OpenAPIToolset
27-
from ..openapi_tool import RestApiTool
2831
from .google_api_tool import GoogleApiTool
2932
from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
3033

3134

32-
class GoogleApiToolSet:
33-
"""Google API Tool Set."""
35+
class GoogleApiToolset(BaseToolset):
36+
"""Google API Toolset contains tools for interacting with Google APIs.
3437
35-
def __init__(self, tools: List[RestApiTool]):
36-
self.tools: Final[List[GoogleApiTool]] = [
37-
GoogleApiTool(tool) for tool in tools
38-
]
38+
Usually one toolsets will contains tools only replated to one Google API, e.g.
39+
Google Bigquery API toolset will contains tools only related to Google
40+
Bigquery API, like list dataset tool, list table tool etc.
41+
"""
3942

40-
def get_tools(self) -> List[GoogleApiTool]:
43+
def __init__(
44+
self,
45+
openapi_toolset: OpenAPIToolset,
46+
client_id: Optional[str] = None,
47+
client_secret: Optional[str] = None,
48+
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
49+
):
50+
self.openapi_toolset = openapi_toolset
51+
self.tool_filter = tool_filter
52+
self.client_id = client_id
53+
self.client_secret = client_secret
54+
55+
@override
56+
async def get_tools(
57+
self, readonly_context: Optional[ReadonlyContext] = None
58+
) -> List[GoogleApiTool]:
4159
"""Get all tools in the toolset."""
42-
return self.tools
60+
tools = []
61+
62+
for tool in await self.openapi_toolset.get_tools(readonly_context):
63+
if self.tool_filter and (
64+
isinstance(self.tool_filter, ToolPredicate)
65+
and not self.tool_filter(tool, readonly_context)
66+
or isinstance(self.tool_filter, list)
67+
and tool.name not in self.tool_filter
68+
):
69+
continue
70+
google_api_tool = GoogleApiTool(tool)
71+
google_api_tool.configure_auth(self.client_id, self.client_secret)
72+
tools.append(google_api_tool)
4373

44-
def get_tool(self, tool_name: str) -> Optional[GoogleApiTool]:
45-
"""Get a tool by name."""
46-
matching_tool = filter(lambda t: t.name == tool_name, self.tools)
47-
return next(matching_tool, None)
74+
return tools
75+
76+
def set_tool_filter(self, tool_filter: Union[ToolPredicate, List[str]]):
77+
self.tool_filter = tool_filter
4878

4979
@staticmethod
50-
def _load_tool_set_with_oidc_auth(
80+
def _load_toolset_with_oidc_auth(
5181
spec_file: Optional[str] = None,
5282
spec_dict: Optional[dict[str, Any]] = None,
5383
scopes: Optional[list[str]] = None,
@@ -64,7 +94,7 @@ def _load_tool_set_with_oidc_auth(
6494
yaml_path = os.path.join(caller_dir, spec_file)
6595
with open(yaml_path, 'r', encoding='utf-8') as file:
6696
spec_str = file.read()
67-
tool_set = OpenAPIToolset(
97+
toolset = OpenAPIToolset(
6898
spec_dict=spec_dict,
6999
spec_str=spec_str,
70100
spec_str_type='yaml',
@@ -85,26 +115,29 @@ def _load_tool_set_with_oidc_auth(
85115
scopes=scopes,
86116
),
87117
)
88-
return tool_set
118+
return toolset
89119

90120
def configure_auth(self, client_id: str, client_secret: str):
91-
for tool in self.tools:
92-
tool.configure_auth(client_id, client_secret)
121+
self.client_id = client_id
122+
self.client_secret = client_secret
93123

94124
@classmethod
95-
def load_tool_set(
96-
cls: Type[GoogleApiToolSet],
125+
def load_toolset(
126+
cls: Type[GoogleApiToolset],
97127
api_name: str,
98128
api_version: str,
99-
) -> GoogleApiToolSet:
129+
) -> GoogleApiToolset:
100130
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
101131
scope = list(
102132
spec_dict['components']['securitySchemes']['oauth2']['flows'][
103133
'authorizationCode'
104134
]['scopes'].keys()
105135
)[0]
106136
return cls(
107-
cls._load_tool_set_with_oidc_auth(
108-
spec_dict=spec_dict, scopes=[scope]
109-
).get_tools()
137+
cls._load_toolset_with_oidc_auth(spec_dict=spec_dict, scopes=[scope])
110138
)
139+
140+
@override
141+
async def close(self):
142+
if self.openapi_toolset:
143+
await self.openapi_toolset.close()

0 commit comments

Comments
 (0)
0