Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better handle errors #37

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions litestar_vite/inertia/exception_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import re
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -44,7 +43,7 @@ class _HTTPConflictException(HTTPException):
status_code = HTTP_409_CONFLICT


def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]: # noqa: PLR0911
def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]:
"""Handler for all exceptions subclassed from HTTPException."""
inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
if isinstance(exc, NotFoundError):
Expand All @@ -57,26 +56,33 @@ def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exce
if request.app.debug and http_exc not in (PermissionDeniedException, NotFoundError):
return cast("Response[Any]", create_debug_response(request, exc))
return cast("Response[Any]", create_exception_response(request, http_exc(detail=str(exc.__cause__))))
return create_inertia_exception_response(request, exc)


def create_inertia_exception_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]:
"""Create the inertia exception response"""
is_inertia = getattr(request, "is_inertia", False)
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
preferred_type = MediaType.HTML if inertia_enabled and not is_inertia else MediaType.JSON
preferred_type = MediaType.HTML if not is_inertia else MediaType.JSON
detail = getattr(exc, "detail", "") # litestar exceptions
extras = getattr(exc, "extra", "") # msgspec exceptions
content = {"status_code": status_code, "message": getattr(exc, "detail", "")}
inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
if extras:
content.update({"extra": extras})
with contextlib.suppress(Exception):
try:
flash(request, detail, category="error")
except Exception as flash_exc:
msg = f"Failed to set the `flash` session state. Reason: {flash_exc.__cause__!s}"
request.logger.exception(msg)
if extras and len(extras) >= 1:
message = extras[0]
default_field = f"root.{message.get('key')}" if message.get("key", None) is not None else "root" # type: ignore
error_detail = cast("str", message.get("message", detail)) # type: ignore[union-attr] # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
match = FIELD_ERR_RE.search(error_detail)
field = match.group(1) if match else default_field
if isinstance(message, dict):
with contextlib.suppress(Exception):
error(request, field, error_detail)
error(request, field, error_detail)
if status_code in {HTTP_422_UNPROCESSABLE_ENTITY, HTTP_400_BAD_REQUEST} or isinstance(
exc,
PermissionDeniedException,
Expand All @@ -85,15 +91,15 @@ def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exce
if isinstance(exc, PermissionDeniedException):
return InertiaBack(request)
if status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException):
redirect_to = (
if (
inertia_plugin.config.redirect_unauthorized_to is not None
and str(request.url) != inertia_plugin.config.redirect_unauthorized_to
)
if redirect_to:
return InertiaRedirect(request, redirect_to=cast("str", inertia_plugin.config.redirect_unauthorized_to))
return InertiaBack(request)
return InertiaResponse[Any](
media_type=preferred_type,
content=content,
status_code=status_code,
)
):
return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to)
if str(request.url) != inertia_plugin.config.redirect_unauthorized_to:
return InertiaResponse[Any](
media_type=preferred_type,
content=content,
status_code=status_code,
)
return InertiaBack(request)
45 changes: 29 additions & 16 deletions litestar_vite/inertia/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ def share(
key: str,
value: Any,
) -> None:
connection.session.setdefault("_shared", {}).update({key: value})
try:
connection.session.setdefault("_shared", {}).update({key: value})
except Exception as exc:
msg = f"Failed to set the `share` session state. Reason: {exc.__cause__!s}"
connection.logger.exception(msg)


def error(
connection: ASGIConnection[Any, Any, Any, Any],
key: str,
message: str,
) -> None:
connection.session.setdefault("_errors", {}).update({key: message})
try:
connection.session.setdefault("_errors", {}).update({key: message})
except Exception as exc:
msg = f"Failed to set the `error` session state. Reason: {exc.__cause__!s}"
connection.logger.exception(msg)


def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, Any]: # noqa: UP006
Expand All @@ -63,21 +71,26 @@ def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, A

Be sure to call this before `self.create_template_context` if you would like to include the `flash` message details.
"""
error_bag = request.headers.get("X-Inertia-Error-Bag", None)
errors: dict[str, Any] = request.session.pop("_errors", {})
props: dict[str, Any] = request.session.pop("_shared", {})
flash: dict[str, list[str]] = defaultdict(list)
for message in cast("List[Dict[str,Any]]", request.session.pop("_messages", [])):
flash[message["category"]].append(message["message"])

inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
props.update(inertia_plugin.config.extra_static_page_props)
for session_prop in inertia_plugin.config.extra_session_page_props:
if session_prop not in props and session_prop in request.session:
props[session_prop] = request.session.get(session_prop)
props: dict[str, Any] = {}
try:
error_bag = request.headers.get("X-Inertia-Error-Bag", None)
errors: dict[str, Any] = request.session.pop("_errors", {})
props.update(cast("Dict[str,Any]", request.session.pop("_shared", {})))
flash: dict[str, list[str]] = defaultdict(list)
for message in cast("List[Dict[str,Any]]", request.session.pop("_messages", [])):
flash[message["category"]].append(message["message"])

inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
props.update(inertia_plugin.config.extra_static_page_props)
for session_prop in inertia_plugin.config.extra_session_page_props:
if session_prop not in props and session_prop in request.session:
props[session_prop] = request.session.get(session_prop)
props["flash"] = flash
props["errors"] = {error_bag: errors} if error_bag is not None else errors
except Exception as exc:
msg = f"Failed to set the `error` session state. Reason: {exc.__cause__}"
request.logger.exception(msg)
props["csrf_token"] = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "")
props["flash"] = flash
props["errors"] = {error_bag: errors} if error_bag is not None else errors
return props


Expand Down
2 changes: 1 addition & 1 deletion tests/test_inertia/test_inertia_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def handler(request: InertiaRequest[Any, Any, Any]) -> bool:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.text
== '{"component":null,"url":"/","version":"1.0","props":{"content":true,"csrf_token":"","flash":{},"errors":{}}}'
== '{"component":null,"url":"/","version":"1.0","props":{"content":true,"flash":{},"errors":{},"csrf_token":""}}'
)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_inertia/test_inertia_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"csrf_token":"","flash":{},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}'
)


Expand All @@ -72,7 +72,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"csrf_token":"","flash":{"info":["a flash message"]},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}'
)


Expand All @@ -99,7 +99,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"auth":{"user":"nobody"},"csrf_token":"","flash":{"info":["a flash message"]},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"auth":{"user":"nobody"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}'
)


Expand Down Expand Up @@ -135,5 +135,5 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
)
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"csrf_token":"","flash":{},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}'
)