8000 Fix: Enable underscores in direct method invocations to match hyphens by zastrowm · Pull Request #178 · strands-agents/sdk-python · GitHub
[go: up one dir, main page]

Skip to content

Fix: Enable underscores in direct method invocations to match hyphens #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').

Args:
name: The name of the attribute (tool) being accessed.
Expand All @@ -92,9 +93,34 @@ def __getattr__(self, name: str) -> Callable[..., Any]:
A function that when called will execute the named tool.

Raises:
AttributeError: If no tool with the given name exists.
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
"""

def find_normalized_tool_name() -> Optional[str]:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
tool_registry = self._agent.tool_registry.registry

if tool_registry.get(name, None):
return name

# If the desired name contains underscores, it might be a placeholder for characters that can't be
# represented as python identifiers but are valid as tool names, such as dashes. In that case, find
# all tools that can be represented with the normalized name
if "_" in name:
filtered_tools = [
tool_name
for (tool_name, tool) in tool_registry.items()
if tool_name.replace("-", "_") == name
]

if len(filtered_tools) > 1:
raise AttributeError(f"Multiple tools matching '{name}' found: {', '.join(filtered_tools)}")

if filtered_tools:
return filtered_tools[0]

raise AttributeError(f"Tool '{name}' not found")

def caller(**kwargs: Any) -> Any:
"""Call a tool directly by name.

Expand All @@ -115,14 +141,13 @@ def caller(**kwargs: Any) -> Any:
Raises:
AttributeError: If the tool doesn't exist.
"""
if name not in self._agent.tool_registry.registry:
raise AttributeError(f"Tool '{name}' not found")
normalized_name = find_normalized_tool_name()

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use = {
"toolUseId": tool_id,
"name": name,
"name": normalized_name,
"input": kwargs.copy(),
}

Expand Down
62 changes: 62 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,68 @@ def function(system_prompt: str) -> str:
)


def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()

tool_name = "system-prompter"

@strands.tools.tool(name=tool_name)
def function(system_prompt: str) -> str:
return system_prompt

tool = strands.tools.tools.FunctionTool(function)
agent.tool_registry.register_tool(tool)

mock_randint.return_value = 1

agent.tool.system_prompter(system_prompt="tool prompt")

# Verify the correct tool was invoked
assert agent.tool_handler.process.call_count == 1
tool_call = agent.tool_handler.process.call_args.kwargs.get("tool")

assert tool_call == {
# Note that the tool-use uses the "python safe" name
"toolUseId": "tooluse_system_prompter_1",
# But the name of the tool is the one in the registry
"name": tool_name,
"input": {"system_prompt": "tool prompt"},
}


def test_agent_tool_with_multiple_normalized_matches(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()

@strands.tools.tool(name="system-prompter_1")
def function1(system_prompt: str) -> str:
return system_prompt

@strands.tools.tool(name="system-prompter-1")
def function2(system_prompt: str) -> str:
return system_prompt

agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function1)) 75EF
agent.tool_registry.register_tool(strands.tools.tools.FunctionTool(function2))

mock_randint.return_value = 1

with pytest.raises(AttributeError) as err:
agent.tool.system_prompter_1(system_prompt="tool prompt")

assert str(err.value) == "Multiple tools matching 'system_prompter_1' found: system-prompter_1, system-prompter-1"


def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint):
agent.tool_handler = unittest.mock.Mock()

mock_randint.return_value = 1

with pytest.raises(AttributeError) as err:
agent.tool.system_prompter_1(system_prompt="tool prompt")

assert str(err.value) == "Tool 'system_prompter_1' not found"


def test_agent_with_none_callback_handler_prints_nothing():
agent = Agent()

Expand Down
0