Source code for kumoai.pquery.training_table

from __future__ import annotations

import asyncio
import logging
import os
import time
from concurrent.futures import Future
from functools import cached_property
from typing import List, Optional, Tuple, Union

import pandas as pd
from kumoapi.common import JobStatus
from kumoapi.jobs import (
    ArtifactExportRequest,
    CustomTrainingTable,
    GenerateTrainTableJobResource,
    GenerateTrainTableRequest,
    JobStatusReport,
    SourceTableType,
    TrainingTableOutputConfig,
    TrainingTableSpec,
    WriteMode,
)
from kumoapi.source_table import S3SourceTable
from tqdm.auto import tqdm
from typing_extensions import Self, override

from kumoai import global_state
from kumoai.artifact_export import (
    ArtifactExportJob,
    ArtifactExportResult,
    TrainingTableExportConfig,
)
from kumoai.client.jobs import (
    GenerateTrainTableJobAPI,
    GenerateTrainTableJobID,
)
from kumoai.connector import S3Connector, SourceTable
from kumoai.databricks import to_db_table_name
from kumoai.formatting import pretty_print_error_details
from kumoai.futures import KumoProgressFuture, create_future
from kumoai.jobs import JobInterface

logger = logging.getLogger(__name__)

_DEFAULT_INTERVAL_S = 20


