"""
Patch generation module using object-oriented design.
Follows SOLID principles and design patterns.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np
import rasterio
from dask import compute, delayed
from dask.diagnostics import ProgressBar
from rasterio.enums import Resampling
from rasterio.transform import Affine
from rasterio.warp import reproject
from rasterio.windows import Window
from .backends import Backend, get_backend
# ============================================================================
# Configuration Classes (Single Responsibility)
# ============================================================================
[docs]
@dataclass
class PatchConfig:
"""Configuration for patch generation."""
input_tif: str
output_folder: str
patch_size: int = 2048
overlap: int = 0
downsample_factor: Optional[float] = 1.0
driver: str = "GTiff"
rasterio_env: Optional[dict] = None
def __post_init__(self):
"""Validate configuration."""
if self.patch_size <= 0:
raise ValueError("patch_size must be > 0")
if self.overlap < 0:
raise ValueError("overlap must be >= 0")
if not (0 < self.downsample_factor <= 1.0):
raise ValueError("downsample_factor must be in (0, 1]")
# ============================================================================
# Utility Classes
# ============================================================================
[docs]
class PatchProcessor:
"""Utility class for patch processing operations (SRP)."""
[docs]
@staticmethod
def pad_to_size(data: np.ndarray, target_h: int, target_w: int, dtype) -> np.ndarray:
"""Pad data to (bands, target_h, target_w) with zeros on bottom/right."""
bands, h, w = data.shape
if h == target_h and w == target_w:
return data
out = np.zeros((bands, target_h, target_w), dtype=dtype)
out[:, :h, :w] = data
return out
[docs]
@staticmethod
def downsample_patch(
patch: np.ndarray,
src_transform: Affine,
downsample_factor: float,
resampling: Resampling = Resampling.bilinear,
) -> Tuple[np.ndarray, Affine]:
"""Downsample patch by downsample_factor."""
if downsample_factor is None or downsample_factor == 1.0:
return patch, src_transform
if downsample_factor <= 0 or downsample_factor > 1.0:
raise ValueError("downsample_factor must be in (0, 1].")
bands, h, w = patch.shape
out_h = max(1, int(round(h * downsample_factor)))
out_w = max(1, int(round(w * downsample_factor)))
# Prepare output array
out_arr = np.zeros((bands, out_h, out_w), dtype=patch.dtype)
# Destination transform scales pixel size by (1/downsample_factor)
scale = 1.0 / downsample_factor
new_transform = src_transform * Affine.scale(scale, scale)
for b in range(bands):
reproject(
source=patch[b],
destination=out_arr[b],
src_transform=src_transform,
src_crs=None,
dst_transform=new_transform,
dst_crs=None,
resampling=resampling,
)
return out_arr, new_transform
[docs]
class PatchWriter:
"""Utility class for writing patches to disk (SRP)."""
def __init__(self, config: PatchConfig):
"""Initialize with configuration."""
self.config = config
self.processor = PatchProcessor()
[docs]
def write_patch(self, patch_id: int, x: int, y: int, src_path: str) -> str:
"""Read, process, and write a single patch."""
# Read patch
if self.config.rasterio_env:
with rasterio.Env(**self.config.rasterio_env):
return self._read_and_write(src_path, patch_id, x, y)
else:
return self._read_and_write(src_path, patch_id, x, y)
def _read_and_write(self, src_path: str, patch_id: int, x: int, y: int) -> str:
"""Internal method to read and write patch."""
with rasterio.open(src_path) as src:
width, height = src.width, src.height
w = min(self.config.patch_size if self.config.patch_size > 0 else width - x, width - x)
h = min(
self.config.patch_size if self.config.patch_size > 0 else height - y, height - y
)
window = Window(x, y, w, h)
transform = src.window_transform(window)
patch_data = src.read(window=window)
src_dtype = src.dtypes[0]
src_count = src.count
src_crs = src.crs
# Pad if necessary
if (w != self.config.patch_size) or (h != self.config.patch_size):
patch_data = self.processor.pad_to_size(
patch_data, self.config.patch_size, self.config.patch_size, dtype=patch_data.dtype
)
# Downsample if necessary
if self.config.downsample_factor is not None and self.config.downsample_factor != 1.0:
out_arr, out_transform = self.processor.downsample_patch(
patch_data, transform, self.config.downsample_factor
)
else:
out_arr = patch_data
out_transform = transform
# Write to disk
patch_filename = os.path.join(self.config.output_folder, f"patch_{patch_id}.tif")
with rasterio.open(
patch_filename,
"w",
driver=self.config.driver,
height=out_arr.shape[1],
width=out_arr.shape[2],
count=out_arr.shape[0],
dtype=out_arr.dtype,
crs=src_crs,
transform=out_transform,
) as dst:
dst.write(out_arr)
return patch_filename
# ============================================================================
# Main Patch Generator Class (Open/Closed Principle)
# ============================================================================
[docs]
class PatchGenerator:
"""
Main class for generating patches from GeoTIFF files.
Follows OOP principles:
- Encapsulation: All patch generation logic in one class
- Single Responsibility: Only responsible for patch generation
- Dependency Injection: Accepts Backend for flexibility
"""
def __init__(self, config: PatchConfig, backend: Optional[Backend] = None):
"""Initialize patch generator with configuration and backend."""
self.config = config
self.backend = backend
self.writer = PatchWriter(config)
def _create_patch_tasks(self, width: int, height: int):
"""Create delayed tasks for patch generation."""
step = (
self.config.patch_size
if self.config.overlap == 0
else max(1, self.config.patch_size - self.config.overlap)
)
tasks = []
patch_id = 0
for y in range(0, height, step):
for x in range(0, width, step):
# Create task based on backend
if self.backend is not None:
task = self.backend.delayed(self.writer.write_patch)(
patch_id=patch_id, x=x, y=y, src_path=self.config.input_tif
)
else:
# Fallback to direct Dask (backward compatibility)
task = delayed(self.writer.write_patch)(
patch_id=patch_id, x=x, y=y, src_path=self.config.input_tif
)
tasks.append(task)
patch_id += 1
return tasks
def _execute_tasks(self, tasks, show_progress: bool = True):
"""Execute tasks using backend or direct Dask."""
if self.backend is not None:
return self.backend.compute(tasks, show_progress=show_progress)
else:
# Fallback to direct Dask
if show_progress:
with ProgressBar():
return compute(*tasks)
return compute(*tasks)
[docs]
def generate(self, show_progress: bool = True) -> int:
"""
Generate patches from the input GeoTIFF.
Args:
show_progress: Whether to show progress bar
Returns:
Number of patches created
"""
# Create output folder
os.makedirs(self.config.output_folder, exist_ok=True)
# Read raster dimensions
if self.config.rasterio_env:
with rasterio.Env(**self.config.rasterio_env):
with rasterio.open(self.config.input_tif) as src:
width, height = src.width, src.height
else:
with rasterio.open(self.config.input_tif) as src:
width, height = src.width, src.height
# Create and execute tasks
tasks = self._create_patch_tasks(width, height)
self._execute_tasks(tasks, show_progress=show_progress)
return len(tasks)
# ============================================================================
# Factory Functions (Backward Compatibility)
# ============================================================================
[docs]
def create_patches_dask(
input_tif: str,
output_folder: str,
patch_size: int = 2048,
overlap: int = 0,
downsample_factor: Optional[float] = 1.0,
show_progress: bool = True,
rasterio_env: Optional[dict] = None,
backend: Optional[Backend] = None,
) -> int:
"""
Create patches using the specified backend. Returns number of patches written.
Args:
input_tif: Path to input GeoTIFF file
output_folder: Path to output folder for patches
patch_size: Size of each patch in pixels
overlap: Overlap between patches in pixels
downsample_factor: Downsampling factor (0-1, 1.0 = no downsampling)
show_progress: Show progress bar
rasterio_env: Rasterio environment settings
backend: Backend instance (if None, uses Dask directly)
Returns:
Number of patches created
"""
config = PatchConfig(
input_tif=input_tif,
output_folder=output_folder,
patch_size=patch_size,
overlap=overlap,
downsample_factor=downsample_factor,
rasterio_env=rasterio_env,
)
generator = PatchGenerator(config, backend=backend)
return generator.generate(show_progress=show_progress)
[docs]
def create_patches(
input_tif: str,
output_folder: str,
patch_size: int = 2048,
overlap: int = 0,
downsample_factor: Optional[float] = 1.0,
rasterio_env: Optional[dict] = None,
backend: str = "dask",
scheduler_address: Optional[str] = None,
n_workers: int = 4,
threads_per_worker: int = 1,
processes: bool = True,
show_progress: bool = True,
) -> int:
"""
High-level wrapper that manages the backend lifecycle internally.
Args:
input_tif: Path to input GeoTIFF file
output_folder: Path to output folder for patches
patch_size: Size of each patch in pixels
overlap: Overlap between patches in pixels
downsample_factor: Downsampling factor (0-1, 1.0 = no downsampling)
rasterio_env: Rasterio environment settings
backend: Backend name ('dask' or 'sedona')
scheduler_address: Dask scheduler address (for 'dask' backend)
n_workers: Number of workers (for 'dask' backend)
threads_per_worker: Threads per worker (for 'dask' backend)
processes: Use processes instead of threads (for 'dask' backend)
show_progress: Show progress bar
Returns:
Number of patches created
"""
# Create configuration
config = PatchConfig(
input_tif=input_tif,
output_folder=output_folder,
patch_size=patch_size,
overlap=overlap,
downsample_factor=downsample_factor,
rasterio_env=rasterio_env,
)
# Use backend context manager
with get_backend(
name=backend,
scheduler_address=scheduler_address,
n_workers=n_workers,
threads_per_worker=threads_per_worker,
processes=processes,
) as backend_instance:
generator = PatchGenerator(config, backend=backend_instance)
return generator.generate(show_progress=show_progress)