Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for github auth provider #149

Merged
merged 8 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 65 additions & 42 deletions giftless/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import dataclasses
import functools
import logging
import math
import os
import threading
from collections.abc import Callable, Mapping, MutableMapping
from contextlib import AbstractContextManager
from contextlib import AbstractContextManager, suppress
from operator import attrgetter, itemgetter
from threading import Condition, Lock, RLock
from typing import Any, cast, overload
from threading import Lock, RLock
from typing import Any, Protocol, cast, overload

import cachetools.keys
import flask
Expand All @@ -23,22 +24,30 @@


# THREAD SAFE CACHING UTILS
class _LockType(AbstractContextManager, Protocol):
"""Generic type for threading.Lock and RLock."""

def acquire(self, blocking: bool = ..., timeout: float = ...) -> bool:
...

def release(self) -> None:
...


@dataclasses.dataclass(kw_only=True)
class SingleCallContext:
"""Thread-safety context for the single_call_method decorator."""

# condition variable blocking a call with particular arguments
cond: Condition = dataclasses.field(default_factory=Condition)
# None - call not started, False - call ongoing, True - call done
# the three states are needed to cover any spurious (pthread-like) wake-ups
call_status: bool | None = None
# reentrant lock guarding a call with particular arguments
rlock: _LockType = dataclasses.field(default_factory=RLock)
start_call: bool = True
result: Any = None
error: BaseException | None = None


def _ensure_lock(
existing_lock: Callable[[Any], AbstractContextManager] | None,
) -> Callable[[Any], AbstractContextManager]:
existing_lock: Callable[[Any], _LockType] | None = None,
) -> Callable[[Any], _LockType]:
if existing_lock is None:
default_lock = RLock()
return lambda _self: default_lock
Expand All @@ -54,7 +63,7 @@ def single_call_method(_method: Callable[..., Any]) -> Callable[..., Any]:
def single_call_method(
*,
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], AbstractContextManager] | None = None,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
...

Expand All @@ -63,7 +72,7 @@ def single_call_method(
_method: Callable[..., Any] | None = None,
*,
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], AbstractContextManager] | None = None,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[..., Any]:
"""Thread-safe decorator limiting concurrency of an idempotent method call.
When multiple threads concurrently call the decorated method with the same
Expand All @@ -78,7 +87,7 @@ def single_call_method(

It's possible to provide a "getter" callable for the lock guarding the main
call cache, called as 'lock(self)'. There's a built-in lock by default.
Each concurrent call is then guarded by its own lock/conditional variable.
Each concurrent call is then guarded by its own reentrant lock variable.
"""
lock = _ensure_lock(lock)

Expand All @@ -97,35 +106,29 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
concurrent_calls[k] = ctx = SingleCallContext()
# start locked for the current thread, so the following
# gap won't let other threads populate the result
ctx.cond.acquire()
ctx.rlock.acquire()

with ctx.cond:
if ctx.call_status is None:
# populating the result
ctx.call_status = False
with ctx.rlock:
if ctx.start_call:
ctx.start_call = False
ctx.rlock.release() # unlock the starting lock
try:
result = method(self, *args, **kwargs)
except BaseException as e:
ctx.error = e
raise
finally:
# call is done, cleanup its entry and notify threads
# call is done, cleanup its entry
with lck:
del concurrent_calls[k]
ctx.cond.release() # unlock the starting lock
ctx.cond.notify_all()
ctx.result = result
ctx.call_status = True
return result

else:
# waiting for the result to get populated
while True:
if ctx.error:
raise ctx.error
if ctx.call_status:
return ctx.result
ctx.cond.wait()
# call is done
if ctx.error:
raise ctx.error
return ctx.result

return wrapper

Expand All @@ -138,7 +141,7 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
def cachedmethod_threadsafe(
cache: Callable[[Any], MutableMapping],
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], AbstractContextManager] | None = None,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[..., Any]:
"""Threadsafe variant of cachetools.cachedmethod."""
lock = _ensure_lock(lock)
Expand Down Expand Up @@ -178,15 +181,14 @@ class Schema(ma.Schema):
token_max_size = ma.fields.Int(
load_default=32, validate=ma.validate.Range(min=0)
)
# the auth cache must have at least one valid slot
auth_max_size = ma.fields.Int(
load_default=32, validate=ma.validate.Range(min=1)
load_default=32, validate=ma.validate.Range(min=0)
)
auth_write_ttl = ma.fields.Float(
load_default=15 * 60.0, validate=ma.validate.Range(min=1.0)
load_default=15 * 60.0, validate=ma.validate.Range(min=0)
)
auth_other_ttl = ma.fields.Float(
load_default=30.0, validate=ma.validate.Range(min=1.0)
load_default=30.0, validate=ma.validate.Range(min=0)
)

@ma.post_load
Expand Down Expand Up @@ -224,8 +226,9 @@ class Schema(ma.Schema):

@ma.post_load
def make_object(
self, data: Mapping[str, Any], **_kwargs: Mapping
self, data: MutableMapping[str, Any], **_kwargs: Mapping
) -> "Config":
data["api_url"] = data["api_url"].rstrip("/")
return Config(**data)

@classmethod
Expand Down Expand Up @@ -262,6 +265,10 @@ def expiration(_key: Any, value: set[Permission], now: float) -> float:
)
return now + ttl

# size-unlimited proxy cache to ensure at least one successful hit
self._auth_cache_read_proxy: MutableMapping[
Any, set[Permission]
] = cachetools.TTLCache(math.inf, 60.0)
self._auth_cache = cachetools.TLRUCache(cc.auth_max_size, expiration)
self._auth_cache_lock = Lock()

