"""Utilities for uploading and downloading objects in cloud storage."""
from __future__ import annotations
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
from urllib.parse import urlparse
from obstore import exceptions as ob_exceptions # type: ignore[import-not-found]
from obstore.store import ( # type: ignore[import-not-found]
AzureStore,
ObjectStore,
S3Store,
)
from .exceptions import StorageError
LocalPath = Union[str, Path]
UploadSpec = Tuple[LocalPath, str]
[docs]
class CloudObjectStore:
"""Generic helper for interacting with AWS S3 or Azure Blob via obstore."""
def __init__(
self,
provider: str,
store: ObjectStore,
*,
bucket: Optional[str] = None,
account_name: Optional[str] = None,
container: Optional[str] = None,
base_prefix: str = "",
) -> None:
self.provider = provider
self._store = store
self._bucket = bucket
self._account_name = account_name
self._container = container
self._base_prefix = base_prefix.strip("/")
# ------------------------------------------------------------------
# Factory helpers
# ------------------------------------------------------------------
[docs]
@classmethod
def from_aws(
cls,
*,
bucket: str,
region: Optional[str] = None,
access_key_id: Optional[str] = None,
secret_access_key: Optional[str] = None,
session_token: Optional[str] = None,
base_prefix: str = "",
) -> "CloudObjectStore":
if not bucket:
raise StorageError("bucket name is required to configure AWS storage")
kwargs = {
"region": region,
"access_key_id": access_key_id,
"secret_access_key": secret_access_key,
}
if session_token:
kwargs["session_token"] = session_token
sanitized_kwargs = {k: v for k, v in kwargs.items() if v}
store = S3Store.from_url(f"s3://{bucket}", **sanitized_kwargs)
return cls(
provider="aws",
store=store,
bucket=bucket,
base_prefix=base_prefix,
)
[docs]
@classmethod
def from_azure(
cls,
*,
account_name: str,
container: str,
account_key: Optional[str] = None,
sas_token: Optional[str] = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
tenant_id: Optional[str] = None,
base_prefix: str = "",
) -> "CloudObjectStore":
if not account_name:
raise StorageError("account_name is required to configure Azure storage")
if not container:
raise StorageError("container is required to configure Azure storage")
azure_kwargs = {"account_name": account_name}
if account_key:
azure_kwargs["account_key"] = account_key
if sas_token:
azure_kwargs["sas_token"] = sas_token
if client_id:
azure_kwargs["client_id"] = client_id
if client_secret:
azure_kwargs["client_secret"] = client_secret
if tenant_id:
azure_kwargs["tenant_id"] = tenant_id
store = AzureStore.from_url(f"az://{container}", **azure_kwargs)
return cls(
provider="azure",
store=store,
account_name=account_name,
container=container,
base_prefix=base_prefix,
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def upload_file(
self,
local_path: LocalPath,
*,
remote_path: str,
overwrite: bool = False,
) -> str:
"""Upload a single file to cloud storage at the given remote path."""
path = Path(local_path)
if not path.is_file():
raise StorageError(f"Local file not found: {path}")
key = self._build_object_key(remote_path)
if not overwrite and self._object_exists(key):
raise StorageError(
f"Remote object already exists at {self._build_remote_url(key)}. "
"Pass overwrite=True to replace it."
)
try:
with path.open("rb") as stream:
self._store.put(key, stream)
except Exception as exc: # pragma: no cover - obstore surfaces real error
raise StorageError(f"Failed to upload '{path.name}' to cloud storage: {exc}") from exc
return self._build_remote_url(key)
[docs]
def upload_files(
self,
uploads: Sequence[UploadSpec],
*,
overwrite: bool = False,
) -> List[str]:
"""Upload multiple files. Each entry is a `(local_path, remote_path)` tuple."""
urls: List[str] = []
for local_path, remote_path in uploads:
urls.append(
self.upload_file(
local_path,
remote_path=remote_path,
overwrite=overwrite,
)
)
return urls
[docs]
def download_file(self, remote_reference: str, destination: LocalPath) -> Path:
"""Download a file given either a remote path or full URL."""
key = (
self._extract_key_from_url(remote_reference)
if self._looks_like_url(remote_reference)
else self._build_object_key(remote_reference)
)
try:
response = self._store.get(key)
except Exception as exc: # pragma: no cover - obstore surfaces real error
raise StorageError(f"Failed to download '{remote_reference}': {exc}") from exc
dest_path = Path(destination)
dest_path.parent.mkdir(parents=True, exist_ok=True)
dest_path.write_bytes(bytes(response.bytes()))
return dest_path
[docs]
def build_url(self, remote_path: str) -> str:
"""Return the fully qualified URL for the given remote path."""
key = self._build_object_key(remote_path)
return self._build_remote_url(key)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_object_key(self, remote_path: str) -> str:
if not remote_path:
raise StorageError("remote_path cannot be empty")
normalized = remote_path.strip("/")
segments = [segment for segment in [self._base_prefix, normalized] if segment]
return "/".join(segments)
def _build_remote_url(self, key: str) -> str:
if self.provider == "aws" and self._bucket:
return f"s3://{self._bucket}/{key}"
if self.provider == "azure" and self._account_name and self._container:
return f"az://{self._account_name}/{self._container}/{key}"
raise StorageError("Unable to determine remote URL for the configured provider")
def _object_exists(self, key: str) -> bool:
try:
self._store.head(key)
return True
except (ob_exceptions.NotFoundError, FileNotFoundError):
return False
except Exception as exc: # pragma: no cover
raise StorageError(f"Failed to check existence of '{key}': {exc}") from exc
def _looks_like_url(self, reference: str) -> bool:
return reference.startswith(("s3://", "s3a://", "s3n://", "az://"))
def _extract_key_from_url(self, remote_url: str) -> str:
parsed = urlparse(remote_url)
if self.provider == "aws":
if parsed.scheme not in ("s3", "s3a", "s3n"):
raise StorageError("Expected an s3:// URL for AWS downloads")
if self._bucket and parsed.netloc != self._bucket:
raise StorageError(
f"URL bucket '{parsed.netloc}' does not match configured bucket '{self._bucket}'"
)
key = parsed.path.lstrip("/")
if not key:
raise StorageError("S3 URL does not contain an object key")
return self._build_object_key(key)
if self.provider == "azure":
if parsed.scheme != "az":
raise StorageError("Expected an az:// URL for Azure downloads")
if self._account_name and parsed.netloc != self._account_name:
raise StorageError(
f"URL account '{parsed.netloc}' does not match configured account '{self._account_name}'"
)
path_parts = parsed.path.lstrip("/").split("/", 1)
if not path_parts or path_parts[0] != self._container:
raise StorageError("URL container does not match configured container")
if len(path_parts) < 2 or not path_parts[1]:
raise StorageError("URL does not contain an object key")
return self._build_object_key(path_parts[1])
raise StorageError(f"Unsupported provider '{self.provider}'")