Source code for kumoai.connector.source_table

import asyncio
import concurrent
import logging
from io import StringIO
from typing import (
    TYPE_CHECKING,
    Dict,
    List,
    Literal,
    Optional,
    Union,
    overload,
)

import pandas as pd
from kumoapi.jobs import JobStatus
from kumoapi.source_table import (
    DataSourceType,
    LLMRequest,
    SourceColumn,
    SourceTableDataResponse,
    SourceTableType,
)
from kumoapi.table import TableDefinition
from typing_extensions import override

from kumoai import global_state
from kumoai.client.jobs import LLMJobId
from kumoai.exceptions import HTTPException
from kumoai.futures import KumoFuture, create_future

if TYPE_CHECKING:
    from kumoai.connector import Connector

logger = logging.getLogger(__name__)

_DEFAULT_INTERVAL_S = 20.0


[docs]class SourceTable: r"""A source table is a reference to a table stored behind a backing :class:`~kumoai.connector.Connector`. It can be used to examine basic information about raw data connected to Kumo, including a sample of the table's rows, basic statistics, and column data type information. Once you are ready to use a table as part of a :class:`~kumoai.graph.Graph`, you may create a :class:`~kumoai.graph.Table` object from this source table, which includes additional specifying information (including column semantic types and column constraint information). Args: name: The name of this table in the backing connector connector: The connector containing this table. .. note:: Source tables can also be augmented with large language models to introduce contextual embeddings for language features. To do so, please consult :meth:`~kumoai.connector.SourceTable.add_llm`. Example: >>> import kumoai >>> connector = kumoai.S3Connector(root_dir='s3://...') # doctest: +SKIP # noqa: E501 >>> articles_src = connector['articles'] # doctest: +SKIP >>> articles_src = kumoai.SourceTable('articles', connector) # doctest: +SKIP # noqa: E501 """
[docs] def __init__(self, name: str, connector: 'Connector') -> None: # TODO(manan): existence check, if not too expensive? self.name = name self.connector = connector
# Metadata ################################################################ @property def column_dict(self) -> Dict[str, SourceColumn]: r"""Returns the names of the columns in this table along with their :class:`SourceColumn` information. """ return {col.name: col for col in self.columns} @property def columns(self) -> List[SourceColumn]: r"""Returns a list of the :class:`SourceColumn` metadata of the columns in this table. """ resp: SourceTableDataResponse = self.connector._get_table_data( table_names=[self.name], sample_rows=0)[0] return resp.cols # Data Access #############################################################
[docs] def head(self, num_rows: int = 5) -> pd.DataFrame: r"""Returns the first :obj:`num_rows` rows of this source table by reading data from the backing connector. Args: num_rows: The number of rows to select. If :obj:`num_rows` is larger than the number of available rows, all rows will be returned. Returns: The first :obj:`num_rows` rows of the source table as a :class:`~pandas.DataFrame`. """ num_rows = int(num_rows) if num_rows <= 0: raise ValueError( f"'num_rows' must be an integer greater than 0; got {num_rows}" ) try: resp = self.connector._get_table_data([self.name], num_rows)[0] # TODO(manan, siyang): consider returning `bytes` instead of `json` assert resp.sample_rows is not None return pd.read_json(StringIO(resp.sample_rows), orient='table') except TypeError as e: raise RuntimeError(f"Cannot read head of table {self.name}. " "Please restart the kernel and try.") from e
# Language Models ######################################################### @overload def add_llm( self, model: str, api_key: str, template: str, output_dir: str, output_column_name: str, output_table_name: str, dimensions: Optional[int], ) -> 'SourceTable': pass @overload def add_llm( self, model: str, api_key: str, template: str, output_dir: str, output_column_name: str, output_table_name: str, dimensions: Optional[int], *, non_blocking: Literal[False], ) -> 'SourceTable': pass @overload def add_llm( self, model: str, api_key: str, template: str, output_dir: str, output_column_name: str, output_table_name: str, dimensions: Optional[int], *, non_blocking: Literal[True], ) -> 'LLMSourceTableFuture': pass @overload def add_llm( self, model: str, api_key: str, template: str, output_dir: str, output_column_name: str, output_table_name: str, dimensions: Optional[int] = None, *, non_blocking: bool, ) -> Union['SourceTable', 'LLMSourceTableFuture']: pass
[docs] def add_llm( self, model: str, api_key: str, template: str, output_dir: str, output_column_name: str, output_table_name: str, dimensions: Optional[int] = None, *, non_blocking: bool = False, ) -> Union['SourceTable', 'LLMSourceTableFuture']: r"""Experimental method which returns a new source table that includes a column computed via an LLM such as OpenAI embedding models. Please refer to the example script for more details. .. note:: Current LLM embedding only works for :obj:`SourceTable` in s3. .. note:: Your :obj:`api_key` will be encrypted once we received it and it's only decrypted just before we call the OpenAI text embeddings. .. note:: Please keep track of the token usage in the `OpenAI Dashboard <https://platform.openai.com/usage/activity>`_. If number of tokens in the data exceeds the limit, the backend will raise an error and no result will be produced. .. warning:: This method only supports text embedding with data that has less than ~6 million tokens. Number of tokens is estimated by following `this guide <https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them>`_. .. warning:: This method is still experimental. Please consult with your Kumo POC before using it. Args: model: The LLM model name, *e.g.*, OpenAI's :obj:`"text-embedding-3-small"`. api_key: The API key to call the LLM service. template: A template string to be put into the LLM. For example, :obj:`"{A1} and {A2}"` will fuse columns :obj:`A1` and :obj:`A2` into a single string. output_dir: The S3 directory to store the output. output_column_name: The output column name for the LLM. output_table_name: The output table name. dimensions: The desired LLM embedding dimension. non_blocking: Whether making this function non-blocking. Example: >>> import kumoai >>> connector = kumoai.S3Connector(root_dir='s3://...') # doctest: +SKIP # noqa: E501 >>> articles_src = connector['articles'] # doctest: +SKIP >>> articles_src_future = \ connector["articles"].add_llm( model="text-embedding-3-small", api_key=YOUR_OPENAI_API_KEY, template=("The product {prod_name} in the {section_name} section" "is categorized as {product_type_name} " "and has following description: {detail_desc}"), output_dir=YOUR_OUTPUT_DIR, output_column_name="embedding_column", output_table_name="articles_emb", dimensions=256, non_blocking=True, ) >>> articles_src_future.status() # doctest: +SKIP >>> articles_src_future.cancel() # doctest: +SKIP >>> articles_src = articles_src_future.result() # doctest: +SKIP """ # noqa if global_state.is_spcs: raise NotImplementedError("add_llm is not available on Snowflake") source_table_type = self._to_api_source_table() req = LLMRequest( source_table_type=source_table_type, template=template, model=model, model_api_key=api_key, output_dir=output_dir, output_column_name=output_column_name, output_table_name=output_table_name, dimensions=dimensions, ) api = global_state.client.llm_job_api resp: LLMJobId = api.create(req) logger.info(f"LLMJobId: {resp}") source_table_future = LLMSourceTableFuture(resp, output_table_name, output_dir) if non_blocking: return source_table_future # TODO (zecheng): Add attach for text embedding return source_table_future.result()
# Persistence ############################################################# def _to_api_source_table(self) -> SourceTableType: r"""Cast this source table as an object of type :obj:`SourceTableType` for use with the public REST API. """ # TODO(manan): this is stupid, and is necessary because the s3_validate # method in Kumo core does not properly return directories. So, we have # to explicitly handle this ourselves here... try: return self.connector._get_table_config(self.name).source_table except HTTPException: name = self.name.rsplit('.', maxsplit=1)[0] out = self.connector._get_table_config(name).source_table self.name = name return out @staticmethod def _from_api_table_definition( table_definition: TableDefinition) -> 'SourceTable': r"""Constructs a :class:`SourceTable` from a :class:`kumoapi.table.TableDefinition`. """ from kumoai.connector import ( BigQueryConnector, DatabricksConnector, FileUploadConnector, S3Connector, SnowflakeConnector, ) from kumoai.connector.s3_connector import S3URI source_type = table_definition.source_table.data_source_type connector: Connector if source_type == DataSourceType.S3: connector_id = table_definition.source_table.connector_id if connector_id in { 'parquet_upload_connector', 'csv_upload_connector' }: # File upload: connector = FileUploadConnector( file_type=('parquet' if connector_id == 'parquet_upload_connector' else 'csv')) table_name = table_definition.source_table.source_table_name else: if connector_id is not None: connector = S3Connector(root_dir=None, _connector_id=connector_id) table_name = ( table_definition.source_table.source_table_name) else: table_path = S3URI(table_definition.source_table.s3_path) connector = S3Connector(root_dir=table_path.root_dir) # Strip suffix, since Kumo always takes care of that: table_name = table_path.object_name.rsplit( '.', maxsplit=1)[0] elif source_type == DataSourceType.SNOWFLAKE: connector_api = global_state.client.connector_api connector_resp = connector_api.get( table_definition.source_table.snowflake_connector_id) assert connector_resp is not None connector_cfg = connector_resp.config connector = SnowflakeConnector( name=connector_cfg.name, account=connector_cfg.account, warehouse=connector_cfg.warehouse, database=connector_cfg.database, schema_name=connector_cfg.schema_name, credentials=None, # should be in env; do not load from DB. _bypass_creation=True, ) table_name = table_definition.source_table.table elif source_type == DataSourceType.DATABRICKS: connector_api = global_state.client.connector_api connector_resp = connector_api.get( table_definition.source_table.databricks_connector_id) assert connector_resp is not None connector_cfg = connector_resp.config connector = DatabricksConnector( name=connector_cfg.name, host=connector_cfg.host, cluster_id=connector_cfg.cluster_id, warehouse_id=connector_cfg.warehouse_id, catalog=connector_cfg.catalog, credentials=None, # should be in env; do not load from DB. _bypass_creation=True, ) table_name = table_definition.source_table.table elif source_type == DataSourceType.BIGQUERY: connector_api = global_state.client.connector_api connector_resp = connector_api.get( table_definition.source_table.bigquery_connector_id) assert connector_resp is not None connector_cfg = connector_resp.config connector = BigQueryConnector( name=connector_cfg.name, project_id=connector_cfg.project_id, dataset_id=connector_cfg.dataset_id, credentials=None, # should be in env; do not load from DB. _bypass_creation=True, ) table_name = table_definition.source_table.table_name else: raise NotImplementedError() return SourceTable(table_name, connector) # Class properties ######################################################## def __repr__(self) -> str: return f'{self.__class__.__name__}(name={self.name})'
[docs]class SourceTableFuture(KumoFuture[SourceTable]): r"""A representation of an on-going :class:`SourceTable` generation process. """
[docs] def __init__( self, job_id: LLMJobId, table_name: str, output_dir: str, ) -> None: self.job_id = job_id self._fut: concurrent.futures.Future = create_future( _poll(job_id, table_name, output_dir))
[docs] @override def result(self) -> SourceTable: return self._fut.result()
[docs] @override def future(self) -> 'concurrent.futures.Future[SourceTable]': return self._fut
def status(self) -> JobStatus: return _get_status(self.job_id)
[docs]class LLMSourceTableFuture(SourceTableFuture): r"""A representation of an on-going :class:`SourceTable` generation process for LLM. This class inherits from the :class:`SourceTableFuture` with some functions specific to LLM job. """
[docs] def __init__( self, job_id: LLMJobId, table_name: str, output_dir: str, ) -> None: super().__init__(job_id, table_name, output_dir)
[docs] def cancel(self) -> JobStatus: r"""Cancel the LLM job.""" api = global_state.client.llm_job_api return api.cancel(self.job_id)
def _get_status(job_id: str) -> JobStatus: api = global_state.client.llm_job_api resource: JobStatus = api.get(job_id) return resource async def _poll(job_id: str, table_name: str, output_dir: str) -> SourceTable: status = _get_status(job_id) while not status.is_terminal: await asyncio.sleep(_DEFAULT_INTERVAL_S) status = _get_status(job_id) if status != JobStatus.DONE: raise RuntimeError(f"LLM job {job_id} failed with " f"job status {status}.") from kumoai.connector import S3Connector connector = S3Connector(root_dir=output_dir) return connector[table_name]