from __future__ import annotations import json import secrets import time from base64 import urlsafe_b64encode from dataclasses import dataclass from hashlib import sha256 from typing import Any from cryptography.fernet import Fernet, InvalidToken from fastapi import Request, Response from settings import Settings SESSION_COOKIE_NAME = "robot_u_session" SESSION_MAX_AGE_SECONDS = 60 * 60 * 24 * 14 OAUTH_STATE_MAX_AGE_SECONDS = 60 * 10 @dataclass class SessionRecord: forgejo_token: str | None user: dict[str, Any] created_at: float expires_at: float @dataclass class OAuthStateRecord: redirect_uri: str return_to: str code_verifier: str created_at: float _OAUTH_STATES: dict[str, OAuthStateRecord] = {} def resolve_forgejo_token(request: Request, settings: Settings) -> tuple[str | None, str]: header_token = _authorization_token(request.headers.get("authorization")) if header_token: return header_token, "authorization" session = _session_from_request(request, settings) if session and session.forgejo_token: return session.forgejo_token, "session" if settings.forgejo_token: return settings.forgejo_token, "server" return None, "none" def current_session_user(request: Request, settings: Settings) -> dict[str, Any] | None: session = _session_from_request(request, settings) return session.user if session else None def create_login_session( response: Response, settings: Settings, forgejo_token: str | None, user: dict[str, Any], ) -> None: created_at = time.time() session = SessionRecord( forgejo_token=forgejo_token, user=user, created_at=created_at, expires_at=created_at + SESSION_MAX_AGE_SECONDS, ) encrypted_session = _encrypt_session(session, settings) response.set_cookie( SESSION_COOKIE_NAME, encrypted_session, httponly=True, secure=settings.auth_cookie_secure, samesite="lax", max_age=SESSION_MAX_AGE_SECONDS, path="/", ) def clear_login_session(request: Request, response: Response) -> None: response.delete_cookie(SESSION_COOKIE_NAME, path="/") def create_oauth_state(redirect_uri: str, return_to: str) -> tuple[str, str]: state = secrets.token_urlsafe(32) code_verifier = secrets.token_urlsafe(64) _OAUTH_STATES[state] = OAuthStateRecord( redirect_uri=redirect_uri, return_to=_safe_return_path(return_to), code_verifier=code_verifier, created_at=time.time(), ) return state, code_challenge(code_verifier) def consume_oauth_state(state: str) -> OAuthStateRecord | None: record = _OAUTH_STATES.pop(state, None) if not record: return None if time.time() - record.created_at > OAUTH_STATE_MAX_AGE_SECONDS: return None return record def code_challenge(code_verifier: str) -> str: digest = sha256(code_verifier.encode("ascii")).digest() return urlsafe_b64encode(digest).decode("ascii").rstrip("=") def _authorization_token(value: str | None) -> str | None: if not value: return None parts = value.strip().split(None, 1) if len(parts) == 2 and parts[0].lower() in {"bearer", "token"}: return parts[1].strip() or None return None def _safe_return_path(value: str) -> str: if not value.startswith("/") or value.startswith("//"): return "/" return value def _session_from_request(request: Request, settings: Settings) -> SessionRecord | None: encrypted_session = request.cookies.get(SESSION_COOKIE_NAME) if not encrypted_session or not settings.auth_secret_key: return None try: payload = _session_cipher(settings).decrypt( encrypted_session.encode("utf-8"), ttl=SESSION_MAX_AGE_SECONDS, ) raw_session = json.loads(payload.decode("utf-8")) except (InvalidToken, json.JSONDecodeError, UnicodeDecodeError, ValueError): return None session = _session_from_payload(raw_session) if session is None or session.expires_at <= time.time(): return None return session def _encrypt_session(session: SessionRecord, settings: Settings) -> str: payload = { "forgejo_token": session.forgejo_token, "user": session.user, "created_at": session.created_at, "expires_at": session.expires_at, } return ( _session_cipher(settings) .encrypt( json.dumps(payload, separators=(",", ":")).encode("utf-8"), ) .decode("utf-8") ) def _session_from_payload(payload: object) -> SessionRecord | None: if not isinstance(payload, dict): return None forgejo_token = payload.get("forgejo_token") user = payload.get("user") created_at = payload.get("created_at") expires_at = payload.get("expires_at") if forgejo_token is not None and not isinstance(forgejo_token, str): return None if not isinstance(user, dict): return None if not isinstance(created_at, (int, float)) or not isinstance(expires_at, (int, float)): return None return SessionRecord( forgejo_token=forgejo_token, user=user, created_at=float(created_at), expires_at=float(expires_at), ) def _session_cipher(settings: Settings) -> Fernet: if not settings.auth_secret_key: raise ValueError("AUTH_SECRET_KEY is required for encrypted login sessions.") key = urlsafe_b64encode(sha256(settings.auth_secret_key.encode("utf-8")).digest()) return Fernet(key)