[docs] class TrainingTable: r"""A training table in the Kumo platform. A training table can be initialized from a job ID of a completed training table generation job. .. code-block:: python import kumoai # Create a Training Table from a training table generation job. Note # that the job ID passed here must be in a completed state: training_table = kumoai.TrainingTable("gen-traintable-job-...") # Read the training table as a Pandas DataFrame: training_df = training_table.data_df() # Get URLs to download the training table: training_download_urls = training_table.data_urls() # Add weight column to the training table: # see `kumo-sdk.examples.datasets.weighted_train_table.py` # for a more detailed example # 1. Export train table connector = kumo.S3Connector("s3_path") training_table.export(TrainingTableExportConfig( output_types={'training_table'}, output_connector=connector, output_table_name="<any_name>")) # 2. Assume the weight column was added to the train table # and it was saved to the same S3 path as "<mod_name>" training_table.update(SourceTable("<mod_table>", connector), TrainingTableSpec(weight_col="weight")) Args: job_id: ID of the training table generation job which generated this training table. """
[docs] def __init__(self, job_id: GenerateTrainTableJobID): self.job_id = job_id status = _get_status(job_id).status self._custom_train_table: Optional[CustomTrainingTable] = None if status != JobStatus.DONE: raise ValueError( f"Job {job_id} is not yet complete (status: {status}). If you " f"would like to create a future (waiting for training table " f"generation success), please use `TrainingTableJob`.")
[docs] def data_urls(self) -> List[str]: r"""Returns a list of URLs that can be used to view generated training table data. The list will contain more than one element if the table is partitioned; paths will be relative to the location of the Kumo data plane. """ api: GenerateTrainTableJobAPI = ( global_state.client.generate_train_table_job_api) return api._get_table_data(self.job_id, presigned=True, raw_path=True)
[docs] def data_df(self) -> pd.DataFrame: r"""Returns a :class:`~pandas.DataFrame` object representing the generated training data. .. warning:: This method will load the full training table into memory as a :class:`~pandas.DataFrame` object. If you are working on a machine with limited resources, please use :meth:`~kumoai.pquery.TrainingTable.data_urls` instead to download the data and perform analysis per-partition. """ urls = self.data_urls() if global_state.is_spcs: from kumoai.spcs import _parquet_dataset_to_df # TODO(dm): return type hint is wrong return _parquet_dataset_to_df(self.data_urls()) try: return pd.concat([pd.read_parquet(x) for x in urls]) except Exception as e: raise ValueError( f"Could not create a Pandas DataFrame object from data paths " f"{urls}. Please construct the DataFrame manually.") from e
def __repr__(self) -> str: return f'{self.__class__.__name__}(job_id={self.job_id})' def _to_s3_api_source_table(self, source_table: SourceTable) -> S3SourceTable: assert isinstance(source_table.connector, S3Connector) source_type = source_table._to_api_source_table() root_dir: str = source_table.connector.root_dir # type: ignore if root_dir[-1] != os.sep: root_dir = root_dir + os.sep return S3SourceTable( s3_path=root_dir, source_table_name=source_table.name, file_type=source_type.file_type, )
[docs] def export( self, output_config: TrainingTableExportConfig, non_blocking: bool = True, ) -> Union[ArtifactExportJob, ArtifactExportResult]: r"""Export the training table to the connector. specified in the output config. Use the exported table to add a weight column then use `update` to update the training table. Args: output_config: The output configuration to write the training table. non_blocking: If ``True``, the method will return a future object `ArtifactExportJob` representing the export job. If ``False``, the method will block until the export job is complete and return `ArtifactExportResult`. """ assert output_config.output_connector is not None assert output_config.output_types == {'training_table'} output_table_name = to_db_table_name(output_config.output_table_name) assert output_table_name is not None s3_path = None connector_id = None table_name = "" write_mode = WriteMode.OVERWRITE if isinstance(output_config.output_connector, S3Connector): assert output_config.output_connector.root_dir is not None s3_path = output_config.output_connector.root_dir s3_path = os.path.join(s3_path, output_table_name) else: connector_id = output_config.output_connector.name table_name = output_table_name if output_config.connector_specific_config: write_mode = output_config.connector_specific_config.write_mode api = global_state.client.artifact_export_api output_config = TrainingTableOutputConfig( s3_path=s3_path, connector_id=connector_id, table_name=table_name, write_mode=write_mode, ) request = ArtifactExportRequest(job_id=self.job_id, training_table_output=output_config) job_id = api.create(request) if non_blocking: return ArtifactExportJob(job_id) return ArtifactExportJob(job_id).attach()
[docs] def validate_custom_table( self, source_table_type: SourceTableType, train_table_mod: TrainingTableSpec, ) -> None: r"""Validates the modified training table. Args: source_table_type: The source table to be used as the modified training table. train_table_mod: The modification specification. Raises: ValueError: If the modified training table is invalid. """ api: GenerateTrainTableJobAPI = ( global_state.client.generate_train_table_job_api) response = api.validate_custom_train_table(self.job_id, source_table_type, train_table_mod) if not response.ok: raise ValueError("Invalid weighted train table", response.error_message)
[docs] def update( self, source_table: SourceTable, train_table_mod: TrainingTableSpec, validate: bool = True, ) -> Self: r"""Sets the `source_table` as the modified training table. .. note:: The only allowed modification is the addition of weight column Any other modification might lead to unintentded ERRORS while training. Further negative/NA weight values are not supported. The custom training table is ingested during trainer.fit() and is used as the training table. Args: source_table: The source table to be used as the modified training table. train_table_mod: The modification specification. validate: Whether to validate the modified training table. This can be slow for large tables. """ if isinstance(source_table.connector, S3Connector): # Special handling for s3 as `source_table._to_api_source_table` # concatenates root_dir and file name. But the backend expects # these to be separate. source_table_type = self._to_s3_api_source_table(source_table) else: source_table_type = source_table._to_api_source_table() if validate: self.validate_custom_table(source_table_type, train_table_mod) self._custom_train_table = CustomTrainingTable( source_table=source_table_type, table_mod_spec=train_table_mod, validated=validate) return self
# Training Table Future #######################################################
[docs] class TrainingTableJob(JobInterface[GenerateTrainTableJobID, GenerateTrainTableRequest, GenerateTrainTableJobResource], KumoProgressFuture[TrainingTable]): r"""A representation of an ongoing training table generation job in the Kumo platform. .. code-block:: python import kumoai # See `PredictiveQuery` documentation: pquery = kumoai.PredictiveQuery(...) # If a training table is generated in nonblocking mode, the response # will be of type `TrainingTableJob`: training_table_job = pquery.generate_training_table(non_blocking=True) # You can also construct a `TrainingTableJob` from a job ID, e.g. # one that is present in the Kumo Jobs page: training_table_job = kumoai.TrainingTableJob("trainingjob-...") # Get the status of the job: print(training_table_job.status()) # Attach to the job, and poll progress updates: training_table_job.attach() # Cancel the job: training_table_job.cancel() # Wait for the job to complete, and return a `TrainingTable`: training_table_job.result() Args: job_id: ID of the training table generation job. """ @override @staticmethod def _api() -> GenerateTrainTableJobAPI: return global_state.client.generate_train_table_job_api
[docs] def __init__( self, job_id: GenerateTrainTableJobID, ) -> None: self.job_id = job_id
@cached_property def _fut(self) -> Future[TrainingTable]: return create_future(_poll(self.job_id)) @override @property def id(self) -> GenerateTrainTableJobID: r"""The unique ID of this training table generation process.""" return self.job_id
[docs] @override def result(self) -> TrainingTable: return self._fut.result()
[docs] @override def future(self) -> Future[TrainingTable]: return self._fut
[docs] @override def status(self) -> JobStatusReport: r"""Returns the status of a running training table generation job.""" return _get_status(self.job_id)
@override def _attach_internal(self, interval_s: float = 20.0) -> TrainingTable: assert interval_s >= 4.0 print(f"Attaching to training table generation job {self.job_id}. " f"Tracking this job in the Kumo UI is coming soon. To detach " f"from this job, please enter Ctrl+C (the job will continue to " f"run, and you can re-attach anytime).") def _get_progress() -> Optional[Tuple[int, int]]: progress = self._api().get_progress(self.job_id) if len(progress) == 0: return None expected_iter = progress['num_expected_iterations'] completed_iter = progress['num_finished_iterations'] return (expected_iter, completed_iter) # Print progress bar: print("Training table generation is in progress. If your task is " "temporal, progress per timeframe will be loaded shortly.") # Wait for either timeframes to become available, or the job is done: progress = _get_progress() while progress is None: progress = _get_progress() # Not a temporal task, and it's done: if self.status().status.is_terminal: return self.result() time.sleep(interval_s) # Wait for timeframes to become available: progress = _get_progress() assert progress is not None total, prog = progress pbar = tqdm(total=total, unit="timeframe", desc="Generating Training Table") pbar.update(pbar.n - prog) while not self.done(): progress = _get_progress() assert progress is not None total, prog = progress pbar.reset(total) pbar.update(prog) time.sleep(interval_s) pbar.update(pbar.total) pbar.close() # Future is done: return self.result()
[docs] def cancel(self) -> None: r"""Cancels a running training table generation job, and raises an error if cancellation failed. """ return self._api().cancel(self.job_id)
[docs] @override def load_config(self) -> GenerateTrainTableRequest: r"""Load the full configuration for this training table generation job. Returns: GenerateTrainTableRequest: Complete configuration including plan, pquery_id, graph_snapshot_id, etc. """ return self._api().get_config(self.job_id)
def _get_status(job_id: str) -> JobStatusReport: api = global_state.client.generate_train_table_job_api resource: GenerateTrainTableJobResource = api.get(job_id) return resource.job_status_report async def _poll(job_id: str) -> TrainingTable: # TODO(manan): make asynchronous natively with aiohttp: status = _get_status(job_id).status while not status.is_terminal: await asyncio.sleep(_DEFAULT_INTERVAL_S) status = _get_status(job_id).status if status != JobStatus.DONE: api = global_state.client.generate_train_table_job_api error_details = api.get_job_error(job_id) error_str = pretty_print_error_details(error_details) raise RuntimeError( f"Training table generation for job {job_id} failed with " f"job status {status}. Encountered below error(s):" f'{error_str}') return TrainingTable(job_id)