8000 ✨ Update internal `AsyncExitStack` to fix context for dependencies wi… · JeanArhancet/fastapi@c0a5e28 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit c0a5e28

Browse files
tiangoloJeanArhancet
authored andcommitted
✨ Update internal AsyncExitStack to fix context for dependencies with yield (fastapi#4575)
1 parent b176778 commit c0a5e28

File tree

7 files changed

+272
-16
lines changed

7 files changed

+272
-16
lines changed

docs/en/docs/tutorial/dependencies/dependencies-with-yield.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ You saw that you can use dependencies with `yield` and have `try` blocks that ca
9999

100100
It might be tempting to raise an `HTTPException` or similar in the exit code, after the `yield`. But **it won't work**.
101101

102-
The exit code in dependencies with `yield` is executed *after* [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
102+
The exit code in dependencies with `yield` is executed *after* the response is sent, so [Exception Handlers](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank} will have already run. There's nothing catching exceptions thrown by your dependencies in the exit code (after the `yield`).
103103

104104
So, if you raise an `HTTPException` after the `yield`, the default (or any custom) exception handler that catches `HTTPException`s and returns an HTTP 400 response won't be there to catch that exception anymore.
105105

@@ -138,9 +138,11 @@ participant tasks as Background tasks
138138
end
139139
dep ->> operation: Run dependency, e.g. DB session
140140
opt raise
141-
operation -->> handler: Raise HTTPException
141+
operation -->> dep: Raise HTTPException
142+
dep -->> handler: Auto forward exception
142143
handler -->> client: HTTP error response
143144
operation -->> dep: Raise other exception
145+
dep -->> handler: Auto forward exception
144146
end
145147
operation ->> client: Return response to client
146148
Note over client,operation: Response is already sent, can't change it anymore
@@ -162,9 +164,9 @@ participant tasks as Background tasks
162164
After one of those responses is sent, no other response can be sent.
163165

164166
!!! tip
165-
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}. And that exception would be handled by that custom exception handler instead of the dependency exit code.
167+
This diagram shows `HTTPException`, but you could also raise any other exception for which you create a [Custom Exception Handler](../handling-errors.md#install-custom-exception-handlers){.internal-link target=_blank}.
166168

167-
But if you raise an exception that is not handled by the exception handlers, it will be handled by the exit code of the dependency.
169+
If you raise any exception, it will be passed to the dependencies with yield, including `HTTPException`, and then **again** to the exception handlers. If there's no exception handler for that exception, it will then be handled by the default internal `ServerErrorMiddleware`, returning a 500 HTTP status code, to let the client know that there was an error in the server.
168170

169171
## Context Managers
170172

fastapi/applications.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Callable, Coroutine, Dict, List, Optional, Sequence, Type, Union
33

44
from fastapi import routing
5-
from fastapi.concurrency import AsyncExitStack
65
from fastapi.datastructures import Default, DefaultPlaceholder
76
from fastapi.encoders import DictIntStrAny, SetIntStr
87
from fastapi.exception_handlers import (
@@ -11,6 +10,7 @@
1110
)
1211
from fastapi.exceptions import RequestValidationError
1312
from fastapi.logger import logger
13+
from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware
1414
from fastapi.openapi.docs import (
1515
get_redoc_html,
1616
get_swagger_ui_html,
@@ -21,8 +21,9 @@
2121
from fastapi 9E88 .types import DecoratedCallable
2222
from starlette.applications import Starlette
2323
from starlette.datastructures import State
24-
from starlette.exceptions import HTTPException
24+
from starlette.exceptions import ExceptionMiddleware, HTTPException
2525
from starlette.middleware import Middleware
26+
from starlette.middleware.errors import ServerErrorMiddleware
2627
from starlette.requests import Request
2728
from starlette.responses import HTMLResponse, JSONResponse, Response
2829
from starlette.routing import BaseRoute
@@ -134,6 +135,55 @@ def __init__(
134135
self.openapi_schema: Optional[Dict[str, Any]] = None
135136
self.setup()
136137

138+
def build_middleware_stack(self) -> ASGIApp:
139+
# Duplicate/override from Starlette to add AsyncExitStackMiddleware
140+
# inside of ExceptionMiddleware, inside of custom user middlewares
141+
debug = self.debug
142+
error_handler = None
143+
exception_handlers = {}
144+
145+
for key, value in self.exception_handlers.items():
146+
if key in (500, Exception):
147+
error_handler = value
148+
else:
149+
exception_handlers[key] = value
150+
151+
middleware = (
152+
[Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)]
153+
+ self.user_middleware
154+
+ [
155+
Middleware(
156+
ExceptionMiddleware, handlers=exception_handlers, debug=debug
157+
),
158+
# Add FastAPI-specific AsyncExitStackMiddleware for dependencies with
159+
# contextvars.
160+
# This needs to happen after user middlewares because those create a
161+
# new contextvars context copy by using a new AnyIO task group.
162+
# The initial part of dependencies with yield is executed in the
163+
# FastAPI code, inside all the middlewares, but the teardown part
164+
# (after yield) is executed in the AsyncExitStack in this middleware,
165+
# if the AsyncExitStack lived outside of the custom middlewares and
166+
# contextvars were set in a dependency with yield in that internal
167+
# contextvars context, the values would not be available in the
168+
# outside context of the AsyncExitStack.
169+
# By putting the middleware and the AsyncExitStack here, inside all
170+
# user middlewares, the code before and after yield in dependencies
171+
# with yield is executed in the same contextvars context, so all values
172+
# set in contextvars before yield is still available after yield as
173+
# would be expected.
174+
# Additionally, by having this AsyncExitStack here, after the
175+
# ExceptionMiddleware, now dependencies can catch handled exceptions,
176+
# e.g. HTTPException, to customize the teardown code (e.g. DB session
177+
# rollback).
178+
Middleware(AsyncExitStackMiddleware),
179+
]
180+
)
181+
182+
app = self.router
183+
for cls, options in reversed(middleware):
184+
app = cls(app=app, **options)
185+
return app
186+
137187
def openapi(self) -> Dict[str, Any]:
138188
if not self.openapi_schema:
139189
self.openapi_schema = get_openapi(
@@ -206,12 +256,7 @@ async def redoc_html(req: Request) -> HTMLResponse:
206256
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
207257
if self.root_path:
208258
scope["root_path"] = self.root_path
209-
if AsyncExitStack:
210-
async with AsyncExitStack() as stack:
211-
scope["fastapi_astack"] = stack
212-
await super().__call__(scope, receive, send)
213-
else:
214-
await super().__call__(scope, receive, send) # pragma: no cover
259+
await super().__call__(scope, receive, send)
215260

216261
def add_api_route(
217262
self,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from typing import Optional
2+
3+
from fastapi.concurrency import AsyncExitStack
4+
from starlette.types import ASGIApp, Receive, Scope, Send
5+
6+
7+
class AsyncExitStackMiddleware:
< 9E88 /td>
8+
def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None:
9+
self.app = app
10+
self.context_name = context_name
11+
12+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
13+
if AsyncExitStack:
14+
dependency_exception: Optional[Exception] = None
15+
async with AsyncExitStack() as stack:
16+
scope[self.context_name] = stack
17+
try:
18+
await self.app(scope, receive, send)
19+
except Exception as e:
20+
dependency_exception = e
21+
raise e
22+
if dependency_exception:
23+
# This exception was possibly handled by the dependency but it should
24+
# still bubble up so that the ServerErrorMiddleware can return a 500
25+
# or the ExceptionMiddleware can catch and handle any other exceptions
26+
raise dependency_exception
27+
else:
28+
await self.app(scope, receive, send) # pragma: no cover

tests/test_dependency_contextmanager.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,16 @@ def test_sync_raise_other():
235235
assert "/sync_raise" not in errors
236236

237237

238-
def test_async_raise():
238+
def test_async_raise_raises():
239+
with pytest.raises(AsyncDependencyError):
240+
client.get("/async_raise")
241+
assert state["/async_raise"] == "asyncgen raise finalized"
242+
assert "/async_raise" in errors
243+
errors.clear()
244+
245+
246+
def test_async_raise_server_error():
247+
client = TestClient(app, raise_server_exceptions=False)
239248
response = client.get("/async_raise")
240249
assert response.status_code == 500, response.text
241250
assert state["/async_raise"] == "asyncgen raise finalized"
@@ -270,7 +279,16 @@ def test_background_tasks():
270279
assert state["bg"] == "bg set - b: started b - a: started a"
271280

272281

273-
def test_sync_raise():
282+
def test_sync_raise_raises():
283+
with pytest.raises(SyncDependencyError):
284+
client.get("/sync_raise")
285+
assert state["/sync_raise"] == "generator raise finalized"
286+
assert "/sync_raise" in errors
287+
errors.clear()
288+
289+
290+
def test_sync_raise_server_error():
291+
client = TestClient(app, raise_server_exceptions=False)
274292
response = client.get("/sync_raise")
275293
assert response.status_code == 500, response.text
276294
assert state["/sync_raise"] == "generator raise finalized"
@@ -306,15 +324,33 @@ def test_sync_sync_raise_other():
306324
assert "/sync_raise" not in errors
307325

308326

309-
def test_sync_async_raise():
327+
def test_sync_async_raise_raises():
328+
with pytest.raises(AsyncDependencyError):
329+
client.get("/sync_async_raise")
330+
assert state["/async_raise"] == "asyncgen raise finalized"
331+
assert "/async_raise" in errors
332+
errors.clear()
333+
334+
335+
def test_sync_async_raise_server_error():
336+
client = TestClient(app, raise_server_exceptions=False)
310337
response = client.get("/sync_async_raise")
311338
assert response.status_code == 500, response.text
312339
assert state["/async_raise"] == "asyncgen raise finalized"
313340
assert "/async_raise" in errors
314341
errors.clear()
315342

316343

317-
def test_sync_sync_raise():
344+
def test_sync_sync_raise_raises():
345+
with pytest.raises(SyncDependencyError):
346+
client.get("/sync_sync_raise")
347+
assert state["/sync_raise"] == "generator raise finalized"
348+
assert "/sync_raise" in errors
349+
errors.clear()
350+
351+
352+
def test_sync_sync_raise_server_error():
353+
client = TestClient(app, raise_server_exceptions=False)
318354
response = client.get("/sync_sync_raise")
319355
assert response.status_code == 500, response.text
320356
assert state["/sync_raise"] == "generator raise finalized"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from contextvars import ContextVar
2+
from typing import Any, Awaitable, Callable, Dict, Optional
3+
4+
from fastapi import Depends, FastAPI, Request, Response
5+
from fastapi.testclient import TestClient
6+
7+
legacy_request_state_context_var: ContextVar[Optional[Dict[str, Any]]] = ContextVar(
8+
"legacy_request_state_context_var", default=None
9+
)
10+
11+
app = FastAPI()
12+
13+
14+
async def set_up_request_state_dependency():
15+
request_state = {"user": "deadpond"}
16+
contextvar_token = legacy_request_state_context_var.set(request_state)
17+
yield request_state
18+
legacy_request_state_context_var.reset(contextvar_token)
19+
20+
21+
@app.middleware("http")
22+
async def custom_middleware(
23+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
24+
):
25+
response = await call_next(request)
26+
response.headers["custom"] = "foo"
27+
return response
28+
29+
30+
@app.get("/user", dependencies=[Depends(set_up_request_state_dependency)])
31+
def get_user():
32+
request_state = legacy_request_state_context_var.get()
33+
assert request_state
34+
return request_state["user"]
35+
36+
37+
client = TestClient(app)
38+
39+
40+
def test_dependency_contextvars():
41+
"""
42+
Check that custom middlewares don't affect the contextvar context for dependencies.
43+
44+
The code before yield and the code after yield should be run in the same contextvar
45+
context, so that request_state_context_var.reset(contextvar_token).
46+
47+
If they are run in a different context, that raises an error.
48+
"""
49+
response = client.get("/user")
50+
assert response.json() == "deadpond"
51+
assert response.headers["custom"] == "foo"
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
from fastapi import Body, Depends, FastAPI, HTTPException
3+
from fastapi.testclient import TestClient
4+
5+
initial_fake_database = {"rick": "Rick Sanchez"}
6+
7+
fake_database = initial_fake_database.copy()
8+
9+
initial_state = {"except": False, "finally": False}
10+
11+
state = initial_state.copy()
12+
13+
app = FastAPI()
14+
15+
16+
async def get_database():
17+
temp_database = fake_database.copy()
18+
try:
19+
yield temp_database
20+
fake_database.update(temp_database)
21+
except HTTPException:
22+
state["except"] = True
23+
finally:
24+
state["finally"] = True
25+
26+
27+
@app.put("/invalid-user/{user_id}")
28+
def put_invalid_user(
29+
user_id: str, name: str = Body(...), db: dict = Depends(get_database)
30+
):
31+
db[user_id] = name
32+
raise HTTPException(status_code=400, detail="Invalid user")
33+
34+
35+
@app.put("/user/{user_id}")
36+
def put_user(user_id: str, name: str = Body(...), db: dict = Depends(get_database)):
37+
db[user_id] = name
38+
return {"message": "OK"}
39+
40+
41+
@pytest.fixture(autouse=True)
42+
def reset_state_and_db():
43+
global fake_database
44+
global state
45+
fake_database = initial_fake_database.copy()
46+
state = initial_state.copy()
47+
48+
49+
client = TestClient(app)
50+
51+
52+
def test_dependency_gets_exception():
53+
assert state["except"] is False
54+
assert state["finally"] is False
55+
response = client.put("/invalid-user/rick", json="Morty")
56+
assert response.status_code == 400, response.text
57+
assert response.json() == {"detail": "Invalid user"}
58+
assert state["except"] is True
59+
assert state["finally"] is True
60+
assert fake_database["rick"] == "Rick Sanchez"
61+
62+
63+
def test_dependency_no_exception():
64+
assert state["except"] is False
65+
assert state["finally"] is False
66+
response = client.put("/user/rick", json="Morty")
67+
assert response.status_code == 200, response.text
68+
assert response.json() == {"message": "OK"}
69+
assert state["except"] is False
70+
assert state["finally"] is True
71+
assert fake_database["rick"] == "Morty"

0 commit comments

Comments
 (0)
0