diff --git a/halfapi/lib/acl.py b/halfapi/lib/acl.py index e6f3dd7..71f9713 100644 --- a/halfapi/lib/acl.py +++ b/halfapi/lib/acl.py @@ -66,26 +66,31 @@ def args_check(fct): return ', '.join(array) - args_d = kwargs.get('args', {}) - required = args_d.get('required', set()) + args_d = kwargs.get('args', None) + if args_d is not None: + required = args_d.get('required', set()) - missing = [] - data = {} + missing = [] + data = {} - for key in required: - data[key] = data_.pop(key, None) - if data[key] is None: - missing.append(key) + for key in required: + data[key] = data_.pop(key, None) + if data[key] is None: + missing.append(key) - if missing: - raise HTTPException( - 400, - f"Missing value{plural(missing)} for: {comma_list(missing)}!") + if missing: + raise HTTPException( + 400, + f"Missing value{plural(missing)} for: {comma_list(missing)}!") - optional = args_d.get('optional', set()) - for key in optional: - if key in data_: - data[key] = data_[key] + optional = args_d.get('optional', set()) + for key in optional: + if key in data_: + data[key] = data_[key] + else: + """ Unsafe mode, without specified arguments + """ + data = data_ kwargs['data'] = data diff --git a/halfapi/lib/domain.py b/halfapi/lib/domain.py index 09b5aea..10babab 100644 --- a/halfapi/lib/domain.py +++ b/halfapi/lib/domain.py @@ -6,16 +6,33 @@ lib/domain.py The domain-scoped utility functions import os import sys import importlib +import inspect import logging from types import ModuleType, FunctionType -from typing import Generator, Dict, List +from typing import Callable, Generator, Dict, List from halfapi.lib import acl +from halfapi.lib.responses import ORJSONResponse logger = logging.getLogger("uvicorn.asgi") VERBS = ('GET', 'POST', 'PUT', 'PATCH', 'DELETE') + +def route_decorator(fct: Callable = None, ret_type: str = 'json'): + """ Returns an async function that can be mounted on a router + """ + if ret_type == 'json': + @acl.args_check + async def wrapped(request, *args, **kwargs): + return ORJSONResponse( + fct(**request.path_params, data=kwargs.get('data'))) + else: + raise Exception('Return type not available') + + return wrapped + + def get_fct_name(http_verb: str, path: str) -> str: """ Returns the predictable name of the function for a route @@ -107,6 +124,9 @@ def gen_routes(route_params: Dict, path: List, m_router: ModuleType) -> Generato logger.error('%s is not defined in %s', fct_name, m_router.__name__) continue + if not inspect.iscoroutinefunction(fct): + fct = route_decorator(fct) + d_res[verb] = {'fct': fct, 'params': params} if len(d_res.keys()) > 1: