|
1 | 1 | import logging
|
2 |
| -import subprocess |
| 2 | +import threading |
| 3 | +from typing import Any, Dict |
3 | 4 |
|
4 | 5 | from localstack import config
|
5 |
| -from localstack.aws.accounts import get_aws_account_id |
6 |
| -from localstack.aws.connect import connect_to |
7 |
| -from localstack.services.infra import do_run, log_startup_message |
8 | 6 | from localstack.services.stepfunctions.packages import stepfunctions_local_package
|
9 | 7 | from localstack.utils.aws import aws_stack
|
10 |
| -from localstack.utils.common import wait_for_port_open |
11 |
| -from localstack.utils.net import wait_for_port_closed |
12 |
| -from localstack.utils.run import ShellCommandThread, wait_for_process_to_be_killed |
13 |
| -from localstack.utils.sync import retry |
| 8 | +from localstack.utils.net import get_free_tcp_port |
| 9 | +from localstack.utils.run import ShellCommandThread |
| 10 | +from localstack.utils.serving import Server |
| 11 | +from localstack.utils.threads import TMP_THREADS, FuncThread |
14 | 12 |
|
15 | 13 | LOG = logging.getLogger(__name__)
|
16 | 14 |
|
17 | 15 | # max heap size allocated for the Java process
|
18 | 16 | MAX_HEAP_SIZE = "256m"
|
19 | 17 |
|
20 |
| -# todo: will be replaced with plugin mechanism |
21 |
| -PROCESS_THREAD: ShellCommandThread | subprocess.Popen | None = None |
22 |
| - |
23 |
| - |
24 |
| -# TODO: pass env more explicitly |
25 |
| -def get_command(backend_port): |
26 |
| - install_dir_stepfunctions = stepfunctions_local_package.get_installed_dir() |
27 |
| - cmd = ( |
28 |
| - "cd %s; PORT=%s java " |
29 |
| - "-javaagent:aspectjweaver-1.9.7.jar " |
30 |
| - "-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " |
31 |
| - "-Dcom.amazonaws.sdk.disableCertChecking -Xmx%s " |
32 |
| - "-jar StepFunctionsLocal.jar --aws-account %s" |
33 |
| - ) % ( |
34 |
| - install_dir_stepfunctions, |
35 |
| - backend_port, |
36 |
| - MAX_HEAP_SIZE, |
37 |
| - get_aws_account_id(), |
38 |
| - ) |
39 |
| - if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default": |
40 |
| - lambda_endpoint = config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url( |
41 |
| - "lambda" |
| 18 | + |
| 19 | +class StepFunctionsServer(Server): |
| 20 | + def __init__(self, port: int, account_id: str, host: str = "localstack") -> None: |
| 21 | + self.account_id = account_id |
| 22 | + super().__init__(port, host) |
| 23 | + |
| 24 | + def do_start_thread(self) -> FuncThread: |
| 25 | + cmd = self.generate_shell_command() |
| 26 | + env_vars = self.generate_env_vars() |
| 27 | + LOG.debug("Starting StepFunctions process %s with env vars %s", cmd, env_vars) |
| 28 | + t = ShellCommandThread( |
| 29 | + cmd, |
| 30 | + strip_color=True, |
| 31 | + env_vars=env_vars, |
| 32 | + log_listener=self._log_listener, |
| 33 | + name="stepfunctions", |
42 | 34 | )
|
43 |
| - cmd += f" --lambda-endpoint {lambda_endpoint}" |
44 |
| - # add service endpoint flags |
45 |
| - services = [ |
46 |
| - "athena", |
47 |
| - "batch", |
48 |
| - "dynamodb", |
49 |
| - "ecs", |
50 |
| - "eks", |
51 |
| - "events", |
52 |
| - "glue", |
53 |
| - "sagemaker", |
54 |
| - "sns", |
55 |
| - "sqs", |
56 |
| - "stepfunctions", |
57 |
| - ] |
58 |
| - for service in services: |
59 |
| - flag = f"--{service}-endpoint" |
60 |
| - if service == "stepfunctions": |
61 |
| - flag = "--step-functions-endpoint" |
62 |
| - elif service == "events": |
63 |
| - flag = "--eventbridge-endpoint" |
64 |
| - elif service in ["athena", "eks"]: |
65 |
| - flag = f"--step-functions-{service}" |
66 |
| - endpoint = aws_stack.get_local_service_url(service) |
67 |
| - cmd += f" {flag} {endpoint}" |
68 |
| - |
69 |
| - return cmd |
70 |
| - |
71 |
| - |
72 |
| -def start_stepfunctions(asynchronous: bool = True, persistence_path: str | None = None): |
73 |
| - # TODO: introduce Server abstraction for StepFunctions process |
74 |
| - global PROCESS_THREAD |
75 |
| - backend_port = config.LOCAL_PORT_STEPFUNCTIONS |
76 |
| - stepfunctions_local_package.install() |
77 |
| - cmd = get_command(backend_port) |
78 |
| -<
F438
div class="diff-text-inner"> log_startup_message("StepFunctions") |
79 |
| - # TODO: change ports in stepfunctions.jar, then update here |
80 |
| - PROCESS_THREAD = do_run( |
81 |
| - cmd, |
82 |
| - asynchronous, |
83 |
| - strip_color=True, |
84 |
| - env_vars={ |
| 35 | + TMP_THREADS.append(t) |
| 36 | + t.start() |
| 37 | + return t |
| 38 | + |
| 39 | + def generate_env_vars(self) -> Dict[str, Any]: |
| 40 | + return { |
85 | 41 | "EDGE_PORT": config.EDGE_PORT_HTTP or config.EDGE_PORT,
|
86 | 42 | "EDGE_PORT_HTTP": config.EDGE_PORT_HTTP or config.EDGE_PORT,
|
87 |
| - "DATA_DIR": persistence_path or config.dirs.data, |
88 |
| - }, |
89 |
| - ) |
90 |
| - return PROCESS_THREAD |
91 |
| - |
92 |
| - |
93 |
| -def wait_for_stepfunctions(): |
94 |
| - retry(check_stepfunctions, sleep=0.5, retries=15) |
95 |
| - |
96 |
| - |
97 |
| -def stop_stepfunctions(): |
98 |
| - if not PROCESS_THREAD or not PROCESS_THREAD.process: |
99 |
| - return |
100 |
| - LOG.debug("Restarting StepFunctions process ...") |
101 |
| - |
102 |
| - pid = PROCESS_THREAD.process.pid |
103 |
| - PROCESS_THREAD.stop() |
104 |
| - wait_for_port_closed(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=0.5, retries=15) |
105 |
| - try: |
106 |
| - # TODO: currently failing in CI (potentially due to a defunct process) - need to investigate! |
107 |
| - wait_for_process_to_be_killed(pid, sleep=0.3, retries=10) |
108 |
| - except Exception as e: |
109 |
| - LOG.warning("StepFunctions process not properly terminated: %s", e) |
110 |
| - |
111 |
| - |
112 |
| -def check_stepfunctions(expect_shutdown: bool = False, print_error: bool = False) -> None: |
113 |
| - out = None |
114 |
| - try: |
115 |
| - wait_for_port_open(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=2) |
116 |
| - endpoint_url = f"http://127.0.0.1:{config.LOCAL_PORT_STEPFUNCTIONS}" |
117 |
| - out = connect_to(endpoint_url=endpoint_url).stepfunctions.list_state_machines() |
118 |
| - except Exception: |
119 |
| - if print_error: |
120 |
| - LOG.exception("StepFunctions health check failed") |
121 |
| - |
122 |
| - if expect_shutdown: |
123 |
| - assert out is None |
124 |
| - else: |
125 |
| - assert out and isinstance(out.get("stateMachines"), list) |
| 43 | + "DATA_DIR": config.dirs.data, |
| 44 | + } |
| 45 | + |
| 46 | + def generate_shell_command(self, port: int, account_id: str) -> str: |
| 47 | + install_dir_stepfunctions = stepfunctions_local_package.get_installed_dir() |
| 48 | + |
| 49 | + cmd = ( |
| 50 | + "cd %s; PORT=%s java " |
| 51 | + "-javaagent:aspectjweaver-1.9.7.jar " |
| 52 | + "-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " |
| 53 | + "-Dcom.amazonaws.sdk.disableCertChecking -Xmx%s " |
| 54 | + "-jar StepFunctionsLocal.jar --aws-account %s" |
| 55 | + ) % ( |
| 56 | + install_dir_stepfunctions, |
| 57 | + port, |
| 58 | + MAX_HEAP_SIZE, |
| 59 | + account_id, |
| 60 | + ) |
| 61 | + |
| 62 | + if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default": |
| 63 | + lambda_endpoint = ( |
| 64 | + config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url("lambda") |
| 65 | + ) |
| 66 | + cmd += f" --lambda-endpoint {lambda_endpoint}" |
| 67 | + |
| 68 | + # add service endpoint flags |
| 69 | + services = [ |
| 70 | + "athena", |
| 71 | + "batch", |
| 72 | + "dynamodb", |
| 73 | + "ecs", |
| 74 | + "eks", |
| 75 | + "events", |
| 76 | + "glue", |
| 77 | + "sagemaker", |
| 78 | + "sns", |
| 79 | + "sqs", |
| 80 | + "stepfunctions", |
| 81 | + ] |
| 82 | + |
| 83 | + for service in services: |
| 84 | + flag = f"--{service}-endpoint" |
| 85 | + if service == "stepfunctions": |
| 86 | + flag = "--step-functions-endpoint" |
| 87 | + elif service == "events": |
| 88 | + flag = "--eventbridge-endpoint" |
| 89 | + elif service in ["athena", "eks"]: |
| 90 | + flag = f"--step-functions-{service}" |
| 91 | + endpoint = aws_stack.get_local_service_url(service) |
| 92 | + cmd += f" {flag} {endpoint}" |
| 93 | + |
| 94 | + return cmd |
| 95 | + |
| 96 | + def _log_listener(self, line, **kwargs): |
| 97 | + LOG.debug(line.rstrip()) |
| 98 | + |
| 99 | + |
| 100 | +class StepFunctionsServerManager: |
| 101 | + def __init__(self): |
| 102 | + self._lock = threading.RLock() |
| 103 | + self._servers = dict[str, StepFunctionsServer] = {} |
| 104 | + |
| 105 | + def get_server_for_account(self, account_id: str) -> StepFunctionsServer: |
| 106 | + if account_id in self._servers: |
| 107 | + return self._servers[account_id] |
| 108 | + |
| 109 | + with self._lock: |
| 110 | + if account_id in self._servers: |
| 111 | + return self._servers[account_id] |
| 112 | + |
| 113 | + LOG.info("Creating StepFunctions server for account %s", account_id) |
| 114 | + self._servers[account_id] = self._create_stepfunctions_server(account_id) |
| 115 | + self._servers[account_id].start() |
| 116 | + if not self._servers[account_id].wait_is_up(timeout=self.default_startup_timeout): |
| 117 | + raise TimeoutError("gave up waiting for StepFunctions server to start up") |
| 118 | + return self._servers[account_id] |
| 119 | + |
| 120 | + def shutdown_all(self): |
| 121 | + with self._lock: |
| 122 | + while self._servers: |
| 123 | + account_id, server = self._servers.popitem() |
| 124 | + LOG.info("Shutting down StepFunctions for account %s", account_id) |
| 125 | + server.shutdown() |
| 126 | + |
| 127 | + def _create_stepfunctions_server(self, account_id: str) -> StepFunctionsServer: |
| 128 | + port = get_free_tcp_port() |
| 129 | + stepfunctions_local_package.install() |
| 130 | + |
| 131 | + server = StepFunctionsServer( |
| 132 | + port=port, |
| 133 | + account_id=account_id, |
| 134 | + ) |
| 135 | + return server |
0 commit comments