From 717d3f8bd64bc654b150ae7294a298745acbd4b3 Mon Sep 17 00:00:00 2001 From: Maxime Alves LIRMM Date: Sat, 14 Jan 2023 10:26:31 +0100 Subject: [PATCH] [responses] use a wrapper function for exception handling (fix starlette 0.20) --- halfapi/halfapi.py | 12 ++++++------ halfapi/lib/responses.py | 8 ++++++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/halfapi/halfapi.py b/halfapi/halfapi.py index 7b4b76d..944af91 100644 --- a/halfapi/halfapi.py +++ b/halfapi/halfapi.py @@ -35,7 +35,7 @@ from .lib.timing import HTimingClient from .lib.jwt_middleware import JWTAuthenticationBackend from .lib.responses import (ORJSONResponse, UnauthorizedResponse, NotFoundResponse, InternalServerErrorResponse, NotImplementedResponse, - ServiceUnavailableResponse) + ServiceUnavailableResponse, gen_exception_route) from .lib.domain import NoDomainsException from .lib.routes import gen_schema_routes, JSONRoute from .lib.schemas import schema_json @@ -90,11 +90,11 @@ class HalfAPI(Starlette): debug=not PRODUCTION, routes=routes, exception_handlers={ - 401: UnauthorizedResponse, - 404: NotFoundResponse, - 500: HalfAPI.exception, - 501: NotImplementedResponse, - 503: ServiceUnavailableResponse + 401: gen_exception_route(UnauthorizedResponse), + 404: gen_exception_route(NotFoundResponse), + 500: gen_exception_route(HalfAPI.exception), + 501: gen_exception_route(NotImplementedResponse), + 503: gen_exception_route(ServiceUnavailableResponse) }, on_startup=startup_fcts ) diff --git a/halfapi/lib/responses.py b/halfapi/lib/responses.py index 28a3794..4709205 100644 --- a/halfapi/lib/responses.py +++ b/halfapi/lib/responses.py @@ -24,6 +24,8 @@ import orjson # asgi framework from starlette.responses import PlainTextResponse, Response, JSONResponse, HTMLResponse +from starlette.requests import Request +from starlette.exceptions import HTTPException from .user import JWTUser, Nobody from ..logging import logger @@ -157,3 +159,9 @@ class ODSResponse(Response): class XLSXResponse(ODSResponse): file_type = 'xlsx' + +def gen_exception_route(response_cls): + async def exception_route(req: Request, exc: HTTPException): + return response_cls() + + return exception_route