Source code for satorbis_kit.clients.airflow
"""Airflow API client for workflow orchestration."""
import json
from typing import Any, Dict
from urllib.parse import urljoin
import requests
[docs]
class AirflowAPIError(requests.exceptions.RequestException):
"""Error raised for Airflow API related failures.
Subclasses ``requests.exceptions.RequestException`` so callers can either
catch this specific type or treat it as a generic requests error.
"""
[docs]
class AirflowClient:
"""Client for communicating with Airflow API.
This class handles all HTTP communication with Airflow API for submitting
and tracking DAG runs. It can be used across different modules in the package.
Attributes:
base_url: Base URL for Airflow API
username: Username for basic authentication
password: Password for basic authentication
session: Requests session with authentication configured
"""
TIMEOUT = 30 # Internal default timeout in seconds
VERIFY_SSL = True # Internal default SSL verification
def __init__(
self,
base_url: str,
username: str,
password: str,
):
"""Initialize Airflow client.
Args:
base_url: Base URL for Airflow API (required)
username: Username for basic auth (required)
password: Password for basic auth (required)
"""
# Ensure base URL ends with /
self.base_url = base_url if base_url.endswith("/") else base_url + "/"
self.username = username
self.password = password
# Setup session with auth
self.session = requests.Session()
self.session.auth = (self.username, self.password)
def _get_headers(self) -> Dict[str, str]:
"""Get request headers.
Returns:
Dictionary of HTTP headers
"""
return {
"Content-Type": "application/json",
"Accept": "application/json",
}
[docs]
def trigger_dag_run(
self,
dag_id: str,
config: Dict[str, Any],
timeout: int = None,
) -> Dict[str, Any]:
"""Trigger a DAG run with configuration.
Args:
dag_id: DAG identifier
config: Configuration dictionary to pass to the DAG (already formatted)
timeout: Optional timeout override (uses class default if not provided)
Returns:
Dictionary with DAG run response including dag_run_id
Raises:
AirflowAPIError: If API request fails
"""
url = urljoin(self.base_url, f"api/v1/dags/{dag_id}/dagRuns")
# config is already formatted with "conf" wrapper from IngestionConfig.to_payload()
payload = config
try:
response = self.session.post(
url,
json=payload,
headers=self._get_headers(),
timeout=timeout or self.TIMEOUT,
verify=self.VERIFY_SSL,
)
# Check for HTTP errors
if response.status_code >= 400:
error_msg = f"Airflow API request failed with status {response.status_code}"
try:
error_details = response.json()
if "detail" in error_details:
error_msg += f": {error_details['detail']}"
except (json.JSONDecodeError, ValueError):
error_msg += f": {response.text}"
raise AirflowAPIError(error_msg, response=response)
# Parse response
try:
response_data = response.json()
except json.JSONDecodeError as e:
raise AirflowAPIError(
f"Failed to parse API response as JSON: {e}", response=response
)
return response_data
except requests.exceptions.Timeout as e:
raise AirflowAPIError(
f"Request to Airflow API timed out after {timeout or self.TIMEOUT} seconds"
) from e
except requests.exceptions.ConnectionError as e:
raise AirflowAPIError(f"Failed to connect to Airflow API: {e}") from e
except requests.exceptions.RequestException as e:
raise AirflowAPIError(f"Airflow API request failed: {e}") from e
[docs]
def get_dag_run_status(self, dag_id: str, dag_run_id: str) -> Dict[str, Any]:
"""Get status of a DAG run.
Args:
dag_id: DAG identifier
dag_run_id: DAG run ID to query
Returns:
Dictionary with DAG run status information
Raises:
AirflowAPIError: If API request fails
"""
url = urljoin(self.base_url, f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}")
try:
response = self.session.get(
url,
headers=self._get_headers(),
timeout=self.TIMEOUT,
verify=self.VERIFY_SSL,
)
if response.status_code >= 400:
raise AirflowAPIError(f"Failed to get DAG run status: HTTP {response.status_code}")
return response.json()
except requests.exceptions.RequestException as e:
raise AirflowAPIError(f"Failed to get DAG run status: {e}") from e