Source code for satorbis_kit.clients.airflow

"""Airflow API client for workflow orchestration."""

import json
from typing import Any, Dict
from urllib.parse import urljoin

import requests


[docs] class AirflowAPIError(requests.exceptions.RequestException): """Error raised for Airflow API related failures. Subclasses ``requests.exceptions.RequestException`` so callers can either catch this specific type or treat it as a generic requests error. """
[docs] class AirflowClient: """Client for communicating with Airflow API. This class handles all HTTP communication with Airflow API for submitting and tracking DAG runs. It can be used across different modules in the package. Attributes: base_url: Base URL for Airflow API username: Username for basic authentication password: Password for basic authentication session: Requests session with authentication configured """ TIMEOUT = 30 # Internal default timeout in seconds VERIFY_SSL = True # Internal default SSL verification def __init__( self, base_url: str, username: str, password: str, ): """Initialize Airflow client. Args: base_url: Base URL for Airflow API (required) username: Username for basic auth (required) password: Password for basic auth (required) """ # Ensure base URL ends with / self.base_url = base_url if base_url.endswith("/") else base_url + "/" self.username = username self.password = password # Setup session with auth self.session = requests.Session() self.session.auth = (self.username, self.password) def _get_headers(self) -> Dict[str, str]: """Get request headers. Returns: Dictionary of HTTP headers """ return { "Content-Type": "application/json", "Accept": "application/json", }
[docs] def trigger_dag_run( self, dag_id: str, config: Dict[str, Any], timeout: int = None, ) -> Dict[str, Any]: """Trigger a DAG run with configuration. Args: dag_id: DAG identifier config: Configuration dictionary to pass to the DAG (already formatted) timeout: Optional timeout override (uses class default if not provided) Returns: Dictionary with DAG run response including dag_run_id Raises: AirflowAPIError: If API request fails """ url = urljoin(self.base_url, f"api/v1/dags/{dag_id}/dagRuns") # config is already formatted with "conf" wrapper from IngestionConfig.to_payload() payload = config try: response = self.session.post( url, json=payload, headers=self._get_headers(), timeout=timeout or self.TIMEOUT, verify=self.VERIFY_SSL, ) # Check for HTTP errors if response.status_code >= 400: error_msg = f"Airflow API request failed with status {response.status_code}" try: error_details = response.json() if "detail" in error_details: error_msg += f": {error_details['detail']}" except (json.JSONDecodeError, ValueError): error_msg += f": {response.text}" raise AirflowAPIError(error_msg, response=response) # Parse response try: response_data = response.json() except json.JSONDecodeError as e: raise AirflowAPIError( f"Failed to parse API response as JSON: {e}", response=response ) return response_data except requests.exceptions.Timeout as e: raise AirflowAPIError( f"Request to Airflow API timed out after {timeout or self.TIMEOUT} seconds" ) from e except requests.exceptions.ConnectionError as e: raise AirflowAPIError(f"Failed to connect to Airflow API: {e}") from e except requests.exceptions.RequestException as e: raise AirflowAPIError(f"Airflow API request failed: {e}") from e
[docs] def get_dag_run_status(self, dag_id: str, dag_run_id: str) -> Dict[str, Any]: """Get status of a DAG run. Args: dag_id: DAG identifier dag_run_id: DAG run ID to query Returns: Dictionary with DAG run status information Raises: AirflowAPIError: If API request fails """ url = urljoin(self.base_url, f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}") try: response = self.session.get( url, headers=self._get_headers(), timeout=self.TIMEOUT, verify=self.VERIFY_SSL, ) if response.status_code >= 400: raise AirflowAPIError(f"Failed to get DAG run status: HTTP {response.status_code}") return response.json() except requests.exceptions.RequestException as e: raise AirflowAPIError(f"Failed to get DAG run status: {e}") from e