8000 Use `MiddlewareType` as a type hint for middlewares by Viicos · Pull Request #1987 · encode/starlette · GitHub
[go: up one dir, main page]

Skip to content

Use MiddlewareType as a type hint for middlewares #1987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Router
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import ASGIApp, Lifespan, MiddlewareType, Receive, Scope, Send

AppType = typing.TypeVar("AppType", bound="Starlette")

Expand Down Expand Up @@ -134,7 +134,9 @@ def host(
) -> None: # pragma: no cover
self.router.host(host, app=app, name=name)

def add_middleware(self, middleware_class: type, **options: typing.Any) -> None:
def add_middleware(
self, middleware_class: MiddlewareType, **options: typing.Any
) -> None:
if self.middleware_stack is not None: # pragma: no cover
raise RuntimeError("Cannot add middleware after an application has started")
self.user_middleware.insert(0, Middleware(middleware_class, **options))
Expand Down
4 changes: 3 additions & 1 deletion starlette/middleware/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing

from starlette.types import MiddlewareType


class Middleware:
def __init__(self, cls: type, **options: typing.Any) -> None:
def __init__(self, cls: MiddlewareType, **options: typing.Any) -> None:
self.cls = cls
self.options = options

Expand Down
18 changes: 18 additions & 0 deletions starlette/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import sys
import typing

if sys.version_info < (3, 8): # pragma: no cover
from typing_extensions import Protocol
else: # pragma: no cover
from typing import Protocol

AppType = typing.TypeVar("AppType")

Scope = typing.MutableMapping[str, typing.Any]
Expand All @@ -15,3 +21,15 @@
[AppType], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
]
Lifespan = typing.Union[StatelessLifespan[AppType], StatefulLifespan[AppType]]


# This callable protocol can both be used to represent a function returning
# an ASGIApp, or a class with an __init__ method matching this __call__ signature
# and a __call__ method matching the ASGIApp signature.
class MiddlewareType(Protocol):
__name__: str

def __call__(
self, *args: typing.Any, **kwargs: typing.Any
) -> ASGIApp: # pragma: no cover
...
4 changes: 2 additions & 2 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, MiddlewareType, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -193,7 +193,7 @@ async def dispatch(self, request, call_next):
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
def test_contextvars(test_client_factory, middleware_cls: MiddlewareType):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
Expand Down
4 changes: 3 additions & 1 deletion tests/middleware/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from starlette.middleware import Middleware
from starlette.types import Receive, Scope, Send


class CustomMiddleware:
pass
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return None # pragma: no cover


def test_middleware_repr():
Expand Down
0