Source code for satorbis_kit.vector_operation.wherobots_vector_data_ingestion

"""
Wherobots Vector Data Ingestion Operations.

This module submits a Wherobots Cloud job to ingest vector data into the
managed vector catalog, optionally streams logs, and waits for completion.
It also performs a client-side S3 size check before submitting the job.
"""

from datetime import datetime
from typing import List, Optional, Tuple

from .wherobots_config import (
    DEFAULT_API_KEY,
    DEFAULT_REGION,
    DEFAULT_RUNTIME,
    DEFAULT_SCRIPT_BASE_URI,
    DEFAULT_VECTOR_DATA_INGESTION_SCRIPT,
)
from .wherobots_status import get_job_logs, get_job_status, submit_job

SHP_EXTENSIONS = {".shp", ".shx", ".dbf", ".prj", ".cpg"}
DEFAULT_MAX_TOTAL_SIZE_MB = 10 * 1024
ALLOWED_DATABASE = "vector_catalog"
# Internal allowlist. Empty set means "allow any table in ALLOWED_DATABASE".
ALLOWED_TABLES: set[str] = set()


def _parse_s3_uri(s3_uri: str) -> Tuple[str, str]:
    """
    Split an S3 URI into bucket and prefix.

    Args:
        s3_uri: S3 URI in the form s3://bucket/prefix

    Returns:
        Tuple of (bucket, prefix) where prefix ends with '/' when present.
    """
    if not s3_uri.startswith("s3://"):
        raise ValueError(f"Expected S3 URI, got: {s3_uri}")
    path = s3_uri[5:]
    parts = path.split("/", 1)
    bucket = parts[0]
    prefix = parts[1] if len(parts) > 1 else ""
    if prefix and not prefix.endswith("/"):
        prefix += "/"
    return bucket, prefix


def _check_s3_size_limit(s3_path: str, max_total_size_mb: int) -> None:
    """
    Check total shapefile size under an S3 prefix.

    Args:
        s3_path: S3 prefix containing shapefile components.
        max_total_size_mb: Maximum allowed size (MB) for shapefile components.

    Raises:
        ValueError: If total size exceeds the limit.
    """
    try:
        import boto3
    except ImportError:
        print(
            "Warning: boto3 not available. Skipping S3 size validation.",
        )
        return

    bucket, prefix = _parse_s3_uri(s3_path)
    total_bytes = 0
    try:
        client = boto3.client("s3")
        paginator = client.get_paginator("list_objects_v2")
        for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
            for obj in page.get("Contents", []):
                key = obj["Key"]
                extension = key[key.rfind(".") :].lower() if "." in key else ""
                if extension in SHP_EXTENSIONS:
                    total_bytes += obj["Size"]
    except Exception:
        print(
            "Warning: unable to list S3 objects for size validation. "
            "Set AWS credentials with ListBucket permission to enable size checks.",
        )
        return

    max_bytes = max_total_size_mb * 1024 * 1024
    if total_bytes > max_bytes:
        raise ValueError(
            f"S3 shapefile size {total_bytes} bytes exceeds {max_bytes} bytes limit."
        )


[docs] def vector_data_ingestion( s3_path: str, database: str, table: str, partition_column: Optional[str], unique_columns: List[str], region: Optional[str] = None, column_renames: Optional[List[str]] = None, zorder_columns: Optional[List[str]] = None, format_version: str = "3", geohash_precision: int = 2, wait_for_completion: bool = False, poll_interval: int = 20, log_page_size: int = 200, job_name_prefix: str = "vector-data-ingestion", ) -> dict: """ Submit a vector data ingestion job to Wherobots Cloud. Args: s3_path: S3 prefix containing shapefile components. database: Destination database name. Must be ``vector_catalog``. table: Destination table name within ``vector_catalog``. partition_column: Column to partition by. If None, geohash is used. unique_columns: Columns used as the MERGE key. region: Wherobots region override (defaults to configured region). column_renames: Optional column renames in key=value format. zorder_columns: Optional columns for Z-order rewrite. format_version: Iceberg table format version. geohash_precision: Precision for geohash partitioning. wait_for_completion: If True, stream logs and wait for completion. poll_interval: Poll interval in seconds for status/logs. log_page_size: Log page size per API call. job_name_prefix: Prefix for the Wherobots job name. Returns: Response dictionary from the Wherobots API. """ if database != ALLOWED_DATABASE: raise ValueError( f"Database '{database}' is not allowed. Only '{ALLOWED_DATABASE}' is permitted." ) if ALLOWED_TABLES and table not in ALLOWED_TABLES: raise ValueError( f"Table '{table}' is not allowed. Allowed tables: {', '.join(sorted(ALLOWED_TABLES))}" ) api_key = DEFAULT_API_KEY region = region or DEFAULT_REGION script_base_uri = DEFAULT_SCRIPT_BASE_URI runtime = DEFAULT_RUNTIME timeout_seconds = 14400 script_uri = f"{script_base_uri.rstrip('/')}/{DEFAULT_VECTOR_DATA_INGESTION_SCRIPT}" script_args = [ "--s3-path", s3_path, "--database", database, "--table", table, "--unique-columns", ] + unique_columns if partition_column: script_args += ["--partition-column", partition_column] if column_renames: script_args += ["--column-renames"] + column_renames script_args += [ "--format-version", format_version, ] if partition_column is None: script_args += ["--geohash-precision", str(geohash_precision)] if zorder_columns: script_args += ["--zorder-columns"] + zorder_columns _check_s3_size_limit(s3_path, DEFAULT_MAX_TOTAL_SIZE_MB) timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") job_name = f"{job_name_prefix}-{timestamp}" result = submit_job( api_key=api_key, region=region, script_uri=script_uri, script_args=script_args, runtime=runtime, name=job_name, timeout_seconds=timeout_seconds, ) if not wait_for_completion: return result run_id = result.get("id") if not run_id: return result terminal_statuses = { "SUCCEEDED", "SUCCESS", "COMPLETED", "DONE", "FAILED", "CANCELED", "CANCELLED", "ERROR", } cursor = 0 last_status = None while True: logs = get_job_logs( api_key=api_key, run_id=run_id, cursor=cursor, size=log_page_size, ) for item in logs.get("items", []): message = item.get("message") or item.get("log") or item.get("text") or "" if message and str(message).strip(): print(message) next_cursor = logs.get("next_page") if next_cursor is None: next_cursor = cursor cursor = next_cursor status_payload = get_job_status(api_key=api_key, run_id=run_id) status = str(status_payload.get("status") or status_payload.get("state") or "") status = status.upper() if status and status != last_status: print(f"[status] {status}") last_status = status if status in terminal_statuses: break import time time.sleep(poll_interval) if status and status not in ["SUCCEEDED", "SUCCESS", "COMPLETED", "DONE"]: raise RuntimeError(f"Wherobots job failed with status: {status}") return result