8000 Implement account ID namespacing for legacy stepfunctions provider · localstack/localstack@14eb9a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 14eb9a2

Browse files
Implement account ID namespacing for legacy stepfunctions provider
1 parent 2682b40 commit 14eb9a2

File tree

5 files changed

+200
-151
lines changed

5 files changed

+200
-151
lines changed

localstack/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,8 +941,6 @@ def legacy_fallback(envar_name: str, default: T) -> T:
941941
# DEV: sbx_user1051 (default when not provided) Alternative system user or empty string to skip dropping privileges.
942942
LAMBDA_INIT_USER = os.environ.get("LAMBDA_INIT_USER")
943943

944-
# Adding Stepfunctions default port
945-
LOCAL_PORT_STEPFUNCTIONS = int(os.environ.get("LOCAL_PORT_STEPFUNCTIONS") or 8083)
946944
# Stepfunctions lambda endpoint override
947945
STEPFUNCTIONS_LAMBDA_ENDPOINT = os.environ.get("STEPFUNCTIONS_LAMBDA_ENDPOINT", "").strip()
948946

localstack/services/stepfunctions/provider.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import threading
44

55
from localstack import config
6+
from localstack.aws.accounts import get_aws_account_id
67
from localstack.aws.api import RequestContext, handler
78
from localstack.aws.api.stepfunctions import (
89
CreateStateMachineInput,
@@ -16,12 +17,9 @@
1617
from localstack.aws.forwarder import get_request_forwarder_http
1718
from localstack.constants import LOCALHOST
1819
from localstack.services.plugins import ServiceLifecycleHook
19-
from localstack.services.stepfunctions.stepfunctions_starter import (
20-
start_stepfunctions,
21-
stop_stepfunctions,
22-
wait_for_stepfunctions,
23-
)
20+
from localstack.services.stepfunctions.stepfunctions_starter import StepFunctionsServerManager
2421
from localstack.state import AssetDirectory, StateVisitor
22+
from localstack.utils.aws import aws_stack
2523

2624
# lock to avoid concurrency issues when creating state machines in parallel (required for StepFunctions-Local)
2725
CREATION_LOCK = threading.RLock()
@@ -30,33 +28,29 @@
3028

3129

3230
class StepFunctionsProvider(StepfunctionsApi, ServiceLifecycleHook):
31+
server_manager = StepFunctionsServerManager()
32+
3333
def __init__(self):
3434
self.forward_request = get_request_forwarder_http(self.get_forward_url)
3535

3636
def get_forward_url(self) -> str:
3737
"""Return the URL of the backend StepFunctions server to forward requests to"""
38-
return f"http://{LOCALHOST}:{config.LOCAL_PORT_STEPFUNCTIONS}"
38+
account_id = get_aws_account_id()
39+
region_name = aws_stack.get_region()
40+
server = self.server_manager.get_server_for_account_region(account_id, region_name)
41+
return f"http://{LOCALHOST}:{server.port}"
3942

4043
def accept_state_visitor(self, visitor: StateVisitor):
4144
visitor.visit(AssetDirectory(os.path.join(config.dirs.data, self.service)))
4245

43-
def on_before_start(self):
44-
start_stepfunctions()
45-
wait_for_stepfunctions()
46-
47-
def on_before_state_reset(self):
48-
stop_stepfunctions()
49-
5046
def on_before_state_load(self):
51-
stop_stepfunctions()
47+
self.server_manager.shutdown_all()
5248

53-
def on_after_state_reset(self):
54-
start_stepfunctions()
55-
wait_for_stepfunctions()
49+
def on_before_state_reset(self):
50+
self.server_manager.shutdown_all()
5651

57-
def on_after_state_load(self):
58-
start_stepfunctions()
59-
wait_for_stepfunctions()
52+
def on_before_stop(self):
53+
self.server_manager.shutdown_all()
6054

6155
def create_state_machine(
6256
self, context: RequestContext, request: CreateStateMachineInput
Lines changed: 139 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1,125 +1,153 @@
11
import logging
2-
import subprocess
2+
import threading
3+
from typing import Any, Dict
34

45
from localstack import config
5-
from localstack.aws.accounts import get_aws_account_id
6-
from localstack.aws.connect import connect_to
7-
from localstack.services.infra import d F438 o_run, log_startup_message
86
from localstack.services.stepfunctions.packages import stepfunctions_local_package
97
from localstack.utils.aws import aws_stack
10-
from localstack.utils.common import wait_for_port_open
11-
from localstack.utils.net import wait_for_port_closed
12-
from localstack.utils.run import ShellCommandThread, wait_for_process_to_be_killed
13-
from localstack.utils.sync import retry
8+
from localstack.utils.net import get_free_tcp_port
9+
from localstack.utils.run import ShellCommandThread
10+
from localstack.utils.serving import Server
11+
from localstack.utils.threads import TMP_THREADS, FuncThread
1412

1513
LOG = logging.getLogger(__name__)
1614

1715
# max heap size allocated for the Java process
1816
MAX_HEAP_SIZE = "256m"
1917

20-
# todo: will be replaced with plugin mechanism
21-
PROCESS_THREAD: ShellCommandThread | subprocess.Popen | None = None
22-
23-
24-
# TODO: pass env more explicitly
25-
def get_command(backend_port):
26-
install_dir_stepfunctions = stepfunctions_local_package.get_installed_dir()
27-
cmd = (
28-
"cd %s; PORT=%s java "
29-
"-javaagent:aspectjweaver-1.9.7.jar "
30-
"-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml "
31-
"-Dcom.amazonaws.sdk.disableCertChecking -Xmx%s "
32-
"-jar StepFunctionsLocal.jar --aws-account %s"
33-
) % (
34-
install_dir_stepfunctions,
35-
backend_port,
36-
MAX_HEAP_SIZE,
37-
get_aws_account_id(),
38-
)
39-
if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default":
40-
lambda_endpoint = config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url(
41-
"lambda"
18+
19+
class StepFunctionsServer(Server):
20+
def __init__(
21+
self, port: int, account_id: str, region_name: str, host: str = "localhost"
22+
) -> None:
23+
self.account_id = account_id
24+
self.region_name = region_name
25+
super().__init__(port, host)
26+
27+
def do_start_thread(self) -> FuncThread:
28+
cmd = self.generate_shell_command()
29+
env_vars = self.generate_env_vars()
30+
cwd = stepfunctions_local_package.get_installed_dir()
31+
LOG.debug("Starting StepFunctions process %s with env vars %s", cmd, env_vars)
32+
t = ShellCommandThread(
33+
cmd,
34+
strip_color=True,
35+
env_vars=env_vars,
36+
log_listener=self._log_listener,
37+
name="stepfunctions",
38+
cwd=cwd,
4239
)
43-
cmd += f" --lambda-endpoint {lambda_endpoint}"
44-
# add service endpoint flags
45-
services = [
46-
"athena",
47-
"batch",
48-
"dynamodb",
49-
"ecs",
50-
"eks",
51-
"events",
52-
"glue",
53-
"sagemaker",
54-
"sns",
55-
"sqs",
56-
"stepfunctions",
57-
]
58-
for service in services:
59-
flag = f"--{service}-endpoint"
60-
if service == "stepfunctions":
61-
flag = "--step-functions-endpoint"
62-
elif service == "events":
63-
flag = "--eventbridge-endpoint"
64-
elif service in ["athena", "eks"]:
65-
flag = f"--step-functions-{service}"
66-
endpoint = aws_stack.get_local_service_url(service)
67-
cmd += f" {flag} {endpoint}"
68-
69-
return cmd
70-
71-
72-
def start_stepfunctions(asynchronous: bool = True, persistence_path: str | None = None):
73-
# TODO: introduce Server abstraction for StepFunctions process
74-
global PROCESS_THREAD
75-
backend_port = config.LOCAL_PORT_STEPFUNCTIONS
76-
stepfunctions_local_package.install()
77-
cmd = get_command(backend_port)
78-
log_startup_message("StepFunctions")
79-
# TODO: change ports in stepfunctions.jar, then update here
80-
PROCESS_THREAD = do_run(
81-
cmd,
82-
asynchronous,
83-
strip_color=True,
84-
env_vars={
40+
TMP_THREADS.append(t)
41+
t.start()
42+
return t
43+
44+
def generate_env_vars(self) -> Dict[str, Any]:
45+
return {
8546
"EDGE_PORT": config.EDGE_PORT_HTTP or config.EDGE_PORT,
8647
"EDGE_PORT_HTTP": config.EDGE_PORT_HTTP or config.EDGE_PORT,
87-
"DATA_DIR": persistence_path or config.dirs.data,
88-
},
89-
)
90-
return PROCESS_THREAD
91-
92-
93-
def wait_for_stepfunctions():
94-
retry(check_stepfunctions, sleep=0.5, retries=15)
95-
96-
97-
def stop_stepfunctions():
98-
if not PROCESS_THREAD or not PROCESS_THREAD.process:
99-
return
100-
LOG.debug("Restarting StepFunctions process ...")
101-
102-
pid = PROCESS_THREAD.process.pid
103-
PROCESS_THREAD.stop()
104-
wait_for_port_closed(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=0.5, retries=15)
105-
try:
106-
# TODO: currently failing in CI (potentially due to a defunct process) - need to investigate!
107-
wait_for_process_to_be_killed(pid, sleep=0.3, retries=10)
108-
except Exception as e:
109-
LOG.warning("StepFunctions process not properly terminated: %s", e)
110-
111-
112-
def check_stepfunctions(expect_shutdown: bool = False, print_error: bool = False) -> None:
113-
out = None
114-
try:
115-
wait_for_port_open(config.LOCAL_PORT_STEPFUNCTIONS, sleep_time=2)
116-
endpoint_url = f"http://127.0.0.1:{config.LOCAL_PORT_STEPFUNCTIONS}"
117-
out = connect_to(endpoint_url=endpoint_url).stepfunctions.list_state_machines()
118-
except Exception:
119-
if print_error:
120-
LOG.exception("StepFunctions health check failed")
121-
122-
if expect_shutdown:
123-
assert out is None
124-
else:
125-
assert out and isinstance(out.get("stateMachines"), list)
48+
"DATA_DIR": config.dirs.data,
49+
"PORT": self._port,
50+
}
51+
52+
def generate_shell_command(self) -> str:
53+
cmd = (
54+
"java "
55+
"-javaagent:aspectjweaver-1.9.7.jar "
56+
"-Dorg.aspectj.weaver.loadtime.configuration=META-INF/aop.xml "
57+
"-Dcom.amazonaws.sdk.disableCertChecking "
58+
"-Xmx%s "
59+
"-jar StepFunctionsLocal.jar "
60+
"--aws-account %s "
61+
"--region %s"
62+
) % (
63+
MAX_HEAP_SIZE,
64+
self.account_id,
65+
self.region_name,
66+
)
67+
68+
if config.STEPFUNCTIONS_LAMBDA_ENDPOINT.lower() != "default":
69+
lambda_endpoint = (
70+
config.STEPFUNCTIONS_LAMBDA_ENDPOINT or aws_stack.get_local_service_url("lambda")
71+
)
72+
cmd += f" --lambda-endpoint {lambda_endpoint}"
73+
74+
# add service endpoint flags
75+
services = [
76+
"athena",
77+
"batch",
78+
"dynamodb",
79+
"ecs",
80+
"eks",
81+
"events",
82+
"glue",
83+
"sagemaker",
84+
"sns",
85+
"sqs",
86+
"stepfunctions",
87+
]
88+
89+
for service in services:
90+
flag = f"--{service}-endpoint"
91+
if service == "stepfunctions":
92+
flag = "--step-functions-endpoint"
93+
elif service == "events":
94+
flag = "--eventbridge-endpoint"
95+
elif service in ["athena", "eks"]:
96+
flag = f"--step-functions-{service}"
97+
endpoint = aws_stack.get_local_service_url(service)
98+
cmd += f" {flag} {endpoint}"
99+
100+
return cmd
101+
102+
def _log_listener(self, line, **kwargs):
103+
LOG.debug(line.rstrip())
104+
105+
106+
class StepFunctionsServerManager:
107+
default_startup_timeout = 20
108+
109+
def __init__(self):
110+
self._lock = threading.RLock()
111+
self._servers: dict[tuple[str, str], StepFunctionsServer] = {}
112+
113+
def get_server_for_account_region(
114+
self, account_id: str, region_name: str
115+
) -> StepFunctionsServer:
116+
locator = (account_id, region_name)
117+
118+
if locator in self._servers:
119+
return self._servers[locator]
120+
121+
with self._lock:
122+
if locator in self._servers:
123+
return self._servers[locator]
124+
125+
LOG.info("Creating StepFunctions server for %s", locator)
126+
self._servers[locator] = self._create_stepfunctions_server(account_id, region_name)
127+
128+
self._servers[locator].start()
129+
130+
if not self._servers[locator].wait_is_up(timeout=self.default_startup_timeout):
131+
raise TimeoutError("Gave up waiting for StepFunctions server to start up")
132+
133+
return self._servers[locator]
134+
135+
def shutdown_all(self):
136+
with self._lock:
137+
while self._servers:
138+
locator, server = self._servers.popitem()
139+
LOG.info("Shutting down StepFunctions for %s", locator)
140+
server.shutdown()
141+
142+
def _create_stepfunctions_server(
143+
self, account_id: str, region_name: str
144+
) -> StepFunctionsServer:
145+
port = get_free_tcp_port()
146+
stepfunctions_local_package.install()
147+
148+
server = StepFunctionsServer(
149+
port=port,
150+
account_id=account_id,
151+
region_name=region_name,
152+
)
153+
return server

0 commit comments

Comments
 (0)
0