8000 adapt application integration toolset to new toolset interface · codefromthecrypt/adk-python@57d1315 · GitHub
[go: up one dir, main page]

Skip to content

Commit 57d1315

Browse files
seanzhougooglecopybara-github
authored andcommitted
adapt application integration toolset to new toolset interface
PiperOrigin-RevId: 757960706
1 parent d19927b commit 57d1315

File tree

1 file changed

+31
-54
lines changed

1 file changed

+31
-54
lines changed

src/google/adk/tools/application_integration_tool/application_integration_toolset.py

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict, List, Optional
15+
from typing import List
16+
from typing import Optional
17+
from typing import override
18+
from typing import Union
1619

1720
from fastapi.openapi.models import HTTPBearer
1821

1922
from ...auth.auth_credential import AuthCredential
2023
from ...auth.auth_credential import AuthCredentialTypes
2124
from ...auth.auth_credential import ServiceAccount
2225
from ...auth.auth_credential import ServiceAccountCredential
26+
from ..base_toolset import ToolPredicate
2327
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
2428
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
2529
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
@@ -63,10 +67,10 @@ class ApplicationIntegrationToolset:
6367
service_account_credentials={...},
6468
)
6569
66-
# Get all available tools
70+
# Feed the toolset to agent
6771
agent = LlmAgent(tools=[
68-
...
69-
*application_integration_toolset.get_tools(),
72+
...,
73+
application_integration_toolset,
7074
])
7175
```
7276
"""
@@ -87,46 +91,10 @@ def __init__(
8791
# tool/python function description.
8892
tool_instructions: Optional[str] = "",
8993
service_account_json: Optional[str] = None,
94+
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
9095
):
91-
"""Initializes the ApplicationIntegrationToolset.
92-
93-
Example Usage:
94-
```
95-
# Get all available tools for an integration with api trigger
96-
application_integration_toolset = ApplicationIntegrationToolset(
97-
98-
project="test-project",
99-
location="us-central1"
100-
integration="test-integration",
101-
trigger="api_trigger/test_trigger",
102-
service_account_credentials={...},
103-
)
104-
105-
# Get all available tools for a connection using entity operations and
106-
# actions
107-
# Note: Find the list of supported entity operations and actions for a
108-
connection
109-
# using integration connector apis:
110-
#
111-
https://cloud.google.com/integration-connectors/docs/reference/rest/v1/projects.locations.connections.connectionSchemaMetadata
112-
application_integration_toolset = ApplicationIntegrationToolset(
113-
project="test-project",
114-
location="us-central1"
< 10000 /td>
115-
connection="test-connection",
116-
entity_operations=["EntityId1": ["LIST","CREATE"], "EntityId2": []],
117-
#empty list for actions means all operations on the entity are supported
118-
actions=["action1"],
119-
service_account_credentials={...},
120-
)
96+
"""Args:
12197
122-
# Get all available tools
123-
agent = LlmAgent(tools=[
124-
...
125-
*application_integration_toolset.get_tools(),
126-
])
127-
```
128-
129-
Args:
13098
project: The GCP project ID.
13199
location: The GCP location.
132100
integration: The integration name.
@@ -139,6 +107,9 @@ def __init__(
139107
service_account_json: The service account configuration as a dictionary.
140108
Required if not using default service credential. Used for fetching
141109
the Application Integration or Integration Connector resource.
110+
tool_filter: The filter used to filter the tools in the toolset. It can
111+
be either a tool predicate or a list of tool names of the tools to
112+
expose.
142113
143114
Raises:
144115
ValueError: If neither integration and trigger nor connection and
@@ -156,7 +127,7 @@ def __init__(
156127
self.tool_name = tool_name
157128
self.tool_instructions = tool_instructions
158129
self.service_account_json = service_account_json
159-
self.generated_tools: Dict[str, RestApiTool] = {}
130+
self.tool_filter = tool_filter
160131

161132
integration_client = IntegrationClient(
162133
project,
@@ -185,10 +156,12 @@ def __init__(
185156
"Either (integration and trigger) or (connection and"
186157
" (entity_operations or actions)) should be provided."
187158
)
188-
self._parse_spec_to_tools(spec, connection_details)
159+
self.openapi_toolset = None
160+
self.tool = None
161+
self._parse_spec_to_toolset(spec, connection_details)
189162

190-
def _parse_spec_to_tools(self, spec_dict, connection_details):
191-
"""Parses the spec dict to a list of RestApiTool."""
163+
def _parse_spec_to_toolset(self, spec_dict, connection_details):
164+
"""Parses the spec dict to OpenAPI toolset."""
192165
if self.service_account_json:
193166
sa_credential = ServiceAccountCredential.model_validate_json(
194167
self.service_account_json
@@ -211,13 +184,12 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
211184
auth_scheme = HTTPBearer(bearerFormat="JWT")
212185

213186
if self.integration and self.trigger:
214-
tools = OpenAPIToolset(
187+
self.openapi_toolset = OpenAPIToolset(
215188
spec_dict=spec_dict,
216189
auth_credential=auth_credential,
217190
auth_scheme=auth_scheme,
218-
).get_tools()
219-
for tool in tools:
220-
self.generated_tools[tool.name] = tool
191+
tool_filter=self.tool_filter,
192+
)
221193
return
222194

223195
operations = OpenApiSpecParser().parse(spec_dict)
@@ -235,7 +207,7 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
235207
rest_api_tool.configure_auth_scheme(auth_scheme)
236208
if auth_credential:
237209
rest_api_tool.configure_auth_credential(auth_credential)
238-
tool = IntegrationConnectorTool(
210+
self.tool = IntegrationConnectorTool(
239211
name=rest_api_tool.name,
240212
description=rest_api_tool.description,
241213
connection_name=connection_details["name"],
@@ -246,7 +218,12 @@ def _parse_spec_to_tools(self, spec_dict, connection_details):
246218
operation=operation,
247219
rest_api_tool=rest_api_tool,
248220
)
249-
self.generated_tools[tool.name] = tool
250221

251-
def get_tools(self) -> List[RestApiTool]:
252-
return list(self.generated_tools.values())
222+
@override
223+
async def get_tools(self) -> List[RestApiTool]:
224+
return [self.tool] if self.tool else await self.openapi_toolset.get_tools()
225+
226+
@override
227+
async def close(self) -> None:
228+
if self.openapi_toolset:
229+
await self.openapi_toolset.close()

0 commit comments

Comments
 (0)
0