8000 Add in the concept of hooks · zastrowm/sdk-python@c15d390 · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit c15d390

Browse files
committed
Add in the concept of hooks
1 parent f62dcd5 commit c15d390

File tree

3 files changed

+65
-0
lines changed

3 files changed

+65
-0
lines changed

src/strands/agent/agent.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ..event_loop.event_loop import event_loop_cycle
2525
from ..handlers.callback_handler import CompositeCallbackHandler, PrintingCallbackHandler, null_callback_handler
2626
from ..handlers.tool_handler import AgentToolHandler
27+
from ..hooks.agent_hook import AgentHook, AgentHookManager, AgentInitialized
2728
from ..models.bedrock import BedrockModel
2829
from ..telemetry.metrics import EventLoopMetrics
2930
from ..telemetry.tracer import get_tracer
@@ -186,6 +187,7 @@ def __init__(
186187
model: Union[Model, str, None] = None,
187188
messages: Optional[Messages] = None,
188189
tools: Optional[List[Union[str, Dict[str, str], Any]]] = None,
190+
hooks: Optional[List[AgentHook]] = None,
189191
system_prompt: Optional[str] = None,
190192
callback_handler: Optional[
191193
Union[Callable[..., Any], _DefaultCallbackHandlerSentinel]
@@ -289,6 +291,13 @@ def __init__(
289291
self.trace_span: Optional[trace.Span] = None
290292

291293
self.tool_caller = Agent.ToolCaller(self)
294+
self.hooks = AgentHookManager(agent=self)
295+
296+
if hooks is not None:
297+
for hook in hooks:
298+
self.hooks.add(hook)
299+
300+
self.hooks.get_hook(AgentInitialized)(agent=self)
292301

293302
@property
294303
def tool(self) -> ToolCaller:

src/strands/hooks/__init__.py

Whitespace-only changes.

src/strands/hooks/agent_hook.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from typing import Optional, Protocol, TYPE_CHECKING, List, Type, Dict, Any, TypeVar, ParamSpec, Callable, Generic
2+
3+
from strands.types.tools import AgentTool
4+
5+
6+
if TYPE_CHECKING:
7+
from strands import Agent
8+
9+
T = TypeVar('T')
< 8000 code>10+
11+
class Reference(Generic[T]):
12+
def __init__(self, value: T) -> None:
13+
self.value = value
14+
15+
class AgentInitialized(Protocol):
16+
def __call__(self, *, agent: "Agent") -> None: ...
17+
18+
class ToolTransformer(Protocol):
19+
def __call__(self, agent: "Agent", tool: Reference[AgentTool]) -> None: ...
20+
21+
class AgentHook(Protocol):
22+
def register_hooks(self, hooks: "AgentHookManager", agent: "Agent") -> None: ...
23+
24+
25+
class AgentHookManager:
26+
27+
registered_hooks: Dict[Type, List[Any]] = {}
28+
29+
def __init__(self, agent: "Agent", hooks: Optional[List[AgentHook]] = None) -> None:
30+
self.agent = agent
31+
self.hooks = hooks
32+
33+
def add(self, hook: AgentHook):
34+
hook.register_hooks(hooks=self, agent=self.agent)
35+
36+
def add_hook(self, hook_type: T, callback: T):
37+
if hook_type not in self.registered_hooks:
38+
self.registered_hooks[hook_type] = []
39+
40+
self.registered_hooks[hook_type].append(callback)
41+
42+
def get_hook(self, hook_type: Type[T]) -> T:
43+
return lambda *args, **kwargs: self._invoke_hook(hook_type, *args, **kwargs)
44+
45+
def _invoke_hook(self, hook_type: Type[Callable], *args: Any, **kwargs: Any) -> None:
46+
if hook_type not in self.registered_hooks:
47+
return
48+
49+
for hook in self.registered_hooks[hook_type]:
50+
hook(*args, **kwargs)
51+
52+
return
53+
54+
55+
56+

0 commit comments

Comments
 (0)
0