8000 Fix: Enable underscores in direct method invocations to match hyphens… · NMsby/sdk-python@aff928c · GitHub
[go: up one dir, main page]

Skip to content

Commit aff928c

Browse files
authored
Fix: Enable underscores in direct method invocations to match hyphens (strands-agents#178)
Enable direct method tool invocations of `example_tool` to match tools with the name of `example-tool`, which fixes strands-agents#139. In the case where no direct match is found but multiple tools would match, we throw an error to avoid ambiguous errors Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 9006105 commit aff928c

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

src/strands/agent/agent.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
8484
"""Call tool as a function.
8585
8686
This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
87+
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').
8788
8889
Args:
8990
name: The name of the attribute (tool) being accessed.
@@ -92,9 +93,34 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
9293
A function that when called will execute the named tool.
9394
9495
Raises:
95-
AttributeError: If no tool with the given name exists.
96+
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
9697
"""
9798

99+
def find_normalized_tool_name() -> Optional[str]:
100+
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
101+
tool_registry = self._agent.tool_registry.registry
102+
103+
if tool_registry.get(name, None):
104+
return name
105+
106+
# If the desired name contains underscores, it might be a placeholder for characters that can't be
107+
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
108+
# all tools that can be represented with the normalized name
109+
if "_" in name:
110+
filtered_tools = [
111+
tool_name
112+
for (tool_name, tool) in tool_registry.items()
113+
if tool_name.replace("-", "_") == name
114+
]
115+
116+
if len(filtered_tools) > 1:
117+
raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}")
118+
119+
if filtered_tools:
120+
return filtered_tools[0]
121+
122+
raise AttributeError(f"Tool '{name}' not found")
123+
98124
def caller(**kwargs: Any) -> Any:
99125
"""Call a tool directly by name.
100126
@@ -115,14 +141,13 @@ def caller(**kwargs: Any) -> Any:
115141
Raises:
116142
AttributeError: If the tool doesn't exist.
117143
"""
118-
if name not in self._agent.tool_registry.registry:
119-
raise AttributeError(f"Tool '{name}' not found")
144+
normalized_name = find_normalized_tool_name()
120145

121146
# Create unique tool ID and set up the tool request
122147
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
123148
tool_use = {
124149
"toolUseId": tool_id,
125-
"name": name,
150+
"name": normalized_name,
126151
"input": kwargs.copy(),
127152
}
128153

tests/strands/agent/test_agent.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,68 @@ def function(system_prompt: str) -> str:
674674
)
675675

676676

677+
def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint):
678+
agent.tool_handler = unittest.mock.Mock()
679+
680+
tool_name = "system-prompter"
681+
682+
@strands.tools.tool(name=tool_name)
683+
def function(system_prompt: str) -> str:
684+
return system_prompt
685+
686+
tool = strands.tools.tools.FunctionTool(function)
687+
agent.tool_registry.register_tool(tool)
688+
689+
mock_randint.return_value = 1
690+
691+
agent.tool.system_prompter(system_prompt="tool prompt")
692+
693+
# Verify the correct tool was invoked
694+
assert agent.tool_handler.process.call_count == 1
695+
tool_call = agent.tool_handler.process.call_args.kwargs.get("tool")
696+
697+
assert tool_call == {
698+
# Note that the tool-use uses the "python safe" name
699+
"toolUseId": "tooluse_system_prompter_1",
700+
# But the name of the tool is the one in the registry
701+
"name": tool_name,
702+
"input": {"system_prompt": "tool prompt"},
703+
}
704+
705+
706+
def test_agent_tool_with_multiple_normalized_matches(agent, tool_registry, mock_randint):
707+
agent.tool_handler = unittest.mock.Mock()
708+
709+
@strands.tools.tool(name="system-prompter_1")
710+
def function1(system_prompt: str) -> str:
711+
return system_prompt
712+
713+
@strands.tools.tool(name="system-prompter-1")
714+
def function2(system_prompt: str) -> str:
715+
return system_prompt
716+
717+
agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function1))
718+
agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function2))
719+
720+
mock_randint.return_value = 1
721+
722+
with pytest.raises(AttributeError) as err:
723+
agent.tool.system_prompter_1(system_prompt="tool prompt")
724+
725+
assert str(err.value) == "Multiple tools matching 'system_prompter_1' found: system-prompter_1, system-prompter-1"
726+
727+
728+
def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint):
729+
agent.tool_handler = unittest.mock.Mock()
730+
731+
mock_randint.return_value = 1
732+
733+
with pytest.raises(AttributeError) as err:
734+
agent.tool.system_prompter_1(system_prompt="tool prompt")
735+
736+
assert str(err.value) == "Tool 'system_prompter_1' not found"
737+
738+
677739
def test_agent_with_none_callback_handler_prints_nothing():
678740
agent = Agent()
679741

0 commit comments

Comments
 (0)
0