Expand All @@ -280,17 +287,28 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
return hash((self.login, self.id))

def permissions(self, org: str, repo: str) -> set[Permission] | None:
def permissions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure I understand the purpose of "authoritative".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to caching and thread safety, in the boundaries of the current design.

The "problem" is in the way how this whole authentication is performed in BaseView._is_authorized, which first creates the authenticator, __call__s it to get the Identity and only then it verifies the right permissions with Identity.is_authorized.

These two consecutive steps have no locks, so they can get mixed up with other threads simultaneously doing the same thing. Because I'm authorizing the user via GitHub API calls, my goal is to prevent those from repeating not only by caching that for some TTL, but also by blocking concurrent identical auth requests from different threads, so only a single thread gets to do the API calls and other just sit blocking and get only the result.

Because I wanted to do all the API calls within the GithubAuthenticator.__call__ method not to drag the requests Session object around (which is otherwise helpful in reusing a single TLS connection the to API), I'm storing only the resulting permissions on the GithubIdentity object, which can only simply expire.

in a different design with a shared requests.Session, the GithubIdentity would know itself how to get the user's permissions (and refresh its permission cache), but the Session should have a limited lifetime and as such, I'd have to manage it externally, which seemed even more complicated.

With all this, and the fact all the caches not only have configurable TTL, but also a rather low maximum size, I was obsessed about a scenario where (likely) an attacker would DoS the permission cache by doing concurrent calls for the same user, but different org/repo. In that scenario, the cached entry that got stored in the GithubAuthenticator.__call__ gets evicted sooner than it's being finally accessed via Identity.is_authorized.

Therefore I couldn't eventually resist to introduce a tiny in-between cache to remember the permissions for a user obtained in __call__ till it's being successfully read in is_authorized 😮‍💨

This permissions method is called in two contexts:

  • in the __call__ to peek if a certain user's permissions are cached
  • from is_authorized, which isn't used internally, but mean the obtained permissions have been read by the core logic for good and that I'm free to finally forget it.

The authoritative flag is just for that latter case, telling the logic the permissions have been properly read and now can sit and expire in the regular cache.

I admit this design is covering a really niche scenario, but once I realized it's there, my "OCD" couldn't let it stay. If you think there's a chance to considerably simplify the design, I'm all ears, but since this works and it's already done... 😇

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds good to me. My approach would just (if I'd noticed it at all, which I probably wouldn't because I wouldn't have tested high concurrency) have been to throw a mutex around the critical section and accept the performance hit. This seems much more amenable to many callers at once.

self, org: str, repo: str, *, authoritative: bool = False
) -> set[Permission] | None:
key = cachetools.keys.hashkey(org, repo)
with self._auth_cache_lock:
return self._auth_cache.get(key)
if authoritative:
permission = self._auth_cache_read_proxy.pop(key, None)
else:
permission = self._auth_cache_read_proxy.get(key)
if permission is None:
return self._auth_cache.get(key)
if authoritative:
with suppress(ValueError):
self._auth_cache[key] = permission
return permission

def authorize(
self, org: str, repo: str, permissions: set[Permission] | None
) -> None:
key = cachetools.keys.hashkey(org, repo)
with self._auth_cache_lock:
self._auth_cache[key] = (
self._auth_cache_read_proxy[key] = (
permissions if permissions is not None else set()
)

Expand All @@ -301,7 +319,7 @@ def is_authorized(
permission: Permission,
oid: str | None = None,
) -> bool:
permissions = self.permissions(organization, repo)
permissions = self.permissions(organization, repo, authoritative=True)
return permission in permissions if permissions else False

def cache_ttl(self, permissions: set[Permission]) -> float:
Expand Down Expand Up @@ -351,11 +369,12 @@ def _extract_token(self, request: flask.Request) -> str:
return token

def __post_init__(self, request: flask.Request) -> None:
self.org, self.repo = request.path.split("/", maxsplit=3)[1:3]
org_repo_getter = itemgetter("organization", "repo")
self.org, self.repo = org_repo_getter(request.view_args or {})
self.token = self._extract_token(request)

def __init__(self, cfg: Config) -> None:
self._api_url = cfg.api_url.rstrip("/")
self._api_url = cfg.api_url
self._api_headers = {"Accept": "application/vnd.github+json"}
if cfg.api_version:
self._api_headers["X-GitHub-Api-Version"] = cfg.api_version
Expand Down Expand Up @@ -471,8 +490,12 @@ def __call__(self, request: flask.Request) -> Identity | None:
self._authorize(ctx, user)
return user

@property
def api_url(self) -> str:
return self._api_url


def factory(**options: Mapping[str, Any]) -> GithubAuthenticator:
def factory(**options: Any) -> GithubAuthenticator:
"""Build GitHub Authenticator from supplied options."""
config = Config.from_dict(options)
return GithubAuthenticator(config)
1 change: 1 addition & 0 deletions requirements/dev.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pytest-mypy
pytest-env
pytest-cov
pytest-vcr
responses

pytz
types-pytz
Expand Down
5 changes: 5 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,17 @@ pytz==2023.3.post1
pyyaml==6.0.1
# via
# -c requirements/main.txt
# responses
# vcrpy
recommonmark==0.7.1
# via -r requirements/dev.in
requests==2.31.0
# via
# -c requirements/main.txt
# responses
# sphinx
responses==0.25.0
# via -r requirements/dev.in
rsa==4.9
# via
# -c requirements/main.txt
Expand Down Expand Up @@ -241,6 +245,7 @@ urllib3==2.0.7
# via
# -c requirements/main.txt
# requests
# responses
# types-requests
vcrpy==5.1.0
# via pytest-vcr
Expand Down
Loading
Loading