robot-u-site/auth.py

197 lines
5.5 KiB
Python

from __future__ import annotations
import json
import secrets
import time
from base64 import urlsafe_b64encode
from dataclasses import dataclass
from hashlib import sha256
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from fastapi import Request, Response
from settings import Settings
SESSION_COOKIE_NAME = "robot_u_session"
SESSION_MAX_AGE_SECONDS = 60 * 60 * 24 * 14
OAUTH_STATE_MAX_AGE_SECONDS = 60 * 10
@dataclass
class SessionRecord:
forgejo_token: str | None
user: dict[str, Any]
created_at: float
expires_at: float
@dataclass
class OAuthStateRecord:
redirect_uri: str
return_to: str
code_verifier: str
created_at: float
_OAUTH_STATES: dict[str, OAuthStateRecord] = {}
def resolve_forgejo_token(request: Request, settings: Settings) -> tuple[str | None, str]:
header_token = _authorization_token(request.headers.get("authorization"))
if header_token:
return header_token, "authorization"
session = _session_from_request(request, settings)
if session and session.forgejo_token:
return session.forgejo_token, "session"
if settings.forgejo_token:
return settings.forgejo_token, "server"
return None, "none"
def current_session_user(request: Request, settings: Settings) -> dict[str, Any] | None:
session = _session_from_request(request, settings)
return session.user if session else None
def create_login_session(
response: Response,
settings: Settings,
forgejo_token: str | None,
user: dict[str, Any],
) -> None:
created_at = time.time()
session = SessionRecord(
forgejo_token=forgejo_token,
user=user,
created_at=created_at,
expires_at=created_at + SESSION_MAX_AGE_SECONDS,
)
encrypted_session = _encrypt_session(session, settings)
response.set_cookie(
SESSION_COOKIE_NAME,
encrypted_session,
httponly=True,
secure=settings.auth_cookie_secure,
samesite="lax",
max_age=SESSION_MAX_AGE_SECONDS,
path="/",
)
def clear_login_session(request: Request, response: Response) -> None:
response.delete_cookie(SESSION_COOKIE_NAME, path="/")
def create_oauth_state(redirect_uri: str, return_to: str) -> tuple[str, str]:
state = secrets.token_urlsafe(32)
code_verifier = secrets.token_urlsafe(64)
_OAUTH_STATES[state] = OAuthStateRecord(
redirect_uri=redirect_uri,
return_to=_safe_return_path(return_to),
code_verifier=code_verifier,
created_at=time.time(),
)
return state, code_challenge(code_verifier)
def consume_oauth_state(state: str) -> OAuthStateRecord | None:
record = _OAUTH_STATES.pop(state, None)
if not record:
return None
if time.time() - record.created_at > OAUTH_STATE_MAX_AGE_SECONDS:
return None
return record
def code_challenge(code_verifier: str) -> str:
digest = sha256(code_verifier.encode("ascii")).digest()
return urlsafe_b64encode(digest).decode("ascii").rstrip("=")
def _authorization_token(value: str | None) -> str | None:
if not value:
return None
parts = value.strip().split(None, 1)
if len(parts) == 2 and parts[0].lower() in {"bearer", "token"}:
return parts[1].strip() or None
return None
def _safe_return_path(value: str) -> str:
if not value.startswith("/") or value.startswith("//"):
return "/"
return value
def _session_from_request(request: Request, settings: Settings) -> SessionRecord | None:
encrypted_session = request.cookies.get(SESSION_COOKIE_NAME)
if not encrypted_session or not settings.auth_secret_key:
return None
try:
payload = _session_cipher(settings).decrypt(
encrypted_session.encode("utf-8"),
ttl=SESSION_MAX_AGE_SECONDS,
)
raw_session = json.loads(payload.decode("utf-8"))
except (InvalidToken, json.JSONDecodeError, UnicodeDecodeError, ValueError):
return None
session = _session_from_payload(raw_session)
if session is None or session.expires_at <= time.time():
return None
return session
def _encrypt_session(session: SessionRecord, settings: Settings) -> str:
payload = {
"forgejo_token": session.forgejo_token,
"user": session.user,
"created_at": session.created_at,
"expires_at": session.expires_at,
}
return (
_session_cipher(settings)
.encrypt(
json.dumps(payload, separators=(",", ":")).encode("utf-8"),
)
.decode("utf-8")
)
def _session_from_payload(payload: object) -> SessionRecord | None:
if not isinstance(payload, dict):
return None
forgejo_token = payload.get("forgejo_token")
user = payload.get("user")
created_at = payload.get("created_at")
expires_at = payload.get("expires_at")
if forgejo_token is not None and not isinstance(forgejo_token, str):
return None
if not isinstance(user, dict):
return None
if not isinstance(created_at, (int, float)) or not isinstance(expires_at, (int, float)):
return None
return SessionRecord(
forgejo_token=forgejo_token,
user=user,
created_at=float(created_at),
expires_at=float(expires_at),
)
def _session_cipher(settings: Settings) -> Fernet:
if not settings.auth_secret_key:
raise ValueError("AUTH_SECRET_KEY is required for encrypted login sessions.")
key = urlsafe_b64encode(sha256(settings.auth_secret_key.encode("utf-8")).digest())
return Fernet(key)