Source code for satorbis_kit.auth.auth_oidc

"""
``OIDCAuth`` — the single entry-point class for the notebook.

Responsibilities
~~~~~~~~~~~~~~~~
* Build the PKCE authorization URL
* Launch the local callback server
* Wait for the callback and exchange the code for tokens
* Enforce that the authenticated identity matches the expected email
* Return a valid access token, refreshing automatically if needed
* Refresh tokens explicitly on demand
"""

import time
from typing import Optional, Union

import requests
from authlib.common.security import generate_token
from authlib.integrations.requests_client import OAuth2Session

from .auth_base import BaseAuth
from .callback_server import get_callback_url, launch_callback_listener
from .environments import Environment, get_env_config
from .token_store import TokenStore

[docs] class IdentityMismatchError(Exception): """ Raised when the email in the returned token does not match the ``login_hint`` supplied to :meth:`OIDCAuth.start_login`. This prevents a user from clearing the pre-filled email in the browser and authenticating as a different identity. """
[docs] class OIDCAuth(BaseAuth): """ Manages the full OIDC Authorization Code + PKCE flow against an OIDC provider, designed for use inside a Jupyter notebook. Parameters ---------- environment: Target deployment environment. Accepts an :class:`Environment` member, a plain string (``"prod"`` or ``"dev"``), or ``None``. When ``None`` the value of the ``SATORBIS_ENV`` environment variable is used; if that is also absent, ``prod`` is the default. Example: ``OIDCAuth(environment="dev")``. application: Application name used to look up a Zitadel project ID from the environment's ``project_ids`` mapping. When a project ID is found, ``urn:zitadel:iam:org:project:id:{project_id}:aud`` is appended to the requested scopes so the token carries the project audience claim. Example: ``OIDCAuth(application="spatial_engine")``. redirect_port: Local port for the OAuth callback server. Must match the redirect URI registered with the provider. Default: ``4200``. scopes: Space-separated OIDC scopes. Include ``offline_access`` to receive a refresh token. callback_timeout: Seconds to wait for the browser callback before raising ``TimeoutError``. Default: 300 (5 minutes). authorization_url: Optional. The authorization endpoint URL. If not provided, it will be discovered from the OIDC discovery document. token_url: Optional. The token endpoint URL. If not provided, it will be discovered from the OIDC discovery document. verify_jwt: If ``True``, JWT signature verification is performed using the provider's JWKS when enforcing email identity. Default: ``True``. """ # ------------------------------------------------------------------ # # Construction # # ------------------------------------------------------------------ # def __init__( self, environment: Union[Environment, str, None] = None, application: Optional[str] = None, redirect_port: int = 4200, scopes: str = "openid profile email offline_access", callback_timeout: int = 300, authorization_url: Optional[str] = None, token_url: Optional[str] = None, verify_jwt: bool = True, ) -> None: super().__init__() if isinstance(environment, str): environment = Environment(environment) resolved_env, config = get_env_config(environment) self.environment = resolved_env self.domain = config.domain.rstrip("/") self.client_id = config.oidc_client_id # Append project audience scope if the application has a project_id mapping project_id = config.project_ids.get(application) if application else None if project_id: scopes = f"{scopes} urn:zitadel:iam:org:project:id:{project_id}:aud" self.redirect_port = redirect_port self.scopes = scopes self.callback_timeout = callback_timeout self.verify_jwt = verify_jwt self.redirect_uri = f"http://localhost:{redirect_port}/logincallback" self.discovery_url = f"{self.domain}/.well-known/openid-configuration" # Endpoints — use explicit values or discover from the provider if authorization_url and token_url: self.authorization_url = authorization_url self.token_url = token_url else: self._discover_endpoints() # Internal PKCE / session state — set during start_login() self._code_verifier: Optional[str] = None self._state: Optional[str] = None self._oauth_session: Optional[OAuth2Session] = None self._expected_email: Optional[str] = None # enforced in await_callback() # Cached JWKS for JWT verification (lazy-loaded) self._jwks = None self._jwks_uri: Optional[str] = None # ------------------------------------------------------------------ # # OIDC Discovery # # ------------------------------------------------------------------ # def _discover_endpoints(self) -> None: """ Fetch the OIDC discovery document and extract endpoint URLs. Uses the ``/.well-known/openid-configuration`` endpoint to retrieve the ``authorization_endpoint`` and ``token_endpoint``. """ resp = requests.get(self.discovery_url, timeout=10) resp.raise_for_status() config = resp.json() self.authorization_url = config["authorization_endpoint"] self.token_url = config["token_endpoint"] self._jwks_uri = config.get("jwks_uri") # ------------------------------------------------------------------ # # Step 1 — Generate login URL # # ------------------------------------------------------------------ #
[docs] def start_login(self, login_hint: str = "") -> str: """ Generate the PKCE authorization URL and start the local callback server in a background thread. The ``login_hint`` email is remembered and **enforced** in :meth:`await_callback` — if the user authenticates as a different identity, :class:`IdentityMismatchError` is raised and no tokens are stored. Parameters ---------- login_hint: Email address to pre-fill on the provider login page. When provided, the authenticated token must match this email. Returns ------- str The authorization URL. The user must open it in a browser. """ # Store for enforcement after callback self._expected_email = login_hint.strip().lower() if login_hint else None # Generate a fresh PKCE code_verifier for every login attempt self._code_verifier = generate_token(64) self._oauth_session = OAuth2Session( client_id=self.client_id, redirect_uri=self.redirect_uri, scope=self.scopes, code_challenge_method="S256", ) login_url, self._state = self._oauth_session.create_authorization_url( self.authorization_url, code_verifier=self._code_verifier, login_hint=login_hint, nonce=generate_token(16), ) # Start the local HTTP listener in the background launch_callback_listener( port=self.redirect_port, timeout_seconds=self.callback_timeout, ) self._print_login_banner(login_url, login_hint) return login_url
# ------------------------------------------------------------------ # # Step 2 — Await callback + exchange code # # ------------------------------------------------------------------ #
[docs] def await_callback(self) -> TokenStore: """ Block until the browser callback is received, exchange the authorization code for tokens, then **verify** that the authenticated email matches the ``login_hint`` from :meth:`start_login`. Raises ------ RuntimeError If :meth:`start_login` has not been called first. TimeoutError If no callback is received within ``callback_timeout`` seconds. IdentityMismatchError If the authenticated user's email does not match the expected ``login_hint``. Tokens are **not** stored in this case. """ if self._oauth_session is None or self._code_verifier is None: raise RuntimeError( "Call start_login() first to generate the authorization URL." ) print("⏳ Waiting for browser callback …") callback_url = get_callback_url(timeout_seconds=self.callback_timeout) print("✅ Callback received — exchanging authorization code …") # Restore state so authlib can validate the `state` parameter self._oauth_session.state = self._state raw_token = self._oauth_session.fetch_token( self.token_url, authorization_response=callback_url, code_verifier=self._code_verifier, grant_type="authorization_code", ) # ── Email identity enforcement ───────────────────────────────── # Validate BEFORE storing any tokens so a mismatched login leaves # the TokenStore empty and cannot be used for API calls. if self._expected_email: self._enforce_email_identity(raw_token, self._expected_email) # All checks passed — persist tokens self.tokens.update_from_raw(raw_token) exp_fmt = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(self.tokens.expires_at) ) print() print("🎉 Tokens stored successfully!") print( f" Authenticated as : {self._extract_email_from_raw(raw_token) or 'unknown'}" ) print(f" Access token : {self.tokens.access_token[:40]}…") print( f" Refresh token : " f"{'✅ present' if self.tokens.refresh_token else '❌ not returned (check offline_access scope)'}" ) print(f" Expires at : {exp_fmt}") return self.tokens
# ------------------------------------------------------------------ # # authenticate() — required by BaseAuth # # ------------------------------------------------------------------ #
[docs] def authenticate(self) -> TokenStore: """ For interactive OIDC flows, use :meth:`start_login` followed by :meth:`await_callback`. This method is not directly applicable to interactive browser-based flows. Raises ------ NotImplementedError Always. Use ``start_login()`` + ``await_callback()`` instead. """ raise NotImplementedError( "Use start_login() + await_callback() for interactive OIDC flows." )
# ------------------------------------------------------------------ # # Token refresh # # ------------------------------------------------------------------ #
[docs] def refresh_token(self) -> TokenStore: """ Exchange the refresh token for a new access token (and possibly a new refresh token). Updates :attr:`tokens` in-place. Providers with **refresh token rotation** return a new refresh token with each call, which is saved automatically. Returns ------- TokenStore The updated token store. Raises ------ ValueError If no refresh token is available (e.g. ``offline_access`` scope was not requested). """ if not self.tokens.refresh_token: raise ValueError( "No refresh token available.\n" "Ensure 'offline_access' is included in scopes and re-run the login flow." ) print("🔄 Refreshing access token …") refresh_session = OAuth2Session( client_id=self.client_id, token={ "access_token": self.tokens.access_token, "refresh_token": self.tokens.refresh_token, "token_type": self.tokens.token_type, "expires_at": self.tokens.expires_at, }, ) new_token = refresh_session.refresh_token( self.token_url, refresh_token=self.tokens.refresh_token, # Public (PKCE) clients must include client_id in the refresh request client_id=self.client_id, ) self.tokens.update_from_raw(new_token) exp_fmt = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(self.tokens.expires_at) ) print(f"✅ Token refreshed! New expiry: {exp_fmt}") print(f" New access token : {self.tokens.access_token[:40]}…") return self.tokens
# ------------------------------------------------------------------ # # Private — email identity enforcement # # ------------------------------------------------------------------ # def _enforce_email_identity(self, raw_token: dict, expected_email: str) -> None: """ Verify that the identity returned in the token matches ``expected_email``. Checks (in order of preference): 1. ``id_token`` JWT payload — most reliable; issued by the provider and cryptographically bound to this auth session. 2. ``access_token`` JWT payload — fallback if id_token is absent. If ``verify_jwt`` is ``True``, JWT signature verification is performed using the provider's JWKS. Raises ------ IdentityMismatchError If the authenticated email does not match, or if no email claim can be found in either token. """ actual_email = self._extract_email_from_raw(raw_token) if actual_email is None: raise IdentityMismatchError( "Could not determine the authenticated user's email.\n" "Ensure the 'email' scope is included and that the provider returns " "an 'email' claim in the id_token or access_token." ) if actual_email.lower() != expected_email.lower(): raise IdentityMismatchError( f"\n{'═' * 60}\n" f" 🚫 Identity mismatch — login rejected\n" f"{'═' * 60}\n" f" Expected : {expected_email}\n" f" Got : {actual_email}\n" f"{'═' * 60}\n" "The authenticated account does not match the required email.\n" "Please re-run start_login() and log in with the correct account." ) def _extract_email_from_raw(self, raw_token: dict) -> Optional[str]: """ Extract the email claim from a raw token response dict. Checks the ``id_token`` first (most authoritative), then falls back to the ``access_token``. Returns ``None`` if not found. If ``verify_jwt`` is ``True``, uses signature-verified JWT decoding. Otherwise, uses unverified base64 decoding. """ for key in ("id_token", "access_token"): token_str = raw_token.get(key) if token_str: if self.verify_jwt: claims = self._verify_and_decode_jwt(token_str) else: claims = TokenStore.decode_jwt_payload_str(token_str) email = claims.get("email") if email: return email.strip() return None def _verify_and_decode_jwt(self, token_str: str) -> dict: """ Decode a JWT with signature verification using the provider's JWKS. Lazily fetches and caches the JWKS from the provider's ``jwks_uri`` discovered during initialization. Parameters ---------- token_str: The raw JWT string to verify and decode. Returns ------- dict The decoded and verified JWT claims. """ from authlib.jose import JsonWebKey from authlib.jose import jwt as jose_jwt if self._jwks is None: if not self._jwks_uri: # Attempt discovery if JWKS URI wasn't set at init self._discover_endpoints() if not self._jwks_uri: raise RuntimeError( "Cannot verify JWT: JWKS URI not available from the provider." ) jwks_resp = requests.get(self._jwks_uri, timeout=10) jwks_resp.raise_for_status() self._jwks = JsonWebKey.import_key_set(jwks_resp.json()) claims = jose_jwt.decode(token_str, self._jwks) claims.validate() return dict(claims) # ------------------------------------------------------------------ # # Internal helpers # # ------------------------------------------------------------------ # @staticmethod def _print_login_banner(url: str, expected_email: str) -> None: width = 65 print("═" * width) print(" 🔑 Click the link below to log in via your browser") if expected_email: print(f" 👤 You must log in as: {expected_email}") print("═" * width) print() print(url) print() print("After login completes, call auth.await_callback() in the next cell.") def __repr__(self) -> str: return ( f"OIDCAuth(domain={self.domain!r}, " f"client_id={self.client_id!r}, " f"redirect_port={self.redirect_port})\n" f" {self.tokens}" )