142 lines
3.6 KiB
Python
142 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import secrets
|
|
import time
|
|
from base64 import urlsafe_b64encode
|
|
from dataclasses import dataclass
|
|
from hashlib import sha256
|
|
from typing import Any
|
|
|
|
from fastapi import Request, Response
|
|
|
|
from settings import Settings
|
|
|
|
SESSION_COOKIE_NAME = "robot_u_session"
|
|
SESSION_MAX_AGE_SECONDS = 60 * 60 * 24 * 14
|
|
OAUTH_STATE_MAX_AGE_SECONDS = 60 * 10
|
|
|
|
|
|
@dataclass
|
|
class SessionRecord:
|
|
forgejo_token: str | None
|
|
user: dict[str, Any]
|
|
created_at: float
|
|
|
|
|
|
@dataclass
|
|
class OAuthStateRecord:
|
|
redirect_uri: str
|
|
return_to: str
|
|
code_verifier: str
|
|
created_at: float
|
|
|
|
|
|
_SESSIONS: dict[str, SessionRecord] = {}
|
|
_OAUTH_STATES: dict[str, OAuthStateRecord] = {}
|
|
|
|
|
|
def resolve_forgejo_token(request: Request, settings: Settings) -> tuple[str | None, str]:
|
|
header_token = _authorization_token(request.headers.get("authorization"))
|
|
if header_token:
|
|
return header_token, "authorization"
|
|
|
|
session = _session_from_request(request)
|
|
if session and session.forgejo_token:
|
|
return session.forgejo_token, "session"
|
|
|
|
if settings.forgejo_token:
|
|
return settings.forgejo_token, "server"
|
|
|
|
return None, "none"
|
|
|
|
|
|
def current_session_user(request: Request) -> dict[str, Any] | None:
|
|
session = _session_from_request(request)
|
|
return session.user if session else None
|
|
|
|
|
|
def create_login_session(
|
|
response: Response,
|
|
forgejo_token: str | None,
|
|
user: dict[str, Any],
|
|
) -> None:
|
|
session_id = secrets.token_urlsafe(32)
|
|
_SESSIONS[session_id] = SessionRecord(
|
|
forgejo_token=forgejo_token,
|
|
user=user,
|
|
created_at=time.time(),
|
|
)
|
|
response.set_cookie(
|
|
SESSION_COOKIE_NAME,
|
|
session_id,
|
|
httponly=True,
|
|
samesite="lax",
|
|
max_age=SESSION_MAX_AGE_SECONDS,
|
|
path="/",
|
|
)
|
|
|
|
|
|
def clear_login_session(request: Request, response: Response) -> None:
|
|
session_id = request.cookies.get(SESSION_COOKIE_NAME)
|
|
if session_id:
|
|
_SESSIONS.pop(session_id, None)
|
|
response.delete_cookie(SESSION_COOKIE_NAME, path="/")
|
|
|
|
|
|
def create_oauth_state(redirect_uri: str, return_to: str) -> tuple[str, str]:
|
|
state = secrets.token_urlsafe(32)
|
|
code_verifier = secrets.token_urlsafe(64)
|
|
_OAUTH_STATES[state] = OAuthStateRecord(
|
|
redirect_uri=redirect_uri,
|
|
return_to=_safe_return_path(return_to),
|
|
code_verifier=code_verifier,
|
|
created_at=time.time(),
|
|
)
|
|
return state, code_challenge(code_verifier)
|
|
|
|
|
|
def consume_oauth_state(state: str) -> OAuthStateRecord | None:
|
|
record = _OAUTH_STATES.pop(state, None)
|
|
if not record:
|
|
return None
|
|
|
|
if time.time() - record.created_at > OAUTH_STATE_MAX_AGE_SECONDS:
|
|
return None
|
|
return record
|
|
|
|
|
|
def code_challenge(code_verifier: str) -> str:
|
|
digest = sha256(code_verifier.encode("ascii")).digest()
|
|
return urlsafe_b64encode(digest).decode("ascii").rstrip("=")
|
|
|
|
|
|
def _authorization_token(value: str | None) -> str | None:
|
|
if not value:
|
|
return None
|
|
|
|
parts = value.strip().split(None, 1)
|
|
if len(parts) == 2 and parts[0].lower() in {"bearer", "token"}:
|
|
return parts[1].strip() or None
|
|
return None
|
|
|
|
|
|
def _safe_return_path(value: str) -> str:
|
|
if not value.startswith("/") or value.startswith("//"):
|
|
return "/"
|
|
return value
|
|
|
|
|
|
def _session_from_request(request: Request) -> SessionRecord | None:
|
|
session_id = request.cookies.get(SESSION_COOKIE_NAME)
|
|
if not session_id:
|
|
return None
|
|
|
|
session = _SESSIONS.get(session_id)
|
|
if not session:
|
|
return None
|
|
|
|
if time.time() - session.created_at > SESSION_MAX_AGE_SECONDS:
|
|
_SESSIONS.pop(session_id, None)
|
|
return None
|
|
|
|
return session
|