Source code for kumoai.connector.s3_connector
import logging
from typing import List, Optional
from kumoapi.data_source import DataSourceType
from kumoapi.source_table import (
S3SourceTableRequest,
SourceTableConfigRequest,
SourceTableConfigResponse,
SourceTableListRequest,
)
from typing_extensions import Self, override
from kumoai import global_state
from kumoai.connector import Connector
from kumoai.connector.source_table import SourceTable
from kumoai.exceptions import HTTPException
logger = logging.getLogger(__name__)
[docs]class S3Connector(Connector):
r"""Defines a connector to a table stored as a file (or partitioned
set of files) on the Amazon `S3 <https://aws.amazon.com/s3/>`__ object
store. Any table behind an S3 bucket accessible by the shared external IAM
role can be accessed through this connector.
.. code-block:: python
import kumoai
connector = kumoai.S3Connector(root_dir="s3://...") # an S3 path.
# List all tables:
print(connector.table_names()) # Returns: ['articles', 'customers', 'users']
# Check whether a table is present:
assert "articles" in connector
# Fetch a source table (both approaches are equivalent):
source_table = connector["articles"]
source_table = connector.table("articles")
Args:
root_dir: The root directory of this connector. If provided, the root
directory is used as a prefix for tables in this connector. If not
provided, all tables must be specified by their full S3 paths.
""" # noqa
[docs] def __init__(self, root_dir: Optional[str] = None) -> None:
if root_dir is not None:
# Remove trailing / to be consistent with boto s3
root_dir = root_dir.rstrip('/')
self.root_dir = root_dir
if global_state.is_spcs and root_dir is not None \
and root_dir.startswith('s3://'):
raise ValueError(
"S3 connectors are not supported when running Kumo in "
"Snowpark container services. Please use a Snowflake "
"connector instead.")
@override
@property
def name(self) -> str:
r"""Not supported by :class:`S3Connector`; returns an internal
specifier.
"""
return "s3_connector"
@override
@property
def source_type(self) -> DataSourceType:
return DataSourceType.S3
@override
def _source_table_request(
self,
table_names: List[str],
) -> S3SourceTableRequest:
root_dir = self.root_dir
if not root_dir:
# Handle None root directories (table name is a path):
table_path = S3URI(table_names[0]).validate()
root_dir = table_path.root_dir
for i, v in enumerate(table_names):
uri = S3URI(v)
if uri.root_dir != root_dir:
# TODO(manan): fix
raise ValueError(
f"Please ensure that all of your tables are behind "
f"the same root directory ({root_dir}).")
table_names[i] = uri.object_name
# TODO(manan): file type?
return S3SourceTableRequest(s3_root_dir=root_dir,
table_names=table_names)
[docs] def has_table(self, name: str) -> bool:
r"""Returns :obj:`True` if the table exists in this connector,
:obj:`False` otherwise.
Args:
name: The name of the table on S3. If :obj:`root_dir` is provided,
the path will be specified as :obj:`root_dir/name`. If
:obj:`root_dir` is not provided, the name should be the full
path (e.g. starting with `s3://`).
"""
# TODO(manan): this is silly. Just write a quick endpoint to check the
# validity of an individual path:
root_dir = self.root_dir
table_name = None
if not root_dir:
# Remove trailing / to be consistent with boto s3
name = name.rstrip('/')
# Handle None root directories (table name is a path):
table_path = S3URI(name).validate()
root_dir = table_path.root_dir
table_name = table_path.object_name
try:
table_names = global_state.client.source_table_api.list_tables(
SourceTableListRequest(
connector_id=None, root_dir=root_dir,
source_type=DataSourceType.S3)).table_names
return (name in table_names) or (table_name in table_names)
except HTTPException as e:
logger.warning(
"Could not fetch tables from connector %s, due to exception "
"%s.", self, e)
return False
[docs] @override
def table(self, name: str) -> SourceTable:
r"""Returns a :class:`~kumoai.connector.SourceTable` object
corresponding to a source table on Amazon S3.
Args:
name: The name of the table on S3. If :obj:`root_dir` is provided,
the path will be specified as :obj:`root_dir/name`. If
:obj:`root_dir` is not provided, the name should be the full
path (e.g. starting with ``s3://``).
Raises:
:class:`ValueError`: if ``name`` does not exist in the backing
connector.
"""
# NOTE only overridden for documentation purposes.
return super().table(name)
@override
def _list_tables(self) -> List[str]:
if self.root_dir is None:
raise ValueError(
"Listing tables without a specified root directory is not "
"supported. Please specify a root directory to continue; "
"alternatively, please access individual tables with their "
"full S3 paths.")
req = SourceTableListRequest(connector_id=None, root_dir=self.root_dir,
source_type=DataSourceType.S3)
return global_state.client.source_table_api.list_tables(req)
@override
def _get_table_config(self, table_name: str) -> SourceTableConfigResponse:
root_dir = self.root_dir
if not root_dir:
# Handle None root directories (table name is a path):
table_path = S3URI(table_name).validate()
root_dir = table_path.root_dir
table_name = table_path.object_name
req = SourceTableConfigRequest(connector_id=None, root_dir=root_dir,
table_name=table_name,
source_type=self.source_type)
return global_state.client.source_table_api.get_table_config(req)
# Class properties ########################################################
def __repr__(self) -> str:
root_dir_name = f"\"{self.root_dir}\"" if self.root_dir else "None"
return f'{self.__class__.__name__}(root_dir={root_dir_name})'
class S3URI:
r"""A utility class to parse and navigate S3 URIs."""
def __init__(self, uri: str):
self.uri: str = uri
if uri.endswith('/'): # remove trailing slash
self.uri = uri[:-1]
@property
def is_valid(self) -> bool:
# TODO(zeyuan): For SPCS, the path can be a local filesystem path
# For train/pred table.
if global_state.is_spcs:
return True
# TODO(manan): implement more checks...
return self.uri.startswith("s3://")
def validate(self) -> Self:
if not self.is_valid:
raise ValueError(f"Path {self.uri} is not a valid S3 URI.")
return self
@property
def root_dir(self) -> str:
self.validate()
return self.uri.rsplit('/', 1)[0]
@property
def object_name(self) -> str:
self.validate()
return self.uri.rsplit('/', 1)[1]
# Class properties ########################################################
def __repr__(self) -> str:
return (f'{self.__class__.__name__}('
f'uri={self.uri}, valid={self.is_valid})')