13
13
# limitations under the License.
14
14
15
15
16
- from typing import Dict , List , Optional
16
+ from typing import List
17
+ from typing import Optional
18
+ from typing import override
19
+ from typing import Union
17
20
18
21
import yaml
19
22
23
+ from ...agents .readonly_context import ReadonlyContext
20
24
from ...auth .auth_credential import AuthCredential
21
25
from ...auth .auth_schemes import AuthScheme
26
+ from ..base_toolset import BaseToolset
27
+ from ..base_toolset import ToolPredicate
22
28
from ..openapi_tool .common .common import to_snake_case
23
29
from ..openapi_tool .openapi_spec_parser .openapi_toolset import OpenAPIToolset
24
30
from ..openapi_tool .openapi_spec_parser .rest_api_tool import RestApiTool
25
31
from .clients .apihub_client import APIHubClient
26
32
27
33
28
- class APIHubToolset :
34
+ class APIHubToolset ( BaseToolset ) :
29
35
"""APIHubTool generates tools from a given API Hub resource.
30
36
31
37
Examples:
@@ -34,16 +40,13 @@ class APIHubToolset:
34
40
apihub_toolset = APIHubToolset(
35
41
apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
36
42
service_account_json="...",
43
+ tool_filter=lambda tool, ctx=None: tool.name in ('my_tool',
44
+ 'my_other_tool')
37
45
)
38
46
39
47
# Get all available tools
40
- agent = LlmAgent(tools=apihub_toolset.get_tools() )
48
+ agent = LlmAgent(tools=apihub_toolset)
41
49
42
- # Get a specific tool
43
- agent = LlmAgent(tools=[
44
- ...
45
- apihub_toolset.get_tool('my_tool'),
46
- ])
47
50
```
48
51
49
52
**apihub_resource_name** is the resource name from API Hub. It must include
@@ -70,6 +73,7 @@ def __init__(
70
73
auth_credential : Optional [AuthCredential ] = None ,
71
74
# Optionally, you can provide a custom API Hub client
72
75
apihub_client : Optional [APIHubClient ] = None ,
76
+ tool_filter : Optional [Union [ToolPredicate , List [str ]]] = None ,
73
77
):
74
78
"""Initializes the APIHubTool with the given parameters.
75
79
@@ -81,12 +85,17 @@ def __init__(
81
85
)
82
86
83
87
# Get all available tools
84
- agent = LlmAgent(tools=apihub_toolset.get_tools() )
88
+ agent = LlmAgent(tools=[ apihub_toolset] )
85
89
90
+ apihub_toolset = APIHubToolset(
91
+ apihub_resource_name="projects/test-project/locations/us-central1/apis/test-api",
92
+ service_account_json="...",
93
+ tool_filter = ['my_tool']
94
+ )
86
95
# Get a specific tool
87
96
agent = LlmAgent(tools=[
88
- ...
89
- apihub_toolset.get_tool('my_tool') ,
97
+ ...,
98
+ apihub_toolset,
90
99
])
91
100
```
92
101
@@ -118,6 +127,9 @@ def __init__(
118
127
lazy_load_spec: If True, the spec will be loaded lazily when needed.
119
128
Otherwise, the spec will be loaded immediately and the tools will be
120
129
generated during initialization.
130
+ tool_filter: The filter used to filter the tools in the toolset. It can
131
+ be either a tool predicate or a list of tool names of the tools to
132
+ expose.
121
133
"""
122
134
self .name = name
123
135
self .description = description
@@ -128,82 +140,51 @@ def __init__(
128
140
service_account_json = service_account_json ,
129
141
)
130
142
131
- self .generated_tools : Dict [ str , RestApiTool ] = {}
143
+ self .openapi_toolset = None
132
144
self .auth_scheme = auth_scheme
133
145
self .auth_credential = auth_credential
146
+ self .tool_filter = tool_filter
134
147
135
148
if not self .lazy_load_spec :
136
- self ._prepare_tools ()
137
-
138
- def get_tool (self , name : str ) -> Optional [RestApiTool ]:
139
- """Retrieves a specific tool by its name.
140
-
141
- Example:
142
- ```
143
- apihub_tool = apihub_toolset.get_tool('my_tool')
144
- ```
145
-
146
- Args:
147
- name: The name of the tool to retrieve.
148
-
149
- Returns:
150
- The tool with the given name, or None if no such tool exists.
151
- """
152
- if not self ._are_tools_ready ():
153
- self ._prepare_tools ()
149
+ self ._prepare_toolset ()
154
150
155
- return self .generated_tools [name ] if name in self .generated_tools else None
156
-
157
- def get_tools (self ) -> List [RestApiTool ]:
151
+ @override
152
+ async def get_tools (
153
+ self , readonly_context : Optional [ReadonlyContext ] = None
154
+ ) -> List [RestApiTool ]:
158
155
"""Retrieves all available tools.
159
156
160
157
Returns:
161
158
A list of all available RestApiTool objects.
162
159
"""
163
- if not self ._are_tools_ready ():
164
- self ._prepare_tools ()
165
-
166
- return list (self .generated_tools .values ())
167
-
168
- def _are_tools_ready (self ) -> bool :
169
- return not self .lazy_load_spec or self .generated_tools
170
-
171
- def _prepare_tools (self ) -> str :
172
- """Fetches the spec from API Hub and generates the tools.
160
+ if not self .openapi_toolset :
161
+ self ._prepare_toolset ()
162
+ if not self .openapi_toolset :
163
+ return []
164
+ return await self .openapi_toolset .get_tools (readonly_context )
173
165
174
- Returns:
175
- True if the tools are ready, False otherwise.
176
- """
166
+ def _prepare_toolset (self ) -> None :
167
+ """Fetches the spec from API Hub and generates the toolset."""
177
168
# For each API, get the first version and the first spec of that version.
178
- spec = self .apihub_client .get_spec_content (self .apihub_resource_name )
179
- self .generated_tools : Dict [str , RestApiTool ] = {}
180
-
181
- tools = self ._parse_spec_to_tools (spec )
182
- for tool in tools :
183
- self .generated_tools [tool .name ] = tool
184
-
185
- def _parse_spec_to_tools (self , spec_str : str ) -> List [RestApiTool ]:
186
- """Parses the spec string to a list of RestApiTool.
187
-
188
- Args:
189
- spec_str: The spec string to parse.
190
-
191
- Returns:
192
- A list of RestApiTool objects.
193
- """
169
+ spec_str = self .apihub_client .get_spec_content (self .apihub_resource_name )
194
170
spec_dict = yaml .safe_load (spec_str )
195
171
if not spec_dict :
196
- return []
172
+ return
197
173
198
174
self .name = self .name or to_snake_case (
199
175
spec_dict .get ('info' , {}).get ('title' , 'unnamed' )
200
176
)
201
177
self .description = self .description or spec_dict .get ('info' , {}).get (
202
178
'description' , ''
203
179
)
204
- tools = OpenAPIToolset (
180
+ self . openapi_toolset = OpenAPIToolset (
205
181
spec_dict = spec_dict ,
206
182
auth_credential = self .auth_credential ,
207
183
auth_scheme = self .auth_scheme ,
208
- ).get_tools ()
209
- return tools
184
+ tool_filter = self .tool_filter ,
185
+ )
186
+
187
+ @override
188
+ async def close (self ):
189
+ if self .openapi_toolset :
190
+ await self .openapi_toolset .close ()
0 commit comments