|
| 1 | +from typing import Optional, Protocol, TYPE_CHECKING, List, Type, Dict, Any, TypeVar, ParamSpec, Callable, Generic |
| 2 | + |
| 3 | +from strands.types.tools import AgentTool |
| 4 | + |
| 5 | + |
| 6 | +if TYPE_CHECKING: |
| 7 | + from strands import Agent |
| 8 | + |
| 9 | +T = TypeVar('T') |
| <
8000
code>10 | + |
| 11 | +class Reference(Generic[T]): |
| 12 | + def __init__(self, value: T) -> None: |
| 13 | + self.value = value |
| 14 | + |
| 15 | +class AgentInitialized(Protocol): |
| 16 | + def __call__(self, *, agent: "Agent") -> None: ... |
| 17 | + |
| 18 | +class ToolTransformer(Protocol): |
| 19 | + def __call__(self, agent: "Agent", tool: Reference[AgentTool]) -> None: ... |
| 20 | + |
| 21 | +class AgentHook(Protocol): |
| 22 | + def register_hooks(self, hooks: "AgentHookManager", agent: "Agent") -> None: ... |
| 23 | + |
| 24 | + |
| 25 | +class AgentHookManager: |
| 26 | + |
| 27 | + registered_hooks: Dict[Type, List[Any]] = {} |
| 28 | + |
| 29 | + def __init__(self, agent: "Agent", hooks: Optional[List[AgentHook]] = None) -> None: |
| 30 | + self.agent = agent |
| 31 | + self.hooks = hooks |
| 32 | + |
| 33 | + def add(self, hook: AgentHook): |
| 34 | + hook.register_hooks(hooks=self, agent=self.agent) |
| 35 | + |
| 36 | + def add_hook(self, hook_type: T, callback: T): |
| 37 | + if hook_type not in self.registered_hooks: |
| 38 | + self.registered_hooks[hook_type] = [] |
| 39 | + |
| 40 | + self.registered_hooks[hook_type].append(callback) |
| 41 | + |
| 42 | + def get_hook(self, hook_type: Type[T]) -> T: |
| 43 | + return lambda *args, **kwargs: self._invoke_hook(hook_type, *args, **kwargs) |
| 44 | + |
| 45 | + def _invoke_hook(self, hook_type: Type[Callable], *args: Any, **kwargs: Any) -> None: |
| 46 | + if hook_type not in self.registered_hooks: |
| 47 | + return |
| 48 | + |
| 49 | + for hook in self.registered_hooks[hook_type]: |
| 50 | + hook(*args, **kwargs) |
| 51 | + |
| 52 | + return |
| 53 | + |
| 54 | + |
| 55 | + |
| 56 | + |
0 commit comments