| 22 | + ) -> None: |
| 23 | + self.account_id = account_id |
| 24 | + self.region_name = region_name |
| 25 | + super().__init__(port, host) |
| 26 | + |
| 27 | + def do_start_thread(self) -> FuncThread: |
| 28 | + cmd = self.generate_shell_command() |
| 29 | + env_vars = self.generate_env_vars() |
| 30 | + cwd = stepfunctions_local_package.get_installed_dir() |
| 31 | + LOG.debug("Starting StepFunctions process %s with env vars %s", cmd, env_vars) |
| 32 | + t = ShellCommandThread( |
| 33 | + cmd, |
| 34 | + strip_color=True, |
| 35 | + env_vars=env_vars, |
| 36 | + log_listener=self._log_listener, |
| 37 | + name="stepfunctions", |
| 38 | + cwd=cwd, |
42 | 39 | )
|
43 |
| - cmd += f" --lambda-endpoint {lambda_endpoint}" |
44 |
| - # add service endpoint flags |
45 |
| - services = [ |
46 |
| - "athena", |
47 |
| - "batch", |
48 |
| - "dynamodb", |
49 |
| - "ecs", |
50 |
| - "eks", |
51 |
| - "events", |
52 |
| - "glue", |
53 |
| - "sagemaker", |
54 |
| - "sns", |
55 |
| - "sqs", |
56 |
| - "stepfunctions", |
57 |
| - ] |
58 |
| - for service in services: |
59 |
| - flag = f"--{service}-endpoint" |
60 |
| - if service == "stepfunctions": |
61 |
| - flag = "--step-functions-endpoint" |
62 |
| - elif service == "events": |
63 |
| - flag = "--eventbridge-endpoint" |
64 |
| - elif service in ["athena", "eks"]: |
65 |
| - flag = f"--step-functions-{service}" |
66 |
| - endpoint = aws_stack.get_local_service_url(service) |
67 |
| - cmd += f" {flag} {endpoint}" |
68 |
| - |
69 |
| - return cmd |
70 |
| - |
71 |
| - |
72 |
| -def start_stepfunctions(asynchronous: bool = True, persistence_path: str | None = None): |
73 |
| - # TODO: introduce Server abstraction for StepFunctions process |
74 |
| - global PROCESS_THREAD |
75 |
| - backend_port = config.LOCAL_PORT_STEPFUNCTIONS |
76 |
| - stepfunctions_local_package.install() |
77 |
| - cmd = get_command(backend_port) |
78 |
| - log_startup_message("StepFunctions") |
79 |
| - # TODO: change ports in stepfunctions.jar, then update here |
80 |
| - PROCESS_THREAD = do_run( |
81 |
| - cmd, |
82 |
| - asynchronous, |
83 |
| - strip_color=True, |
84 |
| - env_vars={ |
| 40 | + TMP_THREADS.append(t) |
| 41 | + t.start() |
| 42 | + return t |
| 43 | + |
| 44 | + def generate_env_vars(self) -> Dict[str, Any]: |
| 45 | + return { |
85 | 46 | "EDGE_PORT": config.EDGE_PORT_HTTP or config.EDGE_PORT,
|
86 | 47 | "EDGE_PORT_HTTP": config.EDGE_PORT_HTTP or config.EDGE_PORT,
|
87 |
| - "DATA_DIR": persistence_path or config.dirs.data, |
88 |
| - }, |
89 |
| - ) |
90 |
| - return PROCESS_THREAD |
91 |
| - |
92 |
| - |
93 |
| -def wait_for_stepfunctions(): |
94 |
| - retry(check_stepfunctions, sleep=0.5, retries=15) |
95 |
| - |
96 |
| - |
97 |
| -def stop_stepfunctions(): |
98 |
| - if not PROCESS_THREAD or not PROCESS_THREAD.process: |
99 |
| - return |
100 |
| - LOG.debug("Restarting StepFunctions process ...") |
101 |
| - |
102 |
| - pid = PROCESS_THREAD.process.pid |
103 |
| - PROCESS_THREAD.stop() |
104 |
| - wait_for_port_closed(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=0.5, retries=15) |
105 |
| - try: |
106 |
| - # TODO: currently failing in CI (potentially due to a defunct process) - need to investigate! |
107 |
| - wait_for_process_to_be_killed(pid, sleep=0.3, retries=10) |
108 |
| - except Exception as e: |
109 |
| - LOG.warning("StepFunctions process not properly terminated: %s", e) |
110 |
| - |
111 |
| - |
112 |
| -def check_stepfunctions(expect_shutdown: bool = False, print_error: bool = False) -> None: |
113 |
| - out = None |
114 |
| - try: |
115 |
| - wait_for_port_open(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=2) |
116 |
| - endpoint_url = f"http://127.0.0.1:{config.LOCAL_PORT_STEPFUNCTIONS}" |
117 |
| - out = connect_to(endpoint_url=endpoint_url).stepfunctions.list_state_machines() |
118 |
| - except Exception: |
119 |
| - if print_error: |
120 |
| - LOG.exception("StepFunctions health check failed") |
121 |
| - |
122 |
| - if expect_shutdown: |
123 |
| - assert out is None |
124 |
| - else: |
125 |
| - assert out and isinstance(out.get("stateMachines"), list) |
| 48 | + "DATA_DIR": config.dirs.data, |
| 49 | + "PORT": self._port, |
| 50 | + } |
| 51 | + |
| 52 | + def generate_shell_command(self) -> str: |
| 53 | + cmd = ( |
| 54 | + "java " |
| 55 | + "-javaagent:aspectjweaver-1.9.7.jar " |
| 56 | + "-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml " |
| 57 | + "-Dcom.amazonaws.sdk.disableCertChecking " |
| 58 | + "-Xmx%s " |
| 59 | + "-jar StepFunctionsLocal.jar " |
| 60 | + "--aws-account %s " |
| 61 | + "--region %s" |
| 62 | + ) % ( |
| 63 | + MAX_HEAP_SIZE, |
| 64 | + self.account_id, |
| 65 | + self.region_name, |
| 66 | + ) |
| 67 | + |
| 68 | + if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default": |
| 69 | + lambda_endpoint = ( |
| 70 | + config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url("lambda") |
| 71 | + ) |
| 72 | + cmd += f" --lambda-endpoint {lambda_endpoint}" |
| 73 | + |
| 74 | + # add service endpoint flags |
| 75 | + services = [ |
| 76 | + "athena", |
| 77 | + "batch", |
| 78 | + "dynamodb", |
| 79 | + "ecs", |
| 80 | + "eks", |
| 81 | + "events", |
| 82 | + "glue", |
| 83 | + "sagemaker", |
| 84 | + "sns", |
| 85 | + "sqs", |
| 86 | + "stepfunctions", |
| 87 | + ] |
| 88 | + |
| 89 | + for service in services: |
| 90 | + flag = f"--{service}-endpoint" |
| 91 | + if service == "stepfunctions": |
| 92 | + flag = "--step-functions-endpoint" |
| 93 | + elif service == "events": |
| 94 | + flag = "--eventbridge-endpoint" |
| 95 | + elif service in ["athena", "eks"]: |
| 96 | + flag = f"--step-functions-{service}" |
| 97 | + endpoint = aws_stack.get_local_service_url(service) |
| 98 | + cmd += f" {flag} {endpoint}" |
| 99 | + |
| 100 | + return cmd |
| 101 | + |
| 102 | + def _log_listener(self, line, **kwargs): |
| 103 | + LOG.debug(line.rstrip()) |
| 104 | + |
| 105 | + |
| 106 | +class StepFunctionsServerManager: |
| 107 | + default_startup_timeout = 20 |
| 108 | + |
| 109 | + def __init__(self): |
| 110 | + self._lock = threading.RLock() |
| 111 | + self._servers: dict[tuple[str, str], StepFunctionsServer] = {} |
| 112 | + |
| 113 | + def get_server_for_account_region( |
| 114 | + self, account_id: str, region_name: str |
| 115 | + ) -> StepFunctionsServer: |
| 116 | + locator = (account_id, region_name) |
| 117 | + |
| 118 | + if locator in self._servers: |
| 119 | + return self._servers[locator] |
| 120 | + |
| 121 | + with self._lock: |
| 122 | + if locator in self._servers: |
| 123 | + return self._servers[locator] |
| 124 | + |
| 125 | + LOG.info("Creating StepFunctions server for %s", locator) |
| 126 | + self._servers[locator] = self._create_stepfunctions_server(account_id, region_name) |
| 127 | + |
| 128 | + self._servers[locator].start() |
| 129 | + |
| 130 | + if not self._servers[locator].wait_is_up(timeout=self.default_startup_timeout): |
| 131 | + raise TimeoutError("Gave up waiting for StepFunctions server to start up") |
| 132 | + |
| 133 | + return self._servers[locator] |
| 134 | + |
| 135 | + def shutdown_all(self): |
| 136 | + with self._lock: |
| 137 | + while self._servers: |
| 138 | + locator, server = self._servers.popitem() |
| 139 | + LOG.info("Shutting down StepFunctions for %s", locator) |
| 140 | + server.shutdown() |
| 141 | + |
| 142 | + def _create_stepfunctions_server( |
| 143 | + self, account_id: str, region_name: str |
| 144 | + ) -> StepFunctionsServer: |
| 145 | + port = get_free_tcp_port() |
| 146 | + stepfunctions_local_package.install() |
| 147 | + |
| 148 | + server = StepFunctionsServer( |
| 149 | + port=port, |
| 150 | + account_id=account_id, |
| 151 | + region_name=region_name, |
| 152 | + ) |
| 153 | + return server |
0 commit comments