|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import inspect |
| 18 | +import logging |
| 19 | +from typing import Any |
| 20 | +from typing import Callable |
| 21 | +from typing import Dict |
| 22 | +from typing import Optional |
| 23 | +from typing import Union |
| 24 | + |
| 25 | +from typing_extensions import override |
| 26 | + |
| 27 | +from ..auth.auth_credential import AuthCredential |
| 28 | +from ..auth.auth_tool import AuthConfig |
| 29 | +from ..auth.credential_manager import CredentialManager |
| 30 | +from ..utils.feature_decorator import experimental |
| 31 | +from .function_tool import FunctionTool |
| 32 | +from .tool_context import ToolContext |
| 33 | + |
| 34 | +logger = logging.getLogger("google_adk." + __name__) |
| 35 | + |
| 36 | + |
| 37 | +@experimental |
| 38 | +class AuthenticatedFunctionTool(FunctionTool): |
| 39 | + """A FunctionTool that handles authentication before the actual tool logic |
| 40 | + gets called. Functions can accept a special `credential` argument which is the |
| 41 | + credential ready for use.(Experimental) |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__( |
| 45 | + self, |
| 46 | + *, |
| 47 | + func: Callable[..., Any], |
| 48 | + auth_config: AuthConfig = None, |
| 49 | + response_for_auth_required: Optional[Union[dict[str, Any], str]] = None, |
| 50 | + ): |
| 51 | + """Initializes the AuthenticatedFunctionTool. |
| 52 | +
|
| 53 | + Args: |
| 54 | + func: The function to be called. |
| 55 | + auth_config: The authentication configuration. |
| 56 | + response_for_auth_required: The response to return when the tool is |
| 57 | + requesting auth credential from the client. There could be two case, |
| 58 | + the tool doesn't configure any credentials |
| 59 | + (auth_config.raw_auth_credential is missing) or the credentials |
| 60 | + configured is not enough to authenticate the tool (e.g. an OAuth |
| 61 | + client id and client secrect is configured.) and needs client input |
| 62 | + (e.g. client need to involve the end user in an oauth flow and get |
| 63 | + back the oauth response.) |
| 64 | + """ |
| 65 | + super().__init__(func=func) |
| 66 | + self._ignore_params.append("credential") |
| 67 | + |
| 68 | + if auth_config and auth_config.auth_scheme: |
| 69 | + self._credentials_manager = CredentialManager(auth_config=auth_config) |
| 70 | + else: |
| 71 | + logger.warning( |
| 72 | + "auth_config or auth_config.auth_scheme is missing. Will skip" |
| 73 | + " authentication.Using FunctionTool instead if authentication is not" |
| 74 | + " required." |
| 75 | + ) |
| 76 | + self._credentials_manager = None |
| 77 | + self._response_for_auth_required = response_for_auth_required |
| 78 | + |
| 79 | + @override |
| 80 | + async def run_async( |
| 81 | + self, *, args: dict[str, Any], tool_context: ToolContext |
| 82 | + ) -> Any: |
| 83 | + credential = None |
| 84 | + if self._credentials_manager: |
| 85 | + credential = await self._credentials_manager.get_auth_credential( |
| 86 | + tool_context |
| 87 | + ) |
| 88 | + if not credential: |
| 89 | + await self._credentials_manager.request_credential(tool_context) |
| 90 | + return self._response_for_auth_required or "Pending User Authorization." |
| 91 | + |
| 92 | + return await self._run_async_impl( |
| 93 | + args=args, tool_context=tool_context, credential=credential |
| 94 | + ) |
| 95 | + |
| 96 | + async def _run_async_impl( |
| 97 | + self, |
| 98 | + *, |
| 99 | + args: dict[str, Any], |
| 100 | + tool_context: ToolContext, |
| 101 | + credential: AuthCredential, |
| 102 | + ) -> Any: |
| 103 | + args_to_call = args.copy() |
| 104 | + signature = inspect.signature(self.func) |
| 105 | + if "credential" in signature.parameters: |
| 106 | + args_to_call["credential"] = credential |
| 107 | + return await super().run_async(args=args_to_call, tool_context=tool_context) |
0 commit comments