12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from typing import Dict , List , Optional
15
+ from typing import List
16
+ from typing import Optional
17
+ from typing import override
18
+ from typing import Union
16
19
17
20
from fastapi .openapi .models import HTTPBearer
18
21
19
22
from ...auth .auth_credential import AuthCredential
20
23
from ...auth .auth_credential import AuthCredentialTypes
21
24
from ...auth .auth_credential import ServiceAccount
22
25
from ...auth .auth_credential import ServiceAccountCredential
26
+ from ..base_toolset import ToolPredicate
23
27
from ..openapi_tool .auth .auth_helpers import service_account_scheme_credential
24
28
from ..openapi_tool .openapi_spec_parser .openapi_spec_parser import OpenApiSpecParser
25
29
from ..openapi_tool .openapi_spec_parser .openapi_toolset import OpenAPIToolset
@@ -63,10 +67,10 @@ class ApplicationIntegrationToolset:
63
67
service_account_credentials={...},
64
68
)
65
69
66
- # Get all available tools
70
+ # Feed the toolset to agent
67
71
agent = LlmAgent(tools=[
68
- ...
69
- * application_integration_toolset.get_tools() ,
72
+ ...,
73
+ application_integration_toolset,
70
74
])
71
75
```
72
76
"""
@@ -87,46 +91,10 @@ def __init__(
87
91
# tool/python function description.
88
92
tool_instructions : Optional [str ] = "" ,
89
93
service_account_json : Optional [str ] = None ,
94
+ tool_filter : Optional [Union [ToolPredicate , List [str ]]] = None ,
90
95
):
91
- """Initializes the ApplicationIntegrationToolset.
92
-
93
- Example Usage:
94
- ```
95
- # Get all available tools for an integration with api trigger
96
- application_integration_toolset = ApplicationIntegrationToolset(
97
-
98
- project="test-project",
99
- location="us-central1"
100
- integration="test-integration",
101
- trigger="api_trigger/test_trigger",
102
- service_account_credentials={...},
103
- )
104
-
105
- # Get all available tools for a connection using entity operations and
106
- # actions
107
- # Note: Find the list of supported entity operations and actions for a
108
- connection
109
- # using integration connector apis:
110
- #
111
- https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
112
- application_integration_toolset = ApplicationIntegrationToolset(
113
- project="test-project",
114
- location="us-central1"
<
10000
/td>115
- connection="test-connection",
116
- entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
117
- #empty list for actions means all operations on the entity are supported
118
- actions=["action1"],
119
- service_account_credentials={...},
120
- )
96
+ """Args:
121
97
122
- # Get all available tools
123
- agent = LlmAgent(tools=[
124
- ...
125
- *application_integration_toolset.get_tools(),
126
- ])
127
- ```
128
-
129
- Args:
130
98
project: The GCP project ID.
131
99
location: The GCP location.
132
100
integration: The integration name.
@@ -139,6 +107,9 @@ def __init__(
139
107
service_account_json: The service account configuration as a dictionary.
140
108
Required if not using default service credential. Used for fetching
141
109
the Application Integration or Integration Connector resource.
110
+ tool_filter: The filter used to filter the tools in the toolset. It can
111
+ be either a tool predicate or a list of tool names of the tools to
112
+ expose.
142
113
143
114
Raises:
144
115
ValueError: If neither integration and trigger nor connection and
@@ -156,7 +127,7 @@ def __init__(
156
127
self .tool_name = tool_name
157
128
self .tool_instructions = tool_instructions
158
129
self .service_account_json = service_account_json
159
- self .generated_tools : Dict [ str , RestApiTool ] = {}
130
+ self .tool_filter = tool_filter
160
131
161
132
integration_client = IntegrationClient (
162
133
project ,
@@ -185,10 +156,12 @@ def __init__(
185
156
"Either (integration and trigger) or (connection and"
186
157
" (entity_operations or actions)) should be provided."
187
158
)
188
- self ._parse_spec_to_tools (spec , connection_details )
159
+ self .openapi_toolset = None
160
+ self .tool = None
161
+ self ._parse_spec_to_toolset (spec , connection_details )
189
162
190
- def _parse_spec_to_tools (self , spec_dict , connection_details ):
191
- """Parses the spec dict to a list of RestApiTool ."""
163
+ def _parse_spec_to_toolset (self , spec_dict , connection_details ):
164
+ """Parses the spec dict to OpenAPI toolset ."""
192
165
if self .service_account_json :
193
166
sa_credential = ServiceAccountCredential .model_validate_json (
194
167
self .service_account_json
@@ -211,13 +184,12 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
211
184
auth_scheme = HTTPBearer (bearerFormat = "JWT" )
212
185
213
186
if self .integration and self .trigger :
214
- tools = OpenAPIToolset (
187
+ self . openapi_toolset = OpenAPIToolset (
215
188
spec_dict = spec_dict ,
216
189
auth_credential = auth_credential ,
217
190
auth_scheme = auth_scheme ,
218
- ).get_tools ()
219
- for tool in tools :
220
- self .generated_tools [tool .name ] = tool
191
+ tool_filter = self .tool_filter ,
192
+ )
221
193
return
222
194
223
195
operations = OpenApiSpecParser ().parse (spec_dict )
@@ -235,7 +207,7 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
235
207
rest_api_tool .configure_auth_scheme (auth_scheme )
236
208
if auth_credential :
237
209
rest_api_tool .configure_auth_credential (auth_credential )
238
- tool = IntegrationConnectorTool (
210
+ self . tool = IntegrationConnectorTool (
239
211
name = rest_api_tool .name ,
240
212
description = rest_api_tool .description ,
241
213
connection_name = connection_details ["name" ],
@@ -246,7 +218,12 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
246
218
operation = operation ,
247
219
rest_api_tool = rest_api_tool ,
248
220
)
249
- self .generated_tools [tool .name ] = tool
250
221
251
- def get_tools (self ) -> List [RestApiTool ]:
252
- return list (self .generated_tools .values ())
222
+ @override
223
+ async def get_tools (self ) -> List [RestApiTool ]:
224
+ return [self .tool ] if self .tool else await self .openapi_toolset .get_tools ()
225
+
226
+ @override
227
+ async def close (self ) -> None :
228
+ if self .openapi_toolset :
229
+ await self .openapi_toolset .close ()
0 commit comments