8000 refactor: refactor and refine LangChainTool · google/adk-python@7445417 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7445417

Browse files
seanzhougooglecopybara-github
authored andcommitted
refactor: refactor and refine LangChainTool
PiperOrigin-RevId: 760726719
1 parent ae7d19a commit 7445417

File tree

1 file changed

+96
-49
lines changed

1 file changed

+96
-49
lines changed

src/google/adk/tools/langchain_tool.py

Lines changed: 96 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,74 +13,121 @@
1313
# limitations under the License.
1414

1515
from typing import Any
16-
from typing import Callable
16+
from typing import Optional
17+
from typing import Union
1718

1819
from google.genai import types
19-
from pydantic import model_validator
20+
from langchain.agents import Tool
21+
from langchain_core.tools import BaseTool
2022
from typing_extensions import override
2123

2224
from . import _automatic_function_calling_util
2325
from .function_tool import FunctionTool
2426

2527

2628
class LangchainTool(FunctionTool):
27-
"""Use this class to wrap a langchain tool.
29+
"""Adapter class that wraps a Langchain tool for use with ADK.
2830
29-
If the original tool name and description are not suitable, you can override
30-
them in the constructor.
31+
This adapter converts Langchain tools into a format compatible with Google's
32+
generative AI function calling interface. It preserves the tool's name,
33+
description, and functionality while adapting its schema.
34+
35+
The original tool's name and description can be overridden if needed.
36+
37+
Args:
38+
tool: A Langchain tool to wrap (BaseTool or a tool with a .run method)
39+
name: Optional override for the tool's name
40+
description: Optional override for the tool's description
41+
42+
Examples:
43+
```python
44+
from langchain.tools import DuckDuckGoSearchTool
45+
from google.genai.tools import LangchainTool
46+
47+
search_tool = DuckDuckGoSearchTool()
48+
wrapped_tool = LangchainTool(search_tool)
49+
```
3150
"""
3251

33-
tool: Any
52+
_langchain_tool: Union[BaseTool, object]
3453
"""The wrapped langchain tool."""
3554

36-
def __init__(self, tool: Any):
37-
super().__init__(tool._run)
38-
self.tool = tool
39-
if tool.name:
55+
def __init__(
56+
self,
57+
tool: Union[BaseTool, object],
58+
name: Optional[str] = None,
59+
description: Optional[str] = None,
60+
):
61+
# Check if the tool has a 'run' method
62+
if not hasattr(tool, 'run') and not hasattr(tool, '_run'):
63+
raise ValueError("Langchain tool must have a 'run' or '_run' method")
64+
65+
# Determine which function to use
66+
func = tool._run if hasattr(tool, '_run') else tool.run
67+
super().__init__(func)
68+
69+
self._langchain_tool = tool
70+
71+
# Set name: priority is 1) explicitly provided name, 2) tool's name, 3) default
72+
if name is not None:
73+
self.name = name
74+
elif hasattr(tool, 'name') and tool.name:
4075
self.name = tool.name
41-
if tool.description:
42-
self.description = tool.description
76+
# else: keep default from FunctionTool
4377

44-
@model_validator(mode='before')
45-
@classmethod
46-
def p 8000 opulate_name(cls, data: Any) -> Any:
47-
# Override this to not use function's signature name as it's
48-
# mostly "run" or "invoke" for thir-party tools.
49-
return data
78+
# Set description: similar priority
79+
if description is not None:
80+
self.description = description
81+
elif hasattr(tool, 'description') and tool.description:
82+
self.description = tool.description
83+
# else: keep default from FunctionTool
5084

5185
@override
5286
def _get_declaration(self) -> types.FunctionDeclaration:
53-
"""Build the function declaration for the tool."""
54-
from langchain.agents import Tool
55-
from langchain_core.tools import BaseTool
56-
57-
# There are two types of tools:
58-
# 1. BaseTool: the tool is defined in langchain.tools.
59-
# 2. Other tools: the tool doesn't inherit any class but follow some
60-
# conventions, like having a "run" method.
61-
if isinstance(self.tool, BaseTool):
62-
tool_wrapper = Tool(
63-
name=self.name,
64-
func=self.func,
65-
description=self.description,
66-
)
67-
if self.tool.args_schema:
68-
tool_wrapper.args_schema = self.tool.args_schema
69-
function_declaration = _automatic_function_calling_util.build_function_declaration_for_langchain(
70-
False,
71-
self.name,
72-
self.description,
73-
tool_wrapper.func,
74-
tool_wrapper.args,
75-
)
76-
return function_declaration
77-
else:
87+
"""Build the function declaration for the tool.
88+
89+
Returns:
90+
A FunctionDeclaration object that describes the tool's interface.
91+
92+
Raises:
93+
ValueError: If the tool schema cannot be correctly parsed.
94+
"""
95+
try:
96+
# There are two types of tools:
97+
# 1. BaseTool: the tool is defined in langchain_core.tools.
98+
# 2. Other tools: the tool doesn't inherit any class but follow some
99+
# conventions, like having a "run" method.
100+
# Handle BaseTool type (preferred Langchain approach)
101+
if isinstance(self._langchain_tool, BaseTool):
102+
tool_wrapper = Tool(
103+
name=self.name,
104+
func=self.func,
105+
description=self.description,
106+
)
107+
108+
# Add schema if available
109+
if (
110+
hasattr(self._langchain_tool, 'args_schema')
111+
and self._langchain_tool.args_schema
112+
):
113+
tool_wrapper.args_schema = self._langchain_tool.args_schema
114+
115+
return _automatic_function_calling_util.build_function_declaration_for_langchain(
116+
False,
117+
self.name,
118+
self.description,
119+
tool_wrapper.func,
120+
getattr(tool_wrapper, 'args', None),
121+
)
122+
78123
# Need to provide a way to override the function names and descriptions
79124
# as the original function names are mostly ".run" and the descriptions
80-
# may not meet users' needs.
81-
function_declaration = (
82-
_automatic_function_calling_util.build_function_declaration(
83-
func=self.tool.run,
84-
)
125+
# may not meet users' needs
126+
return _automatic_function_calling_util.build_function_declaration(
127+
func=self._langchain_tool.run,
85128
)
86-
return function_declaration
129+
130+
except Exception as e:
131+
raise ValueError(
132+
f'Failed to build function declaration for Langchain tool: {e}'
133+
) from e

0 commit comments

Comments
 (0)
0