Source code for kumoai.connector.s3_connector

import logging
from typing import List, Optional

from kumoapi.data_source import DataSourceType
from kumoapi.source_table import (
    S3SourceTableRequest,
    SourceTableConfigRequest,
    SourceTableConfigResponse,
    SourceTableListRequest,
)
from typing_extensions import Self, override

from kumoai import global_state
from kumoai.connector import Connector
from kumoai.connector.source_table import SourceTable
from kumoai.exceptions import HTTPException

logger = logging.getLogger(__name__)


[docs]class S3Connector(Connector): r"""Defines a connector to a table stored as a file (or partitioned set of files) on the Amazon `S3 <https://aws.amazon.com/s3/>`__ object store. Any table behind an S3 bucket accessible by the shared external IAM role can be accessed through this connector. .. code-block:: python import kumoai connector = kumoai.S3Connector(root_dir="s3://...") # an S3 path. # List all tables: print(connector.table_names()) # Returns: ['articles', 'customers', 'users'] # Check whether a table is present: assert "articles" in connector # Fetch a source table (both approaches are equivalent): source_table = connector["articles"] source_table = connector.table("articles") Args: root_dir: The root directory of this connector. If provided, the root directory is used as a prefix for tables in this connector. If not provided, all tables must be specified by their full S3 paths. """ # noqa
[docs] def __init__(self, root_dir: Optional[str] = None) -> None: if root_dir is not None: # Remove trailing / to be consistent with boto s3 root_dir = root_dir.rstrip('/') self.root_dir = root_dir if global_state.is_spcs and root_dir is not None \ and root_dir.startswith('s3://'): raise ValueError( "S3 connectors are not supported when running Kumo in " "Snowpark container services. Please use a Snowflake " "connector instead.")
@override @property def name(self) -> str: r"""Not supported by :class:`S3Connector`; returns an internal specifier. """ return "s3_connector" @override @property def source_type(self) -> DataSourceType: return DataSourceType.S3 @override def _source_table_request( self, table_names: List[str], ) -> S3SourceTableRequest: root_dir = self.root_dir if not root_dir: # Handle None root directories (table name is a path): table_path = S3URI(table_names[0]).validate() root_dir = table_path.root_dir for i, v in enumerate(table_names): uri = S3URI(v) if uri.root_dir != root_dir: # TODO(manan): fix raise ValueError( f"Please ensure that all of your tables are behind " f"the same root directory ({root_dir}).") table_names[i] = uri.object_name # TODO(manan): file type? return S3SourceTableRequest(s3_root_dir=root_dir, table_names=table_names)
[docs] def has_table(self, name: str) -> bool: r"""Returns :obj:`True` if the table exists in this connector, :obj:`False` otherwise. Args: name: The name of the table on S3. If :obj:`root_dir` is provided, the path will be specified as :obj:`root_dir/name`. If :obj:`root_dir` is not provided, the name should be the full path (e.g. starting with `s3://`). """ # TODO(manan): this is silly. Just write a quick endpoint to check the # validity of an individual path: root_dir = self.root_dir table_name = None if not root_dir: # Remove trailing / to be consistent with boto s3 name = name.rstrip('/') # Handle None root directories (table name is a path): table_path = S3URI(name).validate() root_dir = table_path.root_dir table_name = table_path.object_name try: table_names = global_state.client.source_table_api.list_tables( SourceTableListRequest( connector_id=None, root_dir=root_dir, source_type=DataSourceType.S3)).table_names return (name in table_names) or (table_name in table_names) except HTTPException as e: logger.warning( "Could not fetch tables from connector %s, due to exception " "%s.", self, e) return False
[docs] @override def table(self, name: str) -> SourceTable: r"""Returns a :class:`~kumoai.connector.SourceTable` object corresponding to a source table on Amazon S3. Args: name: The name of the table on S3. If :obj:`root_dir` is provided, the path will be specified as :obj:`root_dir/name`. If :obj:`root_dir` is not provided, the name should be the full path (e.g. starting with ``s3://``). Raises: :class:`ValueError`: if ``name`` does not exist in the backing connector. """ # NOTE only overridden for documentation purposes. return super().table(name)
@override def _list_tables(self) -> List[str]: if self.root_dir is None: raise ValueError( "Listing tables without a specified root directory is not " "supported. Please specify a root directory to continue; " "alternatively, please access individual tables with their " "full S3 paths.") req = SourceTableListRequest(connector_id=None, root_dir=self.root_dir, source_type=DataSourceType.S3) return global_state.client.source_table_api.list_tables(req) @override def _get_table_config(self, table_name: str) -> SourceTableConfigResponse: root_dir = self.root_dir if not root_dir: # Handle None root directories (table name is a path): table_path = S3URI(table_name).validate() root_dir = table_path.root_dir table_name = table_path.object_name req = SourceTableConfigRequest(connector_id=None, root_dir=root_dir, table_name=table_name, source_type=self.source_type) return global_state.client.source_table_api.get_table_config(req) # Class properties ######################################################## def __repr__(self) -> str: root_dir_name = f"\"{self.root_dir}\"" if self.root_dir else "None" return f'{self.__class__.__name__}(root_dir={root_dir_name})'
class S3URI: r"""A utility class to parse and navigate S3 URIs.""" def __init__(self, uri: str): self.uri: str = uri if uri.endswith('/'): # remove trailing slash self.uri = uri[:-1] @property def is_valid(self) -> bool: # TODO(zeyuan): For SPCS, the path can be a local filesystem path # For train/pred table. if global_state.is_spcs: return True # TODO(manan): implement more checks... return self.uri.startswith("s3://") def validate(self) -> Self: if not self.is_valid: raise ValueError(f"Path {self.uri} is not a valid S3 URI.") return self @property def root_dir(self) -> str: self.validate() return self.uri.rsplit('/', 1)[0] @property def object_name(self) -> str: self.validate() return self.uri.rsplit('/', 1)[1] # Class properties ######################################################## def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'uri={self.uri}, valid={self.is_valid})')