10000 Create a WebSocketAPIRoute class to support #166 by jekirl · Pull Request #178 · fastapi/fastapi · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
✨ Refactor/update implementation of WebSockets with dependencies, etc
  • Loading branch information
tiangolo committed May 24, 2019
commit a91824ceb5786301f3bfeecc7d565256483177f0
12 changes: 12 additions & 0 deletions fastapi/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,18 @@ def decorator(func: Callable) -> Callable:

return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable, name: str = None
) -> None:
self.router.add_api_websocket_route(path, endpoint, name=name)

def websocket(self, path: str, name: str = None) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_websocket_route(path, func, name=name)
return func

return decorator

def include_router(
self,
router: routing.APIRouter,
Expand Down
2 changes: 2 additions & 0 deletions fastapi/dependencies/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
name: str = None,
call: Callable = None,
request_param_name: str = None,
websocket_param_name: str = None,
background_tasks_param_name: str = None,
security_scopes_param_name: str = None,
security_scopes: List[str] = None,
Expand All @@ -38,6 +39,7 @@ def __init__(
self.dependencies = dependencies or []
self.security_requirements = security_schemes or []
self.request_param_name = request_param_name
self.websocket_param_name = websocket_param_name
self.background_tasks_param_name = background_tasks_param_name
self.security_scopes = security_scopes
self.security_scopes_param_name = security_scopes_param_name
Expand Down
12 changes: 7 additions & 5 deletions fastapi/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def get_dependant(
add_param_to_fields(
param=param, dependant=dependant, default_schema=params.Query
)
elif lenient_issubclass(param.annotation, Request) or lenient_issubclass(
param.annotation, WebSocket
):
elif lenient_issubclass(param.annotation, Request):
dependant.request_param_name = param_name
elif lenient_issubclass(param.annotation, WebSocket):
dependant.websocket_param_name = param_name
elif lenient_issubclass(param.annotation, BackgroundTasks):
dependant.background_tasks_param_name = param_name
elif lenient_issubclass(param.annotation, SecurityScopes):
Expand Down Expand Up @@ -258,7 +258,7 @@ def is_coroutine_callable(call: Callable) -> bool:

async def solve_dependencies(
*,
request: Request,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Dict[str, Any] = None,
background_tasks: BackgroundTasks = None,
Expand Down Expand Up @@ -305,8 +305,10 @@ async def solve_dependencies(
)
values.update(body_values)
errors.extend(body_errors)
if dependant.request_param_name:
if dependant.request_param_name and isinstance(request, Request):
values[dependant.request_param_name] = request
elif dependant.websocket_param_name and isinstance(request, WebSocket):
values[dependant.websocket_param_name] = request
if dependant.background_tasks_param_name:
if background_tasks is None:
background_tasks = BackgroundTasks()
Expand Down
80 changes: 49 additions & 31 deletions fastapi/routing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import inspect
import logging
import re
from typing import Any, Callable, Dict, List, Optional, Type, Union

from fastapi import params
Expand All @@ -16,9 +17,13 @@
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import compile_path, get_name, request_response
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
from starlette.types import ASGIInstance, Receive, Scope, Send
from starlette.routing import (
compile_path,
get_name,
request_response,
websocket_session,
)
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket


Expand Down Expand Up @@ -94,31 +99,33 @@ async def app(request: Request) -> Response:
return app


class WebSocketAPIRoute(routing.WebSocketRoute):
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
super().__init__(path, endpoint, name=name)
self.dependant = get_dependant(path=path, call=self.endpoint)
def get_websocket_app(dependant: Dependant) -> Callable:
async def app(websocket: WebSocket) -> None:
values, errors, _ = await solve_dependencies(
request=websocket, dependant=dependant
)
if errors:
await websocket.close(code=WS_1008_POLICY_VIOLATION)
errors_out = ValidationError(errors)
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail=errors_out.errors()
)
assert dependant.call is not None, "dependant.call must me a function"
await dependant.call(**values)

def app(scope: Scope) -> ASGIInstance:
async def awaitable(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
values, errors, background_tasks = await solve_dependencies(
request=websocket, dependant=self.dependant
)
if errors:
errors_out = ValidationError(errors)
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
detail=errors_out.errors(),
)
assert (
self.dependant.call is not None
), "dependant.call must me a function"
await self.dependant.call(**values)
return app

return awaitable

self.app = app
class APIWebSocketRoute(routing.WebSocketRoute):
def __init__(self, path: str, endpoint: Callable, *, name: str = None) -> None:
self.path = path
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
self.dependant = get_dependant(path=path, call=self.endpoint)
self.app = websocket_session(get_websocket_app(dependant=self.dependant))
regex = "^" + path + "$"
regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex)
self.path_regex, self.path_format, self.param_convertors = compile_path(path)


class APIRoute(routing.Route):
Expand Down Expand Up @@ -295,6 +302,19 @@ def decorator(func: Callable) -> Callable:

return decorator

def add_api_websocket_route(
self, path: str, endpoint: Callable, name: str = None
) -> None:
route = APIWebSocketRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)

def websocket(self, path: str, name: str = None) -> Callable:
def decorator(func: Callable) -> Callable:
self.add_api_websocket_route(path, func, name=name)
return func

return decorator

def include_router(
self,
router: "APIRouter",
Expand Down Expand Up @@ -338,17 +358,15 @@ def include_router(
include_in_schema=route.include_in_schema,
name=route.name,
)
elif isinstance(route, APIWebSocketRoute):
self.add_api_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)
elif isinstance(route, routing.WebSocketRoute):
self.add_websocket_route(
prefix + route.path, route.endpoint, name=route.name
)

def add_websocket_route(
self, path: str, endpoint: Callable, name: str = None
) -> None:
route = WebSocketAPIRoute(path, endpoint=endpoint, name=name)
self.routes.append(route)

def get(
self,
path: str,
Expand Down
2AA9
0