Source code for satorbis_kit.patch_generation

"""
Patch generation module using object-oriented design.

Follows SOLID principles and design patterns.
"""

import os
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
import rasterio
from dask import compute, delayed
from dask.diagnostics import ProgressBar
from rasterio.enums import Resampling
from rasterio.transform import Affine
from rasterio.warp import reproject
from rasterio.windows import Window

from .backends import Backend, get_backend

# ============================================================================
# Configuration Classes (Single Responsibility)
# ============================================================================


[docs] @dataclass class PatchConfig: """Configuration for patch generation.""" input_tif: str output_folder: str patch_size: int = 2048 overlap: int = 0 downsample_factor: Optional[float] = 1.0 driver: str = "GTiff" rasterio_env: Optional[dict] = None def __post_init__(self): """Validate configuration.""" if self.patch_size <= 0: raise ValueError("patch_size must be > 0") if self.overlap < 0: raise ValueError("overlap must be >= 0") if not (0 < self.downsample_factor <= 1.0): raise ValueError("downsample_factor must be in (0, 1]")
# ============================================================================ # Utility Classes # ============================================================================
[docs] class PatchProcessor: """Utility class for patch processing operations (SRP)."""
[docs] @staticmethod def pad_to_size(data: np.ndarray, target_h: int, target_w: int, dtype) -> np.ndarray: """Pad data to (bands, target_h, target_w) with zeros on bottom/right.""" bands, h, w = data.shape if h == target_h and w == target_w: return data out = np.zeros((bands, target_h, target_w), dtype=dtype) out[:, :h, :w] = data return out
[docs] @staticmethod def downsample_patch( patch: np.ndarray, src_transform: Affine, downsample_factor: float, resampling: Resampling = Resampling.bilinear, ) -> Tuple[np.ndarray, Affine]: """Downsample patch by downsample_factor.""" if downsample_factor is None or downsample_factor == 1.0: return patch, src_transform if downsample_factor <= 0 or downsample_factor > 1.0: raise ValueError("downsample_factor must be in (0, 1].") bands, h, w = patch.shape out_h = max(1, int(round(h * downsample_factor))) out_w = max(1, int(round(w * downsample_factor))) # Prepare output array out_arr = np.zeros((bands, out_h, out_w), dtype=patch.dtype) # Destination transform scales pixel size by (1/downsample_factor) scale = 1.0 / downsample_factor new_transform = src_transform * Affine.scale(scale, scale) for b in range(bands): reproject( source=patch[b], destination=out_arr[b], src_transform=src_transform, src_crs=None, dst_transform=new_transform, dst_crs=None, resampling=resampling, ) return out_arr, new_transform
[docs] class PatchWriter: """Utility class for writing patches to disk (SRP).""" def __init__(self, config: PatchConfig): """Initialize with configuration.""" self.config = config self.processor = PatchProcessor()
[docs] def write_patch(self, patch_id: int, x: int, y: int, src_path: str) -> str: """Read, process, and write a single patch.""" # Read patch if self.config.rasterio_env: with rasterio.Env(**self.config.rasterio_env): return self._read_and_write(src_path, patch_id, x, y) else: return self._read_and_write(src_path, patch_id, x, y)
def _read_and_write(self, src_path: str, patch_id: int, x: int, y: int) -> str: """Internal method to read and write patch.""" with rasterio.open(src_path) as src: width, height = src.width, src.height w = min(self.config.patch_size if self.config.patch_size > 0 else width - x, width - x) h = min( self.config.patch_size if self.config.patch_size > 0 else height - y, height - y ) window = Window(x, y, w, h) transform = src.window_transform(window) patch_data = src.read(window=window) src_dtype = src.dtypes[0] src_count = src.count src_crs = src.crs # Pad if necessary if (w != self.config.patch_size) or (h != self.config.patch_size): patch_data = self.processor.pad_to_size( patch_data, self.config.patch_size, self.config.patch_size, dtype=patch_data.dtype ) # Downsample if necessary if self.config.downsample_factor is not None and self.config.downsample_factor != 1.0: out_arr, out_transform = self.processor.downsample_patch( patch_data, transform, self.config.downsample_factor ) else: out_arr = patch_data out_transform = transform # Write to disk patch_filename = os.path.join(self.config.output_folder, f"patch_{patch_id}.tif") with rasterio.open( patch_filename, "w", driver=self.config.driver, height=out_arr.shape[1], width=out_arr.shape[2], count=out_arr.shape[0], dtype=out_arr.dtype, crs=src_crs, transform=out_transform, ) as dst: dst.write(out_arr) return patch_filename
# ============================================================================ # Main Patch Generator Class (Open/Closed Principle) # ============================================================================
[docs] class PatchGenerator: """ Main class for generating patches from GeoTIFF files. Follows OOP principles: - Encapsulation: All patch generation logic in one class - Single Responsibility: Only responsible for patch generation - Dependency Injection: Accepts Backend for flexibility """ def __init__(self, config: PatchConfig, backend: Optional[Backend] = None): """Initialize patch generator with configuration and backend.""" self.config = config self.backend = backend self.writer = PatchWriter(config) def _create_patch_tasks(self, width: int, height: int): """Create delayed tasks for patch generation.""" step = ( self.config.patch_size if self.config.overlap == 0 else max(1, self.config.patch_size - self.config.overlap) ) tasks = [] patch_id = 0 for y in range(0, height, step): for x in range(0, width, step): # Create task based on backend if self.backend is not None: task = self.backend.delayed(self.writer.write_patch)( patch_id=patch_id, x=x, y=y, src_path=self.config.input_tif ) else: # Fallback to direct Dask (backward compatibility) task = delayed(self.writer.write_patch)( patch_id=patch_id, x=x, y=y, src_path=self.config.input_tif ) tasks.append(task) patch_id += 1 return tasks def _execute_tasks(self, tasks, show_progress: bool = True): """Execute tasks using backend or direct Dask.""" if self.backend is not None: return self.backend.compute(tasks, show_progress=show_progress) else: # Fallback to direct Dask if show_progress: with ProgressBar(): return compute(*tasks) return compute(*tasks)
[docs] def generate(self, show_progress: bool = True) -> int: """ Generate patches from the input GeoTIFF. Args: show_progress: Whether to show progress bar Returns: Number of patches created """ # Create output folder os.makedirs(self.config.output_folder, exist_ok=True) # Read raster dimensions if self.config.rasterio_env: with rasterio.Env(**self.config.rasterio_env): with rasterio.open(self.config.input_tif) as src: width, height = src.width, src.height else: with rasterio.open(self.config.input_tif) as src: width, height = src.width, src.height # Create and execute tasks tasks = self._create_patch_tasks(width, height) self._execute_tasks(tasks, show_progress=show_progress) return len(tasks)
# ============================================================================ # Factory Functions (Backward Compatibility) # ============================================================================
[docs] def create_patches_dask( input_tif: str, output_folder: str, patch_size: int = 2048, overlap: int = 0, downsample_factor: Optional[float] = 1.0, show_progress: bool = True, rasterio_env: Optional[dict] = None, backend: Optional[Backend] = None, ) -> int: """ Create patches using the specified backend. Returns number of patches written. Args: input_tif: Path to input GeoTIFF file output_folder: Path to output folder for patches patch_size: Size of each patch in pixels overlap: Overlap between patches in pixels downsample_factor: Downsampling factor (0-1, 1.0 = no downsampling) show_progress: Show progress bar rasterio_env: Rasterio environment settings backend: Backend instance (if None, uses Dask directly) Returns: Number of patches created """ config = PatchConfig( input_tif=input_tif, output_folder=output_folder, patch_size=patch_size, overlap=overlap, downsample_factor=downsample_factor, rasterio_env=rasterio_env, ) generator = PatchGenerator(config, backend=backend) return generator.generate(show_progress=show_progress)
[docs] def create_patches( input_tif: str, output_folder: str, patch_size: int = 2048, overlap: int = 0, downsample_factor: Optional[float] = 1.0, rasterio_env: Optional[dict] = None, backend: str = "dask", scheduler_address: Optional[str] = None, n_workers: int = 4, threads_per_worker: int = 1, processes: bool = True, show_progress: bool = True, ) -> int: """ High-level wrapper that manages the backend lifecycle internally. Args: input_tif: Path to input GeoTIFF file output_folder: Path to output folder for patches patch_size: Size of each patch in pixels overlap: Overlap between patches in pixels downsample_factor: Downsampling factor (0-1, 1.0 = no downsampling) rasterio_env: Rasterio environment settings backend: Backend name ('dask' or 'sedona') scheduler_address: Dask scheduler address (for 'dask' backend) n_workers: Number of workers (for 'dask' backend) threads_per_worker: Threads per worker (for 'dask' backend) processes: Use processes instead of threads (for 'dask' backend) show_progress: Show progress bar Returns: Number of patches created """ # Create configuration config = PatchConfig( input_tif=input_tif, output_folder=output_folder, patch_size=patch_size, overlap=overlap, downsample_factor=downsample_factor, rasterio_env=rasterio_env, ) # Use backend context manager with get_backend( name=backend, scheduler_address=scheduler_address, n_workers=n_workers, threads_per_worker=threads_per_worker, processes=processes, ) as backend_instance: generator = PatchGenerator(config, backend=backend_instance) return generator.generate(show_progress=show_progress)