from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, Optional, Protocol
from dask import compute as dask_compute
from dask import delayed as dask_delayed
from dask.diagnostics import ProgressBar
from .dask_utils import dask_client
# ============================================================================
# Backend Protocol (Strategy Pattern)
# ============================================================================
[docs]
class ComputeStrategy(Protocol):
"""Protocol defining the compute strategy interface (Strategy Pattern)"""
[docs]
def delayed(self, fn) -> Any:
"""Mark a function as a delayed task for parallel execution."""
...
[docs]
def compute(self, tasks: Iterable, show_progress: bool = True) -> Any:
"""Execute delayed tasks and return results."""
...
[docs]
def close(self) -> None:
"""Clean up resources."""
...
# ============================================================================
# Backend Abstract Base Class
# ============================================================================
[docs]
class Backend(ABC):
"""
Abstract base class for compute backends (Dask, Sedona, etc.)
Follows SOLID principles:
- Single Responsibility: Each backend handles one compute framework
- Open/Closed: Open for extension via inheritance
- Liskov Substitution: All backends are interchangeable
- Interface Segregation: Clean, focused interface
- Dependency Inversion: Depend on abstractions, not concretions
"""
[docs]
@abstractmethod
def delayed(self, fn):
"""Mark a function as a delayed task for parallel execution."""
pass
[docs]
@abstractmethod
def compute(self, tasks: Iterable, show_progress: bool = True):
"""Execute delayed tasks and return results."""
pass
[docs]
@abstractmethod
def close(self):
"""Clean up resources."""
pass
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit with cleanup."""
self.close()
# ============================================================================
# Dask Backend Implementation
# ============================================================================
[docs]
class DaskBackend(Backend):
"""Dask-based backend for parallel computation."""
def __init__(self, client) -> None:
self.client = client
[docs]
def delayed(self, fn):
"""Create a Dask delayed task."""
return dask_delayed(fn)
[docs]
def compute(self, tasks: Iterable, show_progress: bool = True):
"""Execute Dask tasks with optional progress bar."""
if show_progress:
with ProgressBar():
return dask_compute(*tasks)
return dask_compute(*tasks)
[docs]
def close(self):
"""Close the Dask client."""
if self.client is not None:
self.client.close()
# ============================================================================
# Sedona Backend Implementation (placeholder)
# ============================================================================
[docs]
class SedonaBackend(Backend):
"""Apache Sedona-based backend for geospatial processing."""
def __init__(self, spark_session) -> None:
self.spark_session = spark_session
[docs]
def delayed(self, fn):
"""Create a Sedona task (to be implemented)."""
# TODO: Implement Sedona task creation
raise NotImplementedError("Sedona backend not yet implemented")
[docs]
def compute(self, tasks: Iterable, show_progress: bool = True):
"""Execute Sedona tasks (to be implemented)."""
# TODO: Implement Sedona task execution
raise NotImplementedError("Sedona backend not yet implemented")
[docs]
def close(self):
"""Close the Spark session."""
if self.spark_session is not None:
self.spark_session.stop()
# ============================================================================
# Configuration Dataclasses (Single Responsibility)
# ============================================================================
[docs]
@dataclass
class DaskConfig:
"""Configuration for Dask backend (Single Responsibility Principle)."""
scheduler_address: Optional[str] = None
n_workers: int = 4
threads_per_worker: int = 1
processes: bool = True
[docs]
@dataclass
class SedonaConfig:
"""Configuration for Sedona backend (Single Responsibility Principle)."""
spark_master: str = "local[*]"
app_name: str = "SatorbisKit"
# ============================================================================
# Backend Factory (Factory Pattern + Strategy Pattern)
# ============================================================================
[docs]
class BackendFactory:
"""
Factory for creating and managing compute backends.
Follows Factory Pattern + Singleton Pattern for managing backend registry.
"""
_instance = None
_backends: Dict[str, type] = {
"dask": DaskBackend,
"sedona": SedonaBackend,
}
def __new__(cls):
"""Singleton pattern - ensure only one factory instance."""
if cls._instance is None:
cls._instance = super(BackendFactory, cls).__new__(cls)
return cls._instance
[docs]
def register_backend(self, name: str, backend_class: type) -> None:
"""
Register a new backend type.
Args:
name: Backend name
backend_class: Backend class that subclasses Backend
Raises:
TypeError: If backend_class doesn't subclass Backend
"""
if not issubclass(backend_class, Backend):
raise TypeError(f"{backend_class} must subclass Backend")
self._backends[name] = backend_class
[docs]
def list_backends(self) -> list[str]:
"""List all registered backend names."""
return list(self._backends.keys())
[docs]
def get_backend_class(self, name: str) -> type:
"""
Get a backend class by name.
Args:
name: Backend name
Returns:
Backend class
Raises:
ValueError: If backend name is unknown
"""
if name not in self._backends:
raise ValueError(
f"Unknown backend '{name}'. Available backends: {list(self._backends.keys())}"
)
return self._backends[name]
[docs]
def create_dask_backend(self, client) -> DaskBackend:
"""Create a DaskBackend instance."""
return DaskBackend(client)
[docs]
def create_sedona_backend(self, spark_session) -> SedonaBackend:
"""Create a SedonaBackend instance."""
return SedonaBackend(spark_session)
[docs]
@classmethod
def get_instance(cls) -> "BackendFactory":
"""Get the singleton instance."""
if cls._instance is None:
cls._instance = cls()
return cls._instance
# ============================================================================
# Backend Manager
# ============================================================================
[docs]
@contextmanager
def get_backend(
name: str = "dask",
*,
scheduler_address: Optional[str] = None,
n_workers: int = 4,
threads_per_worker: int = 1,
processes: bool = True,
spark_master: Optional[str] = None,
**kwargs,
) -> Iterator[Backend]:
"""
Factory/context manager for creating and managing compute backends.
Uses Strategy Pattern to select appropriate backend implementation.
Uses Dependency Injection for configuration.
Args:
name: Backend name ('dask' or 'sedona')
scheduler_address: Dask scheduler address (for 'dask' backend)
n_workers: Number of workers (for 'dask' backend)
threads_per_worker: Threads per worker (for 'dask' backend)
processes: Use processes instead of threads (for 'dask' backend)
spark_master: Spark master URL (for 'sedona' backend)
**kwargs: Additional backend-specific parameters
Yields:
Backend instance
Raises:
ValueError: If unknown backend name
NotImplementedError: If backend not yet implemented
"""
factory = BackendFactory.get_instance()
backend: Optional[Backend] = None
try:
if name == "dask":
# Create Dask backend using factory and config
config = DaskConfig(
scheduler_address=scheduler_address,
n_workers=n_workers,
threads_per_worker=threads_per_worker,
processes=processes,
)
with dask_client(
scheduler_address=config.scheduler_address,
n_workers=config.n_workers,
threads_per_worker=config.threads_per_worker,
processes=config.processes,
) as client:
backend = factory.create_dask_backend(client)
yield backend
elif name == "sedona":
# Create Sedona backend (placeholder - requires implementation)
config = SedonaConfig(spark_master=spark_master or "local[*]")
# TODO: Implement Sedona/Spark session creation
raise NotImplementedError(
"Sedona backend requires Spark session setup. "
"Please install pyspark and sedona-python to use this backend."
)
else:
available = factory.list_backends()
raise ValueError(f"Unknown backend: {name}. Available: {available}")
finally:
if backend is not None:
backend.close()