Source code for kumoai.pquery.training_table

from __future__ import annotations

import asyncio
import logging
import os
import time
from concurrent.futures import Future
from typing import List, Mapping, Optional, Tuple, Union

import pandas as pd
from kumoapi.common import JobStatus
from kumoapi.jobs import (
    ArtifactExportRequest,
    CustomTrainingTable,
    GenerateTrainTableJobResource,
    JobStatusReport,
    SourceTableType,
    TrainingTableOutputConfig,
    TrainingTableSpec,
    WriteMode,
)
from kumoapi.source_table import S3SourceTable
from tqdm.auto import tqdm
from typing_extensions import Self, override

from kumoai import global_state
from kumoai.artifact_export import (
    ArtifactExportJob,
    ArtifactExportResult,
    TrainingTableExportConfig,
)
from kumoai.client.jobs import (
    GenerateTrainTableJobAPI,
    GenerateTrainTableJobID,
)
from kumoai.connector import S3Connector, SourceTable
from kumoai.databricks import to_db_table_name
from kumoai.formatting import pretty_print_error_details
from kumoai.futures import KumoProgressFuture, create_future

logger = logging.getLogger(__name__)

_DEFAULT_INTERVAL_S = 20


[docs]class TrainingTable: r"""A training table in the Kumo platform. A training table can be initialized from a job ID of a completed training table generation job. .. code-block:: python import kumoai # Create a Training Table from a training table generation job. Note # that the job ID passed here must be in a completed state: training_table = kumoai.TrainingTable("gen-traintable-job-...") # Read the training table as a Pandas DataFrame: training_df = training_table.data_df() # Get URLs to download the training table: training_download_urls = training_table.data_urls() # Add weight column to the training table: # see `kumo-sdk.examples.datasets.weighted_train_table.py` # for a more detailed example # 1. Export train table connector = kumo.S3Connector("s3_path") training_table.export(TrainingTableExportConfig( output_types={'training_table'}, output_connector=connector, output_table_name="<any_name>")) # 2. Assume the weight column was added to the train table # and it was saved to the same S3 path as "<mod_name>" training_table.update(SourceTable("<mod_table>", connector), TrainingTableSpec(weight_col="weight")) Args: job_id: ID of the training table generation job which generated this training table. """
[docs] def __init__(self, job_id: GenerateTrainTableJobID): self.job_id = job_id status = _get_status(job_id).status self._custom_train_table: Optional[CustomTrainingTable] = None if status != JobStatus.DONE: raise ValueError( f"Job {job_id} is not yet complete (status: {status}). If you " f"would like to create a future (waiting for training table " f"generation success), please use `TrainingTableJob`.")
[docs] def data_urls(self) -> List[str]: r"""Returns a list of URLs that can be used to view generated training table data. The list will contain more than one element if the table is partitioned; paths will be relative to the location of the Kumo data plane. """ api: GenerateTrainTableJobAPI = ( global_state.client.generate_train_table_job_api) return api._get_table_data(self.job_id, presigned=True, raw_path=True)
[docs] def data_df(self) -> pd.DataFrame: r"""Returns a :class:`~pandas.DataFrame` object representing the generated training data. .. warning:: This method will load the full training table into memory as a :class:`~pandas.DataFrame` object. If you are working on a machine with limited resources, please use :meth:`~kumoai.pquery.TrainingTable.data_urls` instead to download the data and perform analysis per-partition. """ urls = self.data_urls() if global_state.is_spcs: from kumoai.spcs import _parquet_dataset_to_df # TODO(dm): return type hint is wrong return _parquet_dataset_to_df(self.data_urls()) try: return pd.concat([pd.read_parquet(x) for x in urls]) except Exception as e: raise ValueError( f"Could not create a Pandas DataFrame object from data paths " f"{urls}. Please construct the DataFrame manually.") from e
def __repr__(self) -> str: return f'{self.__class__.__name__}(job_id={self.job_id})' def _to_s3_api_source_table(self, source_table: SourceTable) -> S3SourceTable: assert isinstance(source_table.connector, S3Connector) source_type = source_table._to_api_source_table() root_dir: str = source_table.connector.root_dir # type: ignore if root_dir[-1] != os.sep: root_dir = root_dir + os.sep return S3SourceTable( s3_path=root_dir, source_table_name=source_table.name, file_type=source_type.file_type, )
[docs] def export( self, output_config: TrainingTableExportConfig, non_blocking: bool = True, ) -> Union[ArtifactExportJob, ArtifactExportResult]: r"""Export the training table to the connector. specified in the output config. Use the exported table to add a weight column then use `update` to update the training table. Args: output_config: The output configuration to write the training table. non_blocking: If ``True``, the method will return a future object `ArtifactExportJob` representing the export job. If ``False``, the method will block until the export job is complete and return `ArtifactExportResult`. """ assert output_config.output_connector is not None assert output_config.output_types == {'training_table'} output_table_name = to_db_table_name(output_config.output_table_name) assert output_table_name is not None s3_path = None connector_id = None table_name = "" write_mode = WriteMode.OVERWRITE if isinstance(output_config.output_connector, S3Connector): assert output_config.output_connector.root_dir is not None s3_path = output_config.output_connector.root_dir s3_path = os.path.join(s3_path, output_table_name) else: connector_id = output_config.output_connector.name table_name = output_table_name if output_config.connector_specific_config: write_mode = output_config.connector_specific_config.write_mode api = global_state.client.artifact_export_api output_config = TrainingTableOutputConfig( s3_path=s3_path, connector_id=connector_id, table_name=table_name, write_mode=write_mode, ) request = ArtifactExportRequest(job_id=self.job_id, training_table_output=output_config) job_id = api.create(request) if non_blocking: return ArtifactExportJob(job_id) return ArtifactExportJob(job_id).attach()
[docs] def validate_custom_table(self, source_table_type: SourceTableType, train_table_mod: TrainingTableSpec) -> None: r"""Validates the modified training table. Args: source_table (SourceTable): The source table to be used as the modified training table. train_table_mod (TrainTableSpec): The modification specification. Raises: ValueError: If the modified training table is invalid. """ api: GenerateTrainTableJobAPI = ( global_state.client.generate_train_table_job_api) response = api.validate_custom_train_table(self.job_id, source_table_type, train_table_mod) if not response.ok: raise ValueError("Invalid weighted train table", response.error_message)
[docs] def update( self, source_table: SourceTable, train_table_mod: TrainingTableSpec, validate: bool = True, ) -> Self: r"""Sets the `source_table` as the modified training table. .. note:: The only allowed modification is the addition of weight column Any other modification might lead to unintentded ERRORS downstream. The custom training table is ingested during trainer.fit() and is used as the training table. Args: source_table (SourceTable): The source table to be used as the modified training table. table_mod_spec (TrainTableSpec): The modification specification. validate (bool): Whether to validate the modified training table. This can be slow for large tables. """ if isinstance(source_table.connector, S3Connector): # Special handling for s3 as `source_table._to_api_source_table` # concatenates root_dir and file name. But the backend expects # these to be separate. source_table_type = self._to_s3_api_source_table(source_table) else: source_table_type = source_table._to_api_source_table() if validate: self.validate_custom_table(source_table_type, train_table_mod) self._custom_train_table = CustomTrainingTable( source_table=source_table_type, table_mod_spec=train_table_mod, validated=validate) return self
# Training Table Future #######################################################
[docs]class TrainingTableJob(KumoProgressFuture[TrainingTable]): r"""A representation of an ongoing training table generation job in the Kumo platform. .. code-block:: python import kumoai # See `PredictiveQuery` documentation: pquery = kumoai.PredictiveQuery(...) # If a training table is generated in nonblocking mode, the response # will be of type `TrainingTableJob`: training_table_job = pquery.generate_training_table(non_blocking=True) # You can also construct a `TrainingTableJob` from a job ID, e.g. # one that is present in the Kumo Jobs page: training_table_job = kumoai.TrainingTableJob("trainingjob-...") # Get the status of the job: print(training_table_job.status()) # Attach to the job, and poll progress updates: training_table_job.attach() # Cancel the job: training_table_job.cancel() # Wait for the job to complete, and return a `TrainingTable`: training_table_job.result() Args: job_id: ID of the training table generation job. """
[docs] def __init__( self, job_id: GenerateTrainTableJobID, ) -> None: self.job_id = job_id # A training table holds a reference to the future that tracks the # execution of its generation. self._fut: Future[TrainingTable] = create_future(_poll(job_id))
@property def id(self) -> GenerateTrainTableJobID: r"""The unique ID of this training table generation process.""" return self.job_id
[docs] @override def result(self) -> TrainingTable: return self._fut.result()
[docs] @override def future(self) -> Future[TrainingTable]: return self._fut
[docs] def status(self) -> JobStatusReport: r"""Returns the status of a running training table generation job.""" return _get_status(self.job_id)
@override def _attach_internal(self, interval_s: float = 20.0) -> TrainingTable: assert interval_s >= 4.0 print(f"Attaching to training table generation job {self.job_id}. " f"Tracking this job in the Kumo UI is coming soon. To detach " f"from this job, please enter Ctrl+C (the job will continue to " f"run, and you can re-attach anytime).") api = global_state.client.generate_train_table_job_api def _get_progress() -> Optional[Tuple[int, int]]: progress = api.get_progress(self.job_id) if len(progress) == 0: return None expected_iter = progress['num_expected_iterations'] completed_iter = progress['num_finished_iterations'] return (expected_iter, completed_iter) # Print progress bar: print("Training table generation is in progress. If your task is " "temporal, progress per timeframe will be loaded shortly.") # Wait for either timeframes to become available, or the job is done: progress = _get_progress() while progress is None: progress = _get_progress() # Not a temporal task, and it's done: if self.status().status.is_terminal: return self.result() time.sleep(interval_s) # Wait for timeframes to become available: progress = _get_progress() assert progress is not None total, prog = progress pbar = tqdm(total=total, unit="timeframe", desc="Generating Training Table") pbar.update(pbar.n - prog) while not self.done(): progress = _get_progress() assert progress is not None total, prog = progress pbar.reset(total) pbar.update(prog) time.sleep(interval_s) pbar.update(pbar.total) pbar.close() # Future is done: return self.result()
[docs] def delete_tags(self, tags: List[str]) -> bool: r"""Removes the tags from the job. Args: tags (List[str]): The tags to remove. """ api = global_state.client.generate_train_table_job_api return api.delete_tags(self.job_id, tags)
[docs] def update_tags(self, tags: Mapping[str, Optional[str]]) -> bool: r"""Updates the tags of the job. Args: tags (Mapping[str, Optional[str]]): The tags to update. Note that the value 'none' will remove the tag. If the tag is not present, it will be added. """ api = global_state.client.generate_train_table_job_api return api.update_tags(self.job_id, tags)
[docs] def cancel(self) -> None: r"""Cancels a running training table generation job, and raises an error if cancellation failed. """ api = global_state.client.generate_train_table_job_api return api.cancel(self.job_id)
def _get_status(job_id: str) -> JobStatusReport: api = global_state.client.generate_train_table_job_api resource: GenerateTrainTableJobResource = api.get(job_id) return resource.job_status_report async def _poll(job_id: str) -> TrainingTable: # TODO(manan): make asynchronous natively with aiohttp: status = _get_status(job_id).status while not status.is_terminal: await asyncio.sleep(_DEFAULT_INTERVAL_S) status = _get_status(job_id).status if status != JobStatus.DONE: api = global_state.client.generate_train_table_job_api error_details = api.get_job_error(job_id) error_str = pretty_print_error_details(error_details) raise RuntimeError( f"Training table generation for job {job_id} failed with " f"job status {status}. Encountered below error(s):" f'{error_str}') return TrainingTable(job_id)