"""
``OIDCAuth`` — the single entry-point class for the notebook.
Responsibilities
~~~~~~~~~~~~~~~~
* Build the PKCE authorization URL
* Launch the local callback server
* Wait for the callback and exchange the code for tokens
* Enforce that the authenticated identity matches the expected email
* Return a valid access token, refreshing automatically if needed
* Refresh tokens explicitly on demand
"""
import time
from typing import Optional, Union
import requests
from authlib.common.security import generate_token
from authlib.integrations.requests_client import OAuth2Session
from .auth_base import BaseAuth
from .callback_server import get_callback_url, launch_callback_listener
from .environments import Environment, get_env_config
from .token_store import TokenStore
[docs]
class IdentityMismatchError(Exception):
"""
Raised when the email in the returned token does not match the
``login_hint`` supplied to :meth:`OIDCAuth.start_login`.
This prevents a user from clearing the pre-filled email in the
browser and authenticating as a different identity.
"""
[docs]
class OIDCAuth(BaseAuth):
"""
Manages the full OIDC Authorization Code + PKCE flow against an
OIDC provider, designed for use inside a Jupyter notebook.
Parameters
----------
environment:
Target deployment environment. Accepts an :class:`Environment`
member, a plain string (``"prod"`` or ``"dev"``), or ``None``.
When ``None`` the value of the ``SATORBIS_ENV`` environment
variable is used; if that is also absent, ``prod`` is the
default. Example: ``OIDCAuth(environment="dev")``.
application:
Application name used to look up a Zitadel project ID from the
environment's ``project_ids`` mapping. When a project ID is found,
``urn:zitadel:iam:org:project:id:{project_id}:aud`` is appended to
the requested scopes so the token carries the project audience claim.
Example: ``OIDCAuth(application="spatial_engine")``.
redirect_port:
Local port for the OAuth callback server. Must match the
redirect URI registered with the provider. Default: ``4200``.
scopes:
Space-separated OIDC scopes. Include ``offline_access`` to
receive a refresh token.
callback_timeout:
Seconds to wait for the browser callback before raising
``TimeoutError``. Default: 300 (5 minutes).
authorization_url:
Optional. The authorization endpoint URL. If not provided,
it will be discovered from the OIDC discovery document.
token_url:
Optional. The token endpoint URL. If not provided,
it will be discovered from the OIDC discovery document.
verify_jwt:
If ``True``, JWT signature verification is performed using
the provider's JWKS when enforcing email identity.
Default: ``True``.
"""
# ------------------------------------------------------------------ #
# Construction #
# ------------------------------------------------------------------ #
def __init__(
self,
environment: Union[Environment, str, None] = None,
application: Optional[str] = None,
redirect_port: int = 4200,
scopes: str = "openid profile email offline_access",
callback_timeout: int = 300,
authorization_url: Optional[str] = None,
token_url: Optional[str] = None,
verify_jwt: bool = True,
) -> None:
super().__init__()
if isinstance(environment, str):
environment = Environment(environment)
resolved_env, config = get_env_config(environment)
self.environment = resolved_env
self.domain = config.domain.rstrip("/")
self.client_id = config.oidc_client_id
# Append project audience scope if the application has a project_id mapping
project_id = config.project_ids.get(application) if application else None
if project_id:
scopes = f"{scopes} urn:zitadel:iam:org:project:id:{project_id}:aud"
self.redirect_port = redirect_port
self.scopes = scopes
self.callback_timeout = callback_timeout
self.verify_jwt = verify_jwt
self.redirect_uri = f"http://localhost:{redirect_port}/logincallback"
self.discovery_url = f"{self.domain}/.well-known/openid-configuration"
# Endpoints — use explicit values or discover from the provider
if authorization_url and token_url:
self.authorization_url = authorization_url
self.token_url = token_url
else:
self._discover_endpoints()
# Internal PKCE / session state — set during start_login()
self._code_verifier: Optional[str] = None
self._state: Optional[str] = None
self._oauth_session: Optional[OAuth2Session] = None
self._expected_email: Optional[str] = None # enforced in await_callback()
# Cached JWKS for JWT verification (lazy-loaded)
self._jwks = None
self._jwks_uri: Optional[str] = None
# ------------------------------------------------------------------ #
# OIDC Discovery #
# ------------------------------------------------------------------ #
def _discover_endpoints(self) -> None:
"""
Fetch the OIDC discovery document and extract endpoint URLs.
Uses the ``/.well-known/openid-configuration`` endpoint to
retrieve the ``authorization_endpoint`` and ``token_endpoint``.
"""
resp = requests.get(self.discovery_url, timeout=10)
resp.raise_for_status()
config = resp.json()
self.authorization_url = config["authorization_endpoint"]
self.token_url = config["token_endpoint"]
self._jwks_uri = config.get("jwks_uri")
# ------------------------------------------------------------------ #
# Step 1 — Generate login URL #
# ------------------------------------------------------------------ #
[docs]
def start_login(self, login_hint: str = "") -> str:
"""
Generate the PKCE authorization URL and start the local
callback server in a background thread.
The ``login_hint`` email is remembered and **enforced** in
:meth:`await_callback` — if the user authenticates as a
different identity, :class:`IdentityMismatchError` is raised
and no tokens are stored.
Parameters
----------
login_hint:
Email address to pre-fill on the provider login page.
When provided, the authenticated token must match this email.
Returns
-------
str
The authorization URL. The user must open it in a browser.
"""
# Store for enforcement after callback
self._expected_email = login_hint.strip().lower() if login_hint else None
# Generate a fresh PKCE code_verifier for every login attempt
self._code_verifier = generate_token(64)
self._oauth_session = OAuth2Session(
client_id=self.client_id,
redirect_uri=self.redirect_uri,
scope=self.scopes,
code_challenge_method="S256",
)
login_url, self._state = self._oauth_session.create_authorization_url(
self.authorization_url,
code_verifier=self._code_verifier,
login_hint=login_hint,
nonce=generate_token(16),
)
# Start the local HTTP listener in the background
launch_callback_listener(
port=self.redirect_port,
timeout_seconds=self.callback_timeout,
)
self._print_login_banner(login_url, login_hint)
return login_url
# ------------------------------------------------------------------ #
# Step 2 — Await callback + exchange code #
# ------------------------------------------------------------------ #
[docs]
def await_callback(self) -> TokenStore:
"""
Block until the browser callback is received, exchange the
authorization code for tokens, then **verify** that the
authenticated email matches the ``login_hint`` from
:meth:`start_login`.
Raises
------
RuntimeError
If :meth:`start_login` has not been called first.
TimeoutError
If no callback is received within ``callback_timeout`` seconds.
IdentityMismatchError
If the authenticated user's email does not match the expected
``login_hint``. Tokens are **not** stored in this case.
"""
if self._oauth_session is None or self._code_verifier is None:
raise RuntimeError(
"Call start_login() first to generate the authorization URL."
)
print("⏳ Waiting for browser callback …")
callback_url = get_callback_url(timeout_seconds=self.callback_timeout)
print("✅ Callback received — exchanging authorization code …")
# Restore state so authlib can validate the `state` parameter
self._oauth_session.state = self._state
raw_token = self._oauth_session.fetch_token(
self.token_url,
authorization_response=callback_url,
code_verifier=self._code_verifier,
grant_type="authorization_code",
)
# ── Email identity enforcement ─────────────────────────────────
# Validate BEFORE storing any tokens so a mismatched login leaves
# the TokenStore empty and cannot be used for API calls.
if self._expected_email:
self._enforce_email_identity(raw_token, self._expected_email)
# All checks passed — persist tokens
self.tokens.update_from_raw(raw_token)
exp_fmt = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(self.tokens.expires_at)
)
print()
print("🎉 Tokens stored successfully!")
print(
f" Authenticated as : {self._extract_email_from_raw(raw_token) or 'unknown'}"
)
print(f" Access token : {self.tokens.access_token[:40]}…")
print(
f" Refresh token : "
f"{'✅ present' if self.tokens.refresh_token else '❌ not returned (check offline_access scope)'}"
)
print(f" Expires at : {exp_fmt}")
return self.tokens
# ------------------------------------------------------------------ #
# authenticate() — required by BaseAuth #
# ------------------------------------------------------------------ #
[docs]
def authenticate(self) -> TokenStore:
"""
For interactive OIDC flows, use :meth:`start_login` followed by
:meth:`await_callback`. This method is not directly applicable
to interactive browser-based flows.
Raises
------
NotImplementedError
Always. Use ``start_login()`` + ``await_callback()`` instead.
"""
raise NotImplementedError(
"Use start_login() + await_callback() for interactive OIDC flows."
)
# ------------------------------------------------------------------ #
# Token refresh #
# ------------------------------------------------------------------ #
[docs]
def refresh_token(self) -> TokenStore:
"""
Exchange the refresh token for a new access token (and possibly
a new refresh token). Updates :attr:`tokens` in-place.
Providers with **refresh token rotation** return a new refresh
token with each call, which is saved automatically.
Returns
-------
TokenStore
The updated token store.
Raises
------
ValueError
If no refresh token is available (e.g. ``offline_access``
scope was not requested).
"""
if not self.tokens.refresh_token:
raise ValueError(
"No refresh token available.\n"
"Ensure 'offline_access' is included in scopes and re-run the login flow."
)
print("🔄 Refreshing access token …")
refresh_session = OAuth2Session(
client_id=self.client_id,
token={
"access_token": self.tokens.access_token,
"refresh_token": self.tokens.refresh_token,
"token_type": self.tokens.token_type,
"expires_at": self.tokens.expires_at,
},
)
new_token = refresh_session.refresh_token(
self.token_url,
refresh_token=self.tokens.refresh_token,
# Public (PKCE) clients must include client_id in the refresh request
client_id=self.client_id,
)
self.tokens.update_from_raw(new_token)
exp_fmt = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(self.tokens.expires_at)
)
print(f"✅ Token refreshed! New expiry: {exp_fmt}")
print(f" New access token : {self.tokens.access_token[:40]}…")
return self.tokens
# ------------------------------------------------------------------ #
# Private — email identity enforcement #
# ------------------------------------------------------------------ #
def _enforce_email_identity(self, raw_token: dict, expected_email: str) -> None:
"""
Verify that the identity returned in the token matches
``expected_email``. Checks (in order of preference):
1. ``id_token`` JWT payload — most reliable; issued by the provider
and cryptographically bound to this auth session.
2. ``access_token`` JWT payload — fallback if id_token is absent.
If ``verify_jwt`` is ``True``, JWT signature verification is
performed using the provider's JWKS.
Raises
------
IdentityMismatchError
If the authenticated email does not match, or if no email
claim can be found in either token.
"""
actual_email = self._extract_email_from_raw(raw_token)
if actual_email is None:
raise IdentityMismatchError(
"Could not determine the authenticated user's email.\n"
"Ensure the 'email' scope is included and that the provider returns "
"an 'email' claim in the id_token or access_token."
)
if actual_email.lower() != expected_email.lower():
raise IdentityMismatchError(
f"\n{'═' * 60}\n"
f" 🚫 Identity mismatch — login rejected\n"
f"{'═' * 60}\n"
f" Expected : {expected_email}\n"
f" Got : {actual_email}\n"
f"{'═' * 60}\n"
"The authenticated account does not match the required email.\n"
"Please re-run start_login() and log in with the correct account."
)
def _extract_email_from_raw(self, raw_token: dict) -> Optional[str]:
"""
Extract the email claim from a raw token response dict.
Checks the ``id_token`` first (most authoritative), then falls
back to the ``access_token``. Returns ``None`` if not found.
If ``verify_jwt`` is ``True``, uses signature-verified JWT
decoding. Otherwise, uses unverified base64 decoding.
"""
for key in ("id_token", "access_token"):
token_str = raw_token.get(key)
if token_str:
if self.verify_jwt:
claims = self._verify_and_decode_jwt(token_str)
else:
claims = TokenStore.decode_jwt_payload_str(token_str)
email = claims.get("email")
if email:
return email.strip()
return None
def _verify_and_decode_jwt(self, token_str: str) -> dict:
"""
Decode a JWT with signature verification using the provider's JWKS.
Lazily fetches and caches the JWKS from the provider's
``jwks_uri`` discovered during initialization.
Parameters
----------
token_str:
The raw JWT string to verify and decode.
Returns
-------
dict
The decoded and verified JWT claims.
"""
from authlib.jose import JsonWebKey
from authlib.jose import jwt as jose_jwt
if self._jwks is None:
if not self._jwks_uri:
# Attempt discovery if JWKS URI wasn't set at init
self._discover_endpoints()
if not self._jwks_uri:
raise RuntimeError(
"Cannot verify JWT: JWKS URI not available from the provider."
)
jwks_resp = requests.get(self._jwks_uri, timeout=10)
jwks_resp.raise_for_status()
self._jwks = JsonWebKey.import_key_set(jwks_resp.json())
claims = jose_jwt.decode(token_str, self._jwks)
claims.validate()
return dict(claims)
# ------------------------------------------------------------------ #
# Internal helpers #
# ------------------------------------------------------------------ #
@staticmethod
def _print_login_banner(url: str, expected_email: str) -> None:
width = 65
print("═" * width)
print(" 🔑 Click the link below to log in via your browser")
if expected_email:
print(f" 👤 You must log in as: {expected_email}")
print("═" * width)
print()
print(url)
print()
print("After login completes, call auth.await_callback() in the next cell.")
def __repr__(self) -> str:
return (
f"OIDCAuth(domain={self.domain!r}, "
f"client_id={self.client_id!r}, "
f"redirect_port={self.redirect_port})\n"
f" {self.tokens}"
)