2026-04-12 20:15:33 -04:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
import json
|
2026-04-12 20:15:33 -04:00
|
|
|
import secrets
|
|
|
|
|
import time
|
|
|
|
|
from base64 import urlsafe_b64encode
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
from hashlib import sha256
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
2026-04-12 20:15:33 -04:00
|
|
|
from fastapi import Request, Response
|
|
|
|
|
|
|
|
|
|
from settings import Settings
|
|
|
|
|
|
|
|
|
|
SESSION_COOKIE_NAME = "robot_u_session"
|
|
|
|
|
SESSION_MAX_AGE_SECONDS = 60 * 60 * 24 * 14
|
|
|
|
|
OAUTH_STATE_MAX_AGE_SECONDS = 60 * 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class SessionRecord:
|
|
|
|
|
forgejo_token: str | None
|
|
|
|
|
user: dict[str, Any]
|
|
|
|
|
created_at: float
|
2026-04-12 22:02:47 -04:00
|
|
|
expires_at: float
|
2026-04-12 20:15:33 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class OAuthStateRecord:
|
|
|
|
|
redirect_uri: str
|
|
|
|
|
return_to: str
|
|
|
|
|
code_verifier: str
|
|
|
|
|
created_at: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_OAUTH_STATES: dict[str, OAuthStateRecord] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_forgejo_token(request: Request, settings: Settings) -> tuple[str | None, str]:
|
|
|
|
|
header_token = _authorization_token(request.headers.get("authorization"))
|
|
|
|
|
if header_token:
|
|
|
|
|
return header_token, "authorization"
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
session = _session_from_request(request, settings)
|
2026-04-12 20:15:33 -04:00
|
|
|
if session and session.forgejo_token:
|
|
|
|
|
return session.forgejo_token, "session"
|
|
|
|
|
|
|
|
|
|
if settings.forgejo_token:
|
|
|
|
|
return settings.forgejo_token, "server"
|
|
|
|
|
|
|
|
|
|
return None, "none"
|
|
|
|
|
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
def current_session_user(request: Request, settings: Settings) -> dict[str, Any] | None:
|
|
|
|
|
session = _session_from_request(request, settings)
|
2026-04-12 20:15:33 -04:00
|
|
|
return session.user if session else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_login_session(
|
|
|
|
|
response: Response,
|
2026-04-12 22:02:47 -04:00
|
|
|
settings: Settings,
|
2026-04-12 20:15:33 -04:00
|
|
|
forgejo_token: str | None,
|
|
|
|
|
user: dict[str, Any],
|
|
|
|
|
) -> None:
|
2026-04-12 22:02:47 -04:00
|
|
|
created_at = time.time()
|
|
|
|
|
session = SessionRecord(
|
2026-04-12 20:15:33 -04:00
|
|
|
forgejo_token=forgejo_token,
|
|
|
|
|
user=user,
|
2026-04-12 22:02:47 -04:00
|
|
|
created_at=created_at,
|
|
|
|
|
expires_at=created_at + SESSION_MAX_AGE_SECONDS,
|
2026-04-12 20:15:33 -04:00
|
|
|
)
|
2026-04-12 22:02:47 -04:00
|
|
|
encrypted_session = _encrypt_session(session, settings)
|
2026-04-12 20:15:33 -04:00
|
|
|
response.set_cookie(
|
|
|
|
|
SESSION_COOKIE_NAME,
|
2026-04-12 22:02:47 -04:00
|
|
|
encrypted_session,
|
2026-04-12 20:15:33 -04:00
|
|
|
httponly=True,
|
2026-04-12 22:02:47 -04:00
|
|
|
secure=settings.auth_cookie_secure,
|
2026-04-12 20:15:33 -04:00
|
|
|
samesite="lax",
|
|
|
|
|
max_age=SESSION_MAX_AGE_SECONDS,
|
|
|
|
|
path="/",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_login_session(request: Request, response: Response) -> None:
|
|
|
|
|
response.delete_cookie(SESSION_COOKIE_NAME, path="/")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_oauth_state(redirect_uri: str, return_to: str) -> tuple[str, str]:
|
|
|
|
|
state = secrets.token_urlsafe(32)
|
|
|
|
|
code_verifier = secrets.token_urlsafe(64)
|
|
|
|
|
_OAUTH_STATES[state] = OAuthStateRecord(
|
|
|
|
|
redirect_uri=redirect_uri,
|
|
|
|
|
return_to=_safe_return_path(return_to),
|
|
|
|
|
code_verifier=code_verifier,
|
|
|
|
|
created_at=time.time(),
|
|
|
|
|
)
|
|
|
|
|
return state, code_challenge(code_verifier)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def consume_oauth_state(state: str) -> OAuthStateRecord | None:
|
|
|
|
|
record = _OAUTH_STATES.pop(state, None)
|
|
|
|
|
if not record:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if time.time() - record.created_at > OAUTH_STATE_MAX_AGE_SECONDS:
|
|
|
|
|
return None
|
|
|
|
|
return record
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def code_challenge(code_verifier: str) -> str:
|
|
|
|
|
digest = sha256(code_verifier.encode("ascii")).digest()
|
|
|
|
|
return urlsafe_b64encode(digest).decode("ascii").rstrip("=")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _authorization_token(value: str | None) -> str | None:
|
|
|
|
|
if not value:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
parts = value.strip().split(None, 1)
|
|
|
|
|
if len(parts) == 2 and parts[0].lower() in {"bearer", "token"}:
|
|
|
|
|
return parts[1].strip() or None
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_return_path(value: str) -> str:
|
|
|
|
|
if not value.startswith("/") or value.startswith("//"):
|
|
|
|
|
return "/"
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
def _session_from_request(request: Request, settings: Settings) -> SessionRecord | None:
|
|
|
|
|
encrypted_session = request.cookies.get(SESSION_COOKIE_NAME)
|
|
|
|
|
if not encrypted_session or not settings.auth_secret_key:
|
2026-04-12 20:15:33 -04:00
|
|
|
return None
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
try:
|
|
|
|
|
payload = _session_cipher(settings).decrypt(
|
|
|
|
|
encrypted_session.encode("utf-8"),
|
|
|
|
|
ttl=SESSION_MAX_AGE_SECONDS,
|
|
|
|
|
)
|
|
|
|
|
raw_session = json.loads(payload.decode("utf-8"))
|
|
|
|
|
except (InvalidToken, json.JSONDecodeError, UnicodeDecodeError, ValueError):
|
2026-04-12 20:15:33 -04:00
|
|
|
return None
|
|
|
|
|
|
2026-04-12 22:02:47 -04:00
|
|
|
session = _session_from_payload(raw_session)
|
|
|
|
|
if session is None or session.expires_at <= time.time():
|
2026-04-12 20:15:33 -04:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return session
|
2026-04-12 22:02:47 -04:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _encrypt_session(session: SessionRecord, settings: Settings) -> str:
|
|
|
|
|
payload = {
|
|
|
|
|
"forgejo_token": session.forgejo_token,
|
|
|
|
|
"user": session.user,
|
|
|
|
|
"created_at": session.created_at,
|
|
|
|
|
"expires_at": session.expires_at,
|
|
|
|
|
}
|
|
|
|
|
return (
|
|
|
|
|
_session_cipher(settings)
|
|
|
|
|
.encrypt(
|
|
|
|
|
json.dumps(payload, separators=(",", ":")).encode("utf-8"),
|
|
|
|
|
)
|
|
|
|
|
.decode("utf-8")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _session_from_payload(payload: object) -> SessionRecord | None:
|
|
|
|
|
if not isinstance(payload, dict):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
forgejo_token = payload.get("forgejo_token")
|
|
|
|
|
user = payload.get("user")
|
|
|
|
|
created_at = payload.get("created_at")
|
|
|
|
|
expires_at = payload.get("expires_at")
|
|
|
|
|
if forgejo_token is not None and not isinstance(forgejo_token, str):
|
|
|
|
|
return None
|
|
|
|
|
if not isinstance(user, dict):
|
|
|
|
|
return None
|
|
|
|
|
if not isinstance(created_at, (int, float)) or not isinstance(expires_at, (int, float)):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
return SessionRecord(
|
|
|
|
|
forgejo_token=forgejo_token,
|
|
|
|
|
user=user,
|
|
|
|
|
created_at=float(created_at),
|
|
|
|
|
expires_at=float(expires_at),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _session_cipher(settings: Settings) -> Fernet:
|
|
|
|
|
if not settings.auth_secret_key:
|
|
|
|
|
raise ValueError("AUTH_SECRET_KEY is required for encrypted login sessions.")
|
|
|
|
|
|
|
|
|
|
key = urlsafe_b64encode(sha256(settings.auth_secret_key.encode("utf-8")).digest())
|
|
|
|
|
return Fernet(key)
|