diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0651d452..5854fba6 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -84,6 +84,7 @@ def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). Args: name: The name of the attribute (tool) being accessed. @@ -92,9 +93,34 @@ def __getattr__(self, name: str) -> Callable[..., Any]: A function that when called will execute the named tool. Raises: - AttributeError: If no tool with the given name exists. + AttributeError: If no tool with the given name exists or if multiple tools match the given name. """ + def find_normalized_tool_name() -> Optional[str]: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name + for (tool_name, tool) in tool_registry.items() + if tool_name.replace("-", "_") == name + ] + + if len(filtered_tools) > 1: + raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}") + + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def caller(**kwargs: Any) -> Any: """Call a tool directly by name. @@ -115,14 +141,13 @@ def caller(**kwargs: Any) -> Any: Raises: AttributeError: If the tool doesn't exist. """ - if name not in self._agent.tool_registry.registry: - raise AttributeError(f"Tool '{name}' not found") + normalized_name = find_normalized_tool_name() # Create unique tool ID and set up the tool request tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" tool_use = { "toolUseId": tool_id, - "name": name, + "name": normalized_name, "input": kwargs.copy(), } diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 0ea20b64..02b1470b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -674,6 +674,68 @@ def function(system_prompt: str) -> str: ) +def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + tool_name = "system-prompter" + + @strands.tools.tool(name=tool_name) + def function(system_prompt: str) -> str: + return system_prompt + + tool = strands.tools.tools.FunctionTool(function) + agent.tool_registry.register_tool(tool) + + mock_randint.return_value = 1 + + agent.tool.system_prompter(system_prompt="tool prompt") + + # Verify the correct tool was invoked + assert agent.tool_handler.process.call_count == 1 + tool_call = agent.tool_handler.process.call_args.kwargs.get("tool") + + assert tool_call == { + # Note that the tool-use uses the "python safe" name + "toolUseId": "tooluse_system_prompter_1", + # But the name of the tool is the one in the registry + "name": tool_name, + "input": {"system_prompt": "tool prompt"}, + } + + +def test_agent_tool_with_multiple_normalized_matches(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + @strands.tools.tool(name="system-prompter_1") + def function1(system_prompt: str) -> str: + return system_prompt + + @strands.tools.tool(name="system-prompter-1") + def function2(system_prompt: str) -> str: + return system_prompt + + agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function1)) + agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function2)) + + mock_randint.return_value = 1 + + with pytest.raises(AttributeError) as err: + agent.tool.system_prompter_1(system_prompt="tool prompt") + + assert str(err.value) == "Multiple tools matching 'system_prompter_1' found: system-prompter_1, system-prompter-1" + + +def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint): + agent.tool_handler = unittest.mock.Mock() + + mock_randint.return_value = 1 + + with pytest.raises(AttributeError) as err: + agent.tool.system_prompter_1(system_prompt="tool prompt") + + assert str(err.value) == "Tool 'system_prompter_1' not found" + + def test_agent_with_none_callback_handler_prints_nothing(): agent = Agent()