8000 Merge pull request #26 from modern-python/21-feature-add-cors-instrument · modern-python/lite-bootstrap@85f9e7c · GitHub
[go: up one dir, main page]

Skip to content

Commit 85f9e7c

Browse files
authored
Merge pull request #26 from modern-python/21-feature-add-cors-instrument
add cors instrument
2 parents 5e05482 + 1e80e82 commit 85f9e7c

File tree

5 files changed

+71
-2
lines changed

5 files changed

+71
-2
lines changed

lite_bootstrap/bootstrappers/fastapi_bootstrapper.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from lite_bootstrap import import_checker
55
from lite_bootstrap.bootstrappers.base import BaseBootstrapper
6+
from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument
67
from lite_bootstrap.instruments.healthchecks_instrument import (
78
HealthChecksConfig,
89
HealthChecksInstrument,
@@ -16,6 +17,7 @@
1617

1718
if import_checker.is_fastapi_installed:
1819
import fastapi
20+
from fastapi.middleware.cors import CORSMiddleware
1921

2022
if import_checker.is_opentelemetry_installed:
2123
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
@@ -26,14 +28,31 @@
2628

2729

2830
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
29-
class FastAPIConfig(HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusConfig, SentryConfig):
31+
class FastAPIConfig(CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusConfig, SentryConfig):
3032
application: "fastapi.FastAPI" = dataclasses.field(default_factory=lambda: fastapi.FastAPI())
3133
opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list)
3234
prometheus_instrumentator_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
3335
prometheus_instrument_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
3436
prometheus_expose_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
3537

3638

39+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
40+
class FastApiCorsInstrument(CorsInstrument):
41+
bootstrap_config: FastAPIConfig
42+
43+
def bootstrap(self) -> None:
44+
self.bootstrap_config.application.add_middleware(
45+
CORSMiddleware,
46+
allow_origins=self.bootstrap_config.cors_allowed_origins,
47+
allow_methods=self.bootstrap_config.cors_allowed_methods,
48+
allow_headers=self.bootstrap_config.cors_allowed_headers,
49+
allow_credentials=self.bootstrap_config.cors_allowed_credentials,
50+
allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex,
51+
expose_headers=self.bootstrap_config.cors_exposed_headers,
52+
max_age=self.bootstrap_config.cors_max_age,
53+
)
54+
55+
3756
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
3857
class FastAPIHealthChecksInstrument(HealthChecksInstrument):
3958
bootstrap_config: FastAPIConfig
@@ -114,6 +133,7 @@ class FastAPIBootstrapper(BaseBootstrapper["fastapi.FastAPI"]):
114133
__slots__ = "bootstrap_config", "instruments"
115134

