8000 Add `*args` to `Middleware` and improve its type hints (#2381) · encode/starlette@866a15f · GitHub
[go: up one dir, main page]

Skip to content

Commit 866a15f

Browse files
pawelrubinPaweł Rubin
and
Paweł Rubin
authored
Add *args to Middleware and improve its type hints (#2381)
Co-authored-by: Paweł Rubin <pawel.rubin@ocado.com>
1 parent 23c81da commit 866a15f

File tree

7 files changed

+77
-36
lines changed

7 files changed

+77
-36
lines changed

starlette/applications.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import typing
44
import warnings
55

6+
from typing_extensions import ParamSpec
7+
68
from starlette.datastructures import State, URLPath
7-
from starlette.middleware import Middleware
9+
from starlette.middleware import Middleware, _MiddlewareClass
810
from starlette.middleware.base import BaseHTTPMiddleware
911
from starlette.middleware.errors import ServerErrorMiddleware
1012
from starlette.middleware.exceptions import ExceptionMiddleware
@@ -15,6 +17,7 @@
1517
from starlette.websockets import WebSocket
1618

1719
AppType = typing.TypeVar("AppType", bound="Starlette")
20+
P = ParamSpec("P")
1821

1922

2023
class Starlette:
@@ -98,8 +101,8 @@ def build_middleware_stack(self) -> ASGIApp:
98101
)
99102

100103
app = self.router
101-
for cls, options in reversed(middleware):
102-
app = cls(app=app, **options)
104+
for cls, args, kwargs in reversed(middleware):
105+
app = cls(app=app, *args, **kwargs)
103106
return app
104107

105108
@property
@@ -124,10 +127,15 @@ def mount(self, path: str, app: ASGIApp, name: str | None = None) -> None:
124127
def host(self, host: str, app: ASGIApp, name: str | None = None) -> None:
125128
self.router.host(host, app=app, name=name) # pragma: no cover
126129

127-
def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
130+
def add_middleware(
131+
self,
132+
middleware_class: typing.Type[_MiddlewareClass[P]],
133+
*args: P.args,
134+
**kwargs: P.kwargs,
135+
) -> None:
128136
if self.middleware_stack is not None: # pragma: no cover
129137
raise RuntimeError("Cannot add middleware after an application has started")
130-
self.user_middleware.insert(0, Middleware(middleware_class, **options))
138+
self.user_middleware.insert(0, Middleware(middleware_class, *args, **kwargs))
131139

132140
def add_exception_handler(
133141
self,

starlette/middleware/__init__.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,38 @@
1-
import typing
1+
from typing import Any, Iterator, Protocol, Type
2+
3+
from typing_extensions import ParamSpec
4+
5+
from starlette.types import ASGIApp, Receive, Scope, Send
6+
7+
P = ParamSpec("P")
8+
9+
10+
class _MiddlewareClass(Protocol[P]):
11+
def __init__(self, app: ASGIApp, *args: P.args, **kwargs: P.kwargs) -> None:
12+
... # pragma: no cover
13+
14+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
15+
... # pragma: no cover
216

317

418
class Middleware:
5-
def __init__(self, cls: type, **options: typing.Any) -> None:
19+
def __init__(
20+
self,
21+
cls: Type[_MiddlewareClass[P]],
22+
*args: P.args,
23+
**kwargs: P.kwargs,
24+
) -> None:
625
self.cls = cls
7-
self.options = options
26+
self.args = args
27+
self.kwargs = kwargs
828

9-
def __iter__(self) -> typing.Iterator[typing.Any]:
10-
as_tuple = (self.cls, self.options)
29+
def __iter__(self) -> Iterator[Any]:
30+
as_tuple = (self.cls, self.args, self.kwargs)
1131
return iter(as_tuple)
1232

1333
def __repr__(self) -> str:
1434
class_name = self.__class__.__name__
15-
option_strings = [f"{key}={value!r}" for key, value in self.options.items()]
16-
args_repr = ", ".join([self.cls.__name__] + option_strings)
35+
args_strings = [f"{value!r}" for value in self.args]
36+
option_strings = [f"{key}={value!r}" for key, value in self.kwargs.items()]
37+
args_repr = ", ".join([self.cls.__name__] + args_strings + option_strings)
1738
return f"{class_name}({args_repr})"

starlette/routing.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def __init__(
238238
self.app = endpoint
239239

240240
if middleware is not None:
241-
for cls, options in reversed(middleware):
242-
self.app = cls(app=self.app, **options)
241+
for cls, args, kwargs in reversed(middleware):
242+
self.app = cls(app=self.app, *args, **kwargs)
243243

244244
if methods is None:
245245
self.methods = None
@@ -335,8 +335,8 @@ def __init__(
335335
self.app = endpoint
336336

337337
if middleware is not None:
338-
for cls, options in reversed(middleware):
339-
self.app = cls(app=self.app, **options)
338+
for cls, args, kwargs in reversed(middleware):
339+
self.app = cls(app=self.app, *args, **kwargs)
340340

341341
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
342342

@@ -404,8 +404,8 @@ def __init__(
404404
self._base_app = Router(routes=routes)
405405
self.app = self._base_app
406406
if middleware is not None:
407-
for cls, options in reversed(middleware):
408-
self.app = cls(app=self.app, **options)
407+
for cls, args, kwargs in reversed(middleware):
408+
self.app = cls(app=self.app, *args, **kwargs)
409409
self.name = name
410410
self.path_regex, self.path_format, self.param_convertors = compile_path(
411411
self.path + "/{path:path}"
@@ -672,8 +672,8 @@ def __init__(
672672

673673
self.middleware_stack = self.app
674674
if middleware:
675-
for cls, options in reversed(middleware):
676-
self.middleware_stack = cls(self.middleware_stack, **options)
675+
for cls, args, kwargs in reversed(middleware):
676+
self.middleware_stack = cls(self.middleware_stack, *args, **kwargs)
677677

678678
async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
679679
if scope["type"] == "websocket":

tests/middleware/test_base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import contextvars
22
from contextlib import AsyncExitStack
3-
from typing import AsyncGenerator, Awaitable, Callable, List, Union
3+
from typing import Any, AsyncGenerator, Awaitable, Callable, List, Type, Union
44

55
import anyio
66
import pytest
77

88
from starlette.applications import Starlette
99
from starlette.background import BackgroundTask
10-
from starlette.middleware import Middleware
10+
from starlette.middleware import Middleware, _MiddlewareClass
1111
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
1212
from starlette.requests import Request
1313
from starlette.responses import PlainTextResponse, Response, StreamingResponse
@@ -196,7 +196,7 @@ async def dispatch(self, request, call_next):
196196
),
197197
],
198198
)
199-
def test_contextvars(test_client_factory, middleware_cls: type):
199+
def test_contextvars(test_client_factory, middleware_cls: Type[_MiddlewareClass[Any]]):
200200
# this has to be an async endpoint because Starlette calls run_in_threadpool
201201
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
202202
# contextvars (it propagates them forwards but not backwards)

tests/middleware/test_middleware.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
from starlette.middleware import Middleware
2+
from starlette.types import ASGIApp, Receive, Scope, Send
23

34

4-
class CustomMiddleware:
5-
pass
5+
class CustomMiddleware: # pragma: no cover
6+
def __init__(self, app: ASGIApp, foo: str, *, bar: int) -> None:
7+
self.app = app
8+
self.foo = foo
9< F438 /code>+
self.bar = bar
610

11+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
12+
await self.app(scope, receive, send)
713

8-
def test_middleware_repr():
9-
middleware = Middleware(CustomMiddleware)
10-
assert repr(middleware) == "Middleware(CustomMiddleware)"
14+
15+
def test_middleware_repr() -> None:
16+
middleware = Middleware(CustomMiddleware, "foo", bar=123)
17+
assert repr(middleware) == "Middleware(CustomMiddleware, 'foo', bar=123)"
18+
19+
20+
def test_middleware_iter() -> None:
21+
cls, args, kwargs = Middleware(CustomMiddleware, "foo", bar=123)
22+
assert (cls, args, kwargs) == (CustomMiddleware, ("foo",), {"bar": 123})

tests/test_applications.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from contextlib import asynccontextmanager
3-
from typing import Any, AsyncIterator, Callable
3+
from typing import AsyncIterator, Callable
44

55
import anyio
66
import httpx
@@ -15,7 +15,7 @@
1515
from starlette.responses import JSONResponse, PlainTextResponse
1616
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
1717
from starlette.staticfiles import StaticFiles
18-
from starlette.types import ASGIApp
18+
from starlette.types import ASGIApp, Receive, Scope, Send
1919
from starlette.websockets import WebSocket
2020

2121

@@ -499,8 +499,8 @@ class NoOpMiddleware:
499499
def __init__(self, app: ASGIApp):
500500
self.app = app
501501

502-
async def __call__(self, *args: Any):
503-
await self.app(*args)
502+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
503+
await self.app(scope, receive, send)
504504

505505
class SimpleInitializableMiddleware:
506506
counter = 0
@@ -509,8 +509,8 @@ def __init__(self, app: ASGIApp):
509509
self.app = app
510510
SimpleInitializableMiddleware.counter += 1
511511

512-
async def __call__(self, *args: Any):
513-
await self.app(*args)
512+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
513+
await self.app(scope, receive, send)
514514

515515
def get_app() -> ASGIApp:
516516
app = Starlette()

tests/test_authentication.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from starlette.endpoints import HTTPEndpoint
1616
from starlette.middleware import Middleware
1717
from starlette.middleware.authentication import AuthenticationMiddleware
18-
from starlette.requests import Request
18+
from starlette.requests import HTTPConnection
1919
from starlette.responses import JSONResponse
2020
from starlette.routing import Route, WebSocketRoute
2121
from starlette.websockets import WebSocketDisconnect
@@ -327,7 +327,7 @@ def test_authentication_redirect(test_client_factory):
327327
assert response.json() == {"authenticated": True, "user": "tomchristie"}
328328

329329

330-
def on_auth_error(request: Request, exc: Exception):
330+
def on_auth_error(request: HTTPConnection, exc: AuthenticationError):
331331
return JSONResponse({"error": str(exc)}, status_code=401)
332332

333333

0 commit comments

Comments
 (0)
0