Skip to content

Commit

Permalink
feat: better handle errors (#37)
Browse files Browse the repository at this point in the history
* feat: better handle errors

* feat: updated tests
  • Loading branch information
cofin committed Jul 27, 2024
1 parent b69c001 commit 42ffc3f
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 37 deletions.
38 changes: 22 additions & 16 deletions litestar_vite/inertia/exception_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import re
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -44,7 +43,7 @@ class _HTTPConflictException(HTTPException):
status_code = HTTP_409_CONFLICT


def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]: # noqa: PLR0911
def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]:
"""Handler for all exceptions subclassed from HTTPException."""
inertia_enabled = getattr(request, "inertia_enabled", False) or getattr(request, "is_inertia", False)
if isinstance(exc, NotFoundError):
Expand All @@ -57,26 +56,33 @@ def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exce
if request.app.debug and http_exc not in (PermissionDeniedException, NotFoundError):
return cast("Response[Any]", create_debug_response(request, exc))
return cast("Response[Any]", create_exception_response(request, http_exc(detail=str(exc.__cause__))))
return create_inertia_exception_response(request, exc)


def create_inertia_exception_response(request: Request[UserT, AuthT, StateT], exc: Exception) -> Response[Any]:
"""Create the inertia exception response"""
is_inertia = getattr(request, "is_inertia", False)
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
preferred_type = MediaType.HTML if inertia_enabled and not is_inertia else MediaType.JSON
preferred_type = MediaType.HTML if not is_inertia else MediaType.JSON
detail = getattr(exc, "detail", "") # litestar exceptions
extras = getattr(exc, "extra", "") # msgspec exceptions
content = {"status_code": status_code, "message": getattr(exc, "detail", "")}
inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
if extras:
content.update({"extra": extras})
with contextlib.suppress(Exception):
try:
flash(request, detail, category="error")
except Exception as flash_exc:
msg = f"Failed to set the `flash` session state. Reason: {flash_exc.__cause__!s}"
request.logger.exception(msg)
if extras and len(extras) >= 1:
message = extras[0]
default_field = f"root.{message.get('key')}" if message.get("key", None) is not None else "root" # type: ignore
error_detail = cast("str", message.get("message", detail)) # type: ignore[union-attr] # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
match = FIELD_ERR_RE.search(error_detail)
field = match.group(1) if match else default_field
if isinstance(message, dict):
with contextlib.suppress(Exception):
error(request, field, error_detail)
error(request, field, error_detail)
if status_code in {HTTP_422_UNPROCESSABLE_ENTITY, HTTP_400_BAD_REQUEST} or isinstance(
exc,
PermissionDeniedException,
Expand All @@ -85,15 +91,15 @@ def exception_to_http_response(request: Request[UserT, AuthT, StateT], exc: Exce
if isinstance(exc, PermissionDeniedException):
return InertiaBack(request)
if status_code == HTTP_401_UNAUTHORIZED or isinstance(exc, NotAuthorizedException):
redirect_to = (
if (
inertia_plugin.config.redirect_unauthorized_to is not None
and str(request.url) != inertia_plugin.config.redirect_unauthorized_to
)
if redirect_to:
return InertiaRedirect(request, redirect_to=cast("str", inertia_plugin.config.redirect_unauthorized_to))
return InertiaBack(request)
return InertiaResponse[Any](
media_type=preferred_type,
content=content,
status_code=status_code,
)
):
return InertiaRedirect(request, redirect_to=inertia_plugin.config.redirect_unauthorized_to)
if str(request.url) != inertia_plugin.config.redirect_unauthorized_to:
return InertiaResponse[Any](
media_type=preferred_type,
content=content,
status_code=status_code,
)
return InertiaBack(request)
45 changes: 29 additions & 16 deletions litestar_vite/inertia/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ def share(
key: str,
value: Any,
) -> None:
connection.session.setdefault("_shared", {}).update({key: value})
try:
connection.session.setdefault("_shared", {}).update({key: value})
except Exception as exc:
msg = f"Failed to set the `share` session state. Reason: {exc.__cause__!s}"
connection.logger.exception(msg)


def error(
connection: ASGIConnection[Any, Any, Any, Any],
key: str,
message: str,
) -> None:
connection.session.setdefault("_errors", {}).update({key: message})
try:
connection.session.setdefault("_errors", {}).update({key: message})
except Exception as exc:
msg = f"Failed to set the `error` session state. Reason: {exc.__cause__!s}"
connection.logger.exception(msg)


def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, Any]: # noqa: UP006
Expand All @@ -63,21 +71,26 @@ def get_shared_props(request: ASGIConnection[Any, Any, Any, Any]) -> Dict[str, A
Be sure to call this before `self.create_template_context` if you would like to include the `flash` message details.
"""
error_bag = request.headers.get("X-Inertia-Error-Bag", None)
errors: dict[str, Any] = request.session.pop("_errors", {})
props: dict[str, Any] = request.session.pop("_shared", {})
flash: dict[str, list[str]] = defaultdict(list)
for message in cast("List[Dict[str,Any]]", request.session.pop("_messages", [])):
flash[message["category"]].append(message["message"])

inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
props.update(inertia_plugin.config.extra_static_page_props)
for session_prop in inertia_plugin.config.extra_session_page_props:
if session_prop not in props and session_prop in request.session:
props[session_prop] = request.session.get(session_prop)
props: dict[str, Any] = {}
try:
error_bag = request.headers.get("X-Inertia-Error-Bag", None)
errors: dict[str, Any] = request.session.pop("_errors", {})
props.update(cast("Dict[str,Any]", request.session.pop("_shared", {})))
flash: dict[str, list[str]] = defaultdict(list)
for message in cast("List[Dict[str,Any]]", request.session.pop("_messages", [])):
flash[message["category"]].append(message["message"])

inertia_plugin = cast("InertiaPlugin", request.app.plugins.get("InertiaPlugin"))
props.update(inertia_plugin.config.extra_static_page_props)
for session_prop in inertia_plugin.config.extra_session_page_props:
if session_prop not in props and session_prop in request.session:
props[session_prop] = request.session.get(session_prop)
props["flash"] = flash
props["errors"] = {error_bag: errors} if error_bag is not None else errors
except Exception as exc:
msg = f"Failed to set the `error` session state. Reason: {exc.__cause__}"
request.logger.exception(msg)
props["csrf_token"] = value_or_default(ScopeState.from_scope(request.scope).csrf_token, "")
props["flash"] = flash
props["errors"] = {error_bag: errors} if error_bag is not None else errors
return props


Expand Down
2 changes: 1 addition & 1 deletion tests/test_inertia/test_inertia_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def handler(request: InertiaRequest[Any, Any, Any]) -> bool:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.text
== '{"component":null,"url":"/","version":"1.0","props":{"content":true,"csrf_token":"","flash":{},"errors":{}}}'
== '{"component":null,"url":"/","version":"1.0","props":{"content":true,"flash":{},"errors":{},"csrf_token":""}}'
)


Expand Down
8 changes: 4 additions & 4 deletions tests/test_inertia/test_inertia_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"csrf_token":"","flash":{},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}'
)


Expand All @@ -72,7 +72,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"csrf_token":"","flash":{"info":["a flash message"]},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}'
)


Expand All @@ -99,7 +99,7 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
response = client.get("/", headers={InertiaHeaders.ENABLED.value: "true"})
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"auth":{"user":"nobody"},"csrf_token":"","flash":{"info":["a flash message"]},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"auth":{"user":"nobody"},"flash":{"info":["a flash message"]},"errors":{},"csrf_token":""}}'
)


Expand Down Expand Up @@ -135,5 +135,5 @@ async def handler(request: Request[Any, Any, Any]) -> Dict[str, Any]:
)
assert (
response.content
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"csrf_token":"","flash":{},"errors":{}}}'
== b'{"component":"Home","url":"/","version":"1.0","props":{"content":{"thing":"value"},"flash":{},"errors":{},"csrf_token":""}}'
)

0 comments on commit 42ffc3f

Please sign in to comment.