Source code for kumoapi.source_table

from dataclasses import field
from typing import List, Optional, Union

from pydantic import Field, root_validator, validator
from pydantic.dataclasses import dataclass
from typing_extensions import Literal

from kumoapi.common import StrEnum
from kumoapi.data_source import DataSourceType
from kumoapi.typing import Dtype, Stype

TableName = str

# Source Table ================================================================


class FileType(StrEnum):
    r"""Supported file types for file-based source tables."""
    CSV = "CSV"
    PARQUET = "PARQUET"


class LLMType(StrEnum):
    r"""Supported LLM types."""
    # Use LLM embeddings as features
    FEATURE = "feature"


@dataclass(frozen=True)
class S3SourceTable:
    r"""A source table located on the Amazon S3 object store."""
    # We support two types of table file path:
    # 1. s3_path specifies the whole directory (prefix), ending with "/"
    # 2. s3_path specifies the full path of a single file, ending with file
    #    name suffix that must be one of ".csv" or ".parquet"
    s3_path: str

    # Internal: S3 connector ID, if we are working with a Kumo-owned named S3
    # connector:
    connector_id: Optional[str] = None
    source_table_name: Optional[TableName] = None

    # If not provided, then the file_path must either end in `.csv` or
    # `.parquet`, and we will parse the file type from there. Please use the
    # `validated_file_type` proper to access the parsed & validated file type.
    file_type: Optional[FileType] = None

    data_source_type: Literal[DataSourceType.S3] = DataSourceType.S3

    @property
    def table(self) -> TableName:
        if self.s3_path == "":
            assert self.source_table_name is not None
            return self.source_table_name
        if self.s3_path.endswith('/'):
            return TableName(
                self.s3_path.rstrip('/').rsplit('/', maxsplit=1)[1])
        filename = self.s3_path.rsplit('/', maxsplit=1)[1]
        return TableName(filename.rsplit('.', maxsplit=1)[0])  # strip suffix


@dataclass(frozen=True)
class SnowflakeSourceTable:
    r"""A source table located in the Snowflake data warehouse."""
    snowflake_connector_id: str
    database: str
    schema_name: str
    table: TableName
    data_source_type: Literal[
        DataSourceType.SNOWFLAKE] = DataSourceType.SNOWFLAKE


@dataclass(frozen=True)
class DatabricksSourceTable:
    r"""A source table located in the Databricks data warehouse."""
    databricks_connector_id: str
    table: TableName
    data_source_type: Literal[
        DataSourceType.DATABRICKS] = DataSourceType.DATABRICKS


@dataclass(frozen=True)
class BigQuerySourceTable:
    r"""A source table loated in the BigQuery data warehouse."""
    bigquery_connector_id: str
    table_name: TableName
    project_id: str
    dataset_id: str
    data_source_type: Literal[
        DataSourceType.BIGQUERY] = DataSourceType.BIGQUERY


SourceTableType = Union[S3SourceTable, SnowflakeSourceTable,
                        DatabricksSourceTable, BigQuerySourceTable]

# Method: Configuration =======================================================


@dataclass
class SourceTableConfigRequest:
    connector_id: Optional[str]
    table_name: str
    source_type: DataSourceType

    root_dir: Optional[str] = None
    file_type: Optional[FileType] = None


@dataclass
class SourceTableConfigResponse:
    source_table: SourceTableType = Field(discriminator='data_source_type')


# Method: List ================================================================


@dataclass
class SourceTableListRequest:
    # TODO(manan): enforce one-of connector ID or root_dir
    connector_id: Optional[str]
    source_type: DataSourceType

    # Only for object store-based connectors:
    root_dir: Optional[str] = None

    @root_validator()
    def _validate_connector_id(cls, values):
        if values['connector_id'] is None and values[
                'source_type'] != DataSourceType.S3:
            raise ValueError(
                "A 'None' connector ID is only supported for S3-backed "
                "tables. Please specify a connector ID to proceed.")
        return values


@dataclass
class SourceTableListResponse:
    table_names: List[str]


# Method: Get Data ============================================================


[docs]@dataclass class SourceColumn: r"""The metadata of a column in a source table. Note that a source column simply provides a view into the metadata of a source table. To modify metadata, please create a Kumo Table and adjust the table's data and semantic types. .. note:: Semantic types are inferred based on data types only, and thus may not be accurate. Args: name (str): The name of the column. stype (Stype, optional): The semantic type of the column. dtype (Dtype): The data type of the column is_primary (bool): Whether the column refers to a primary key. """ name: str stype: Optional[Stype] # Kumo-inferred. dtype: Dtype is_primary: bool
@dataclass class S3SourceTableRequest: r"""A request to fetch a source table located on Amazon S3. This table can be located at either root_dir/table_name/*.(csv|parquet) root_dir/table_name.(csv|parquet) """ s3_root_dir: str # TODO(manan): rename to `root_dir` connector_id: Optional[str] = None table_names: Optional[List[str]] = None file_type: Optional[FileType] = None source_type: Literal[DataSourceType.S3] = DataSourceType.S3 @dataclass class SnowflakeSourceTableRequest: connector_id: str table_names: Optional[List[str]] = None source_type: Literal[DataSourceType.SNOWFLAKE] = DataSourceType.SNOWFLAKE # TODO(siyang): We should move database and schema out of SF connector. # database: Optional[str] = None # schema: Optional[str] = None @dataclass class BigQuerySourceTableRequest: connector_id: str table_names: Optional[List[str]] = None # Discriminator: source_type: Literal[DataSourceType.BIGQUERY] = DataSourceType.BIGQUERY @dataclass class DatabricksSourceTableRequest: connector_id: str table_names: Optional[List[str]] = None # Discriminator: source_type: Literal[DataSourceType.DATABRICKS] = DataSourceType.DATABRICKS @dataclass class SourceTableDataRequest: # Table request (metadata needed to fetch a table from the connector): source_table_request: Union[ S3SourceTableRequest, BigQuerySourceTableRequest, DatabricksSourceTableRequest, SnowflakeSourceTableRequest, ] = Field(discriminator='source_type') # Whether to fetch and include sample rows in the response: sample_rows: int = 0 @validator('sample_rows') def _validate_sample_rows(cls, v: int): if v > 1000: return ValueError('sample_rows cannot be greater than 1000.') if v < 0: return ValueError('sample_rows cannot be negative.') return v @dataclass class SourceTableDataResponse: table_name: TableName cols: List[SourceColumn] = field(default_factory=list) # Serialized (json) data of sample rows dataframe, if requested: # TODO(siyang,manan): figure out the ser/de protocol for pandas dataframe sample_rows: Optional[str] = None # Other ======================================================================= @dataclass class TableStats: r"""Minimal statistics of a :class:`SourceTable`. Args: size_bytes (int): The size of the table in bytes. num_rows (int): The number of rows in the table. """ size_bytes: int num_rows: int # TODO(siyang): add a flag to indicate if stats are exact or approx? @dataclass class LLMRequest: source_table_type: SourceTableType template: str model: str model_api_key: str output_dir: str output_column_name: str output_table_name: str dimensions: Optional[int] = None llm_type: LLMType = LLMType.FEATURE @dataclass class LLMResponse: job_id: str