116135
instruments_types: typing.ClassVar = [
136+
FastApiCorsInstrument,
117137
FastAPIOpenTelemetryInstrument,
118138
FastAPISentryInstrument,
119139
FastAPIHealthChecksInstrument,

lite_bootstrap/bootstrappers/litestar_bootstrapper.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from lite_bootstrap import import_checker
55
from lite_bootstrap.bootstrappers.base import BaseBootstrapper
6+
from lite_bootstrap.instruments.cors_instrument import CorsConfig, CorsInstrument
67
from lite_bootstrap.instruments.healthchecks_instrument import (
78
HealthChecksConfig,
89
HealthChecksInstrument,
@@ -22,6 +23,7 @@
2223
if import_checker.is_litestar_installed:
2324
import litestar
2425
from litestar.config.app import AppConfig
26+
from litestar.config.cors import CORSConfig
2527
from litestar.contrib.opentelemetry import OpenTelemetryConfig
2628
from litestar.plugins.prometheus import PrometheusConfig, PrometheusController
2729

@@ -31,13 +33,29 @@
3133

3234
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
3335
class LitestarConfig(
34-
HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig
36+
CorsConfig, HealthChecksConfig, LoggingConfig, OpentelemetryConfig, PrometheusBootstrapperConfig, SentryConfig
3537
):
3638
application_config: "AppConfig" = dataclasses.field(default_factory=lambda: AppConfig())
3739
opentelemetry_excluded_urls: list[str] = dataclasses.field(default_factory=list)
3840
prometheus_additional_params: dict[str, typing.Any] = dataclasses.field(default_factory=dict)
3941

4042

43+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
44+
class LitestarCorsInstrument(CorsInstrument):
45+
bootstrap_config: LitestarConfig
46+
47+
def bootstrap(self) -> None:
48+
self.bootstrap_config.application_config.cors_config = CORSConfig(
49+
allow_origins=self.bootstrap_config.cors_allowed_origins,
50+
allow_methods=self.bootstrap_config.cors_allowed_methods, # type: ignore[arg-type]
51+
allow_headers=self.bootstrap_config.cors_allowed_headers,
52+
allow_credentials=self.bootstrap_config.cors_allowed_credentials,
53+
allow_origin_regex=self.bootstrap_config.cors_allowed_origin_regex,
54+
expose_headers=self.bootstrap_config.cors_exposed_headers,
55+
max_age=self.bootstrap_config.cors_max_age,
56+
)
57+
58+
4159
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
4260
class LitestarHealthChecksInstrument(HealthChecksInstrument):
4361
bootstrap_config: LitestarConfig
@@ -116,6 +134,7 @@ class LitestarBootstrapper(BaseBootstrapper["litestar.Litestar"]):
116134
__slots__ = "bootstrap_config", "instruments"
117135

118136
instruments_types: typing.ClassVar = [
137+
LitestarCorsInstrument,
119138
LitestarOpenTelemetryInstrument,
120139
LitestarSentryInstrument,
121140
LitestarHealthChecksInstrument,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import dataclasses
2+
3+
from lite_bootstrap.instruments.base import BaseConfig, BaseInstrument
4+
5+
6+
@dataclasses.dataclass(kw_only=True, frozen=True)
7+
class CorsConfig(BaseConfig):
8+
cors_allowed_origins: list[str] = dataclasses.field(default_factory=list)
9+
cors_allowed_methods: list[str] = dataclasses.field(default_factory=list)
10+
cors_allowed_headers: list[str] = dataclasses.field(default_factory=list)
11+
cors_exposed_headers: list[str] = dataclasses.field(default_factory=list)
12+
cors_allowed_credentials: bool = False
13+
cors_allowed_origin_regex: str | None = None
14+
cors_max_age: int = 600
15+
16+
17+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
18+
class CorsInstrument(BaseInstrument):
19+
bootstrap_config: CorsConfig
20+
not_ready_message = "cors_allowed_origins or cors_allowed_origin_regex must be provided"
21+
22+
def is_ready(self) -> bool:
23+
return bool(self.bootstrap_config.cors_allowed_origins) or bool(
24+
self.bootstrap_config.cors_allowed_origin_regex,
25+
)

tests/test_fastapi_bootstrap.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def fastapi_config() -> FastAPIConfig:
1818
service_version="2.0.0",
1919
service_environment="test",
2020
service_debug=False,
21+
cors_allowed_origins=["http://test"],
2122
opentelemetry_endpoint="otl",
2223
opentelemetry_instrumentors=[CustomInstrumentor()],
2324
opentelemetry_span_exporter=ConsoleSpanExporter(),

tests/test_litestar_bootstrap.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def litestar_config() -> LitestarConfig:
1818
service_version="2.0.0",
1919
service_environment="test",
2020
service_debug=False,
21+
cors_allowed_origins=["http://test"],
2122
opentelemetry_endpoint="otl",
2223
opentelemetry_instrumentors=[CustomInstrumentor()],
2324
opentelemetry_span_exporter=ConsoleSpanExporter(),
@@ -35,6 +36,9 @@ def test_litestar_bootstrap(litestar_config: LitestarConfig) -> None:
3536
try:
3637
logger.info("testing logging", key="value")
3738

39+
assert application.cors_config
40+
assert application.cors_config.allow_origins == litestar_config.cors_allowed_origins
41+
3842
with TestClient(app=application) as test_client:
3943
response = test_client.get(litestar_config.health_checks_path)
4044
assert response.status_code == status_codes.HTTP_200_OK

0 commit comments

Comments
 (0)
0