Source code for kumoai.trainer.job

import asyncio
import concurrent
import time
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union
from urllib.parse import urlparse, urlunparse

import pandas as pd
from kumoapi.common import JobStatus
from kumoapi.jobs import (
    ArtifactExportRequest,
    AutoTrainerProgress,
    BaselineEvaluationMetrics,
    BatchPredictionJobSummary,
    JobStatusReport,
    ModelEvaluationMetrics,
    PredictionProgress,
)
from kumoapi.model_plan import ModelPlan
from tqdm.auto import tqdm
from typing_extensions import override

from kumoai import global_state
from kumoai.client.jobs import (
    BaselineJobID,
    BaselineJobResource,
    BatchPredictionJobAPI,
    BatchPredictionJobID,
    BatchPredictionJobResource,
    TrainingJobAPI,
    TrainingJobID,
    TrainingJobResource,
)
from kumoai.connector import Connector
from kumoai.futures import KumoFuture, create_future
from kumoai.trainer.util import (
    build_prediction_output_config,
    validate_output_arguments,
)

if TYPE_CHECKING:
    from kumoai.pquery import (
        PredictionTable,
        PredictionTableJob,
        PredictiveQuery,
        TrainingTable,
        TrainingTableJob,
    )


class BaselineJobResult:
    r"""Represents a completed baseline job.

    A :class:`BaselineJobResult` object can either be obtained by creating a
    :class:`~kumoai.trainer.BaselineJob` object and calling the
    :meth:`~kumoai.trainer.BaselineJob.result` method to await the job's
    completion, or by directly creating the object. The former approach is
    recommended, as it includes verification that the job finished succesfully.

    Example:
        >>> import kumoai  # doctest: +SKIP
        >>> job_future = kumoai.BaselineJob(id=...)  # doctest: +SKIP
        >>> job = job_future.result()  # doctest: +SKIP
    """
    def __init__(self, job_id: BaselineJobID) -> None:
        self.job_id = job_id

        # A cached completed, finalized job resource:
        self._job_resource: Optional[BaselineJobResource] = None

    def metrics(self) -> Dict[str, BaselineEvaluationMetrics]:
        r"""Returns the metrics associated with this completed training job,
        or raises an exception if metrics cannot be obtained.
        """
        return self._get_job_resource(
            require_completed=True).result.eval_metrics

    def _get_job_resource(self,
                          require_completed: bool) -> BaselineJobResource:
        if self._job_resource:
            return self._job_resource

        try:
            api = global_state.client.baseline_job_api
            resource: BaselineJobResource = api.get(self.job_id)
        except Exception as e:
            raise RuntimeError(
                f"Baseline job {self.job_id} was not found in the Kumo "
                f"database. Please contact Kumo for further assistance. "
            ) from e

        if not require_completed:
            return resource

        status = resource.job_status_report.status
        if not status.is_terminal:
            raise RuntimeError(
                f"Baseline job {self.job_id} has not yet completed. Please "
                f"create a `BaselineJob` class and await its completion "
                f"before attempting to view metrics.")

        if status != JobStatus.DONE:
            # Should never happen, the future will not resolve:
            raise ValueError(
                f"Baseline job {self.job_id} completed with status {status}, "
                f"and was therefore unable to produce metrics. Please "
                f"re-train the job until it successfully completes.")

        self._job_resource = resource
        return self._job_resource


[docs]class TrainingJobResult: r"""Represents a completed training job. A :class:`TrainingJobResult` object can either be obtained by creating a :class:`~kumoai.trainer.TrainingJob` object and calling the :meth:`~kumoai.trainer.TrainingJob.result` method to await the job's completion, or by directly creating the object. The former approach is recommended, as it includes verification that the job finished succesfully. .. code-block:: python import kumoai training_job = kumoai.TrainingJob("trainingjob-...") # Wait for a training job's completion, and get its result: training_job_result = training_job.result() # Alternatively, create the result directly, but be sure that the job # is completed: training_job_result = kumoai.TrainingJobResult("trainingjob-...") # Get associated objects: pquery = training_job_result.predictive_query training_table = training_job_result.training_table # Get holdout data: holdout_df = training_job_result.holdout_df() Example: >>> import kumoai # doctest: +SKIP >>> job_future = kumoai.TrainingJob(id=...) # doctest: +SKIP >>> job = job_future.result() # doctest: +SKIP """
[docs] def __init__(self, job_id: TrainingJobID) -> None: self.job_id = job_id # A cached completed, finalized job resource: self._job_resource: Optional[TrainingJobResource] = None
@property def id(self) -> TrainingJobID: r"""The unique ID of this training job.""" return self.job_id @property def model_plan(self) -> ModelPlan: r"""Returns the model plan associated with this training job.""" return self._get_job_resource( require_completed=False).config.model_plan @property def training_table(self) -> Union['TrainingTableJob', 'TrainingTable']: r"""Returns the training table associated with this training job, either as a :class:`~kumoai.pquery.TrainingTable` or a :class:`~kumoai.pquery.TrainingTableJob` depending on the status of the training table generation job. """ from kumoai.pquery import TrainingTableJob training_table_job_id = self._get_job_resource( require_completed=False).config.train_table_job_id if training_table_job_id is None: raise RuntimeError( f"Unable to access the training table generation job ID for " f"job {self.job_id}. Did this job fail before generating its " f"training table?") fut = TrainingTableJob(training_table_job_id) if fut.status().status == JobStatus.DONE: return fut.result() return fut @property def predictive_query(self) -> 'PredictiveQuery': r"""Returns the :class:`~kumoai.pquery.PredictiveQuery` object that defined the training table for this training job. """ from kumoai.pquery import PredictiveQuery return PredictiveQuery.load_from_training_job(self.job_id) @property def tracking_url(self) -> str: r"""Returns a tracking URL pointing to the UI display of this training job. """ tracking_url = self._get_job_resource( require_completed=False).job_status_report.tracking_url return _rewrite_tracking_url(tracking_url)
[docs] def metrics(self) -> ModelEvaluationMetrics: r"""Returns the metrics associated with this completed training job, or raises an exception if metrics cannot be obtained. """ return self._get_job_resource( require_completed=True).result.eval_metrics
[docs] def holdout_url(self) -> str: r"""Returns a URL for downloading or reading the holdout dataset. If Kumo is deployed as a SaaS application, the returned URL will be a presigned AWS S3 URL with a TTL of 1 hour. If Kumo is deployed as a Snowpark Container Services application, the returned URL will be a Snowflake stage path that can be directly accessed within a Snowflake worksheet. """ api: TrainingJobAPI = global_state.client.training_job_api return api.holdout_data_url(self.job_id, presigned=True)
[docs] def holdout_df(self) -> pd.DataFrame: r"""Reads the holdout dataset (parquet file) as pandas DataFrame. .. note:: Please note that this function may be memory-intensive, depending on the size of your holdout dataframe. Please exercise caution. """ return pd.read_parquet(self.holdout_url())
def _get_job_resource(self, require_completed: bool) -> TrainingJobResource: if self._job_resource: return self._job_resource try: api = global_state.client.training_job_api resource: TrainingJobResource = api.get(self.job_id) except Exception as e: raise RuntimeError( f"Training job {self.job_id} was not found in the Kumo " f"database. Please contact Kumo for further assistance. " ) from e if not require_completed: return resource status = resource.job_status_report.status if not status.is_terminal: raise RuntimeError( f"Training job {self.job_id} has not yet completed. Please " f"create a `TrainingJob` class and await its completion " f"before attempting to view metrics.") if status != JobStatus.DONE: # Should never happen, the future will not resolve: raise ValueError( f"Training job {self.job_id} completed with status {status}, " f"and was therefore unable to produce metrics. Please " f"re-train the job until it successfully completes.") self._job_resource = resource return self._job_resource
[docs]class BatchPredictionJobResult: r"""Represents a completed batch prediction job. A :class:`BatchPredictionJobResult` object can either be obtained by creating a :class:`~kumoai.trainer.BatchPredictionJob` object and calling the :meth:`~kumoai.trainer.BatchPredictionJob.result` method to await the job's completion, or by directly creating the object. The former approach is recommended, as it includes verification that the job finished succesfully. .. code-block:: python import kumoai prediction_job = kumoai.BatchPredictionJob("bp-job-...") # Wait for a batch prediction job's completion, and get its result: prediction_job_result = prediction_job.result() # Alternatively, create the result directly, but be sure that the job # is completed: prediction_job_result = kumoai.BatchPredictionJobResult("bp-job-...") # Get associated objects: prediction_table = prediction_job_result.prediction_table # Get prediction data (in-memory): predictions_df = training_job.predictions_df() # Export prediction data to any output connector: prediction_job_result.export( output_type = ..., output_connector = ..., output_table_name = ..., ) """ # noqa: E501
[docs] def __init__(self, job_id: BatchPredictionJobID) -> None: self.job_id = job_id self._job_resource: Optional[BatchPredictionJobResource] = None
@property def id(self) -> BatchPredictionJobID: r"""The unique ID of this batch prediction job.""" return self.job_id @property def tracking_url(self) -> str: r"""Returns a tracking URL pointing to the UI display of this batch prediction job. """ tracking_url = self._get_job_resource( require_completed=False).job_status_report.tracking_url return _rewrite_tracking_url(tracking_url)
[docs] def summary(self) -> BatchPredictionJobSummary: r"""Returns summary statistics associated with the batch prediction job's output, or raises an exception if summary statistics cannot be obtained. """ return self._get_job_resource(require_completed=True).result
@property def prediction_table( self) -> Union['PredictionTableJob', 'PredictionTable']: r"""Returns the prediction table associated with this prediction job, either as a :class:`~kumoai.pquery.PredictionTable` or a :class:`~kumoai.pquery.PredictionTableJob` depending on the status of the prediction table generation job. """ from kumoai.pquery import PredictionTableJob prediction_table_job_id = self._get_job_resource( require_completed=False).pred_table_job_id if prediction_table_job_id is None: raise RuntimeError( f"Unable to access the prediction table generation job ID for " f"job {self.job_id}. Did this job fail before generating its " f"prediction table, or use a custom prediction table?") fut = PredictionTableJob(prediction_table_job_id) if fut.status().status == JobStatus.DONE: return fut.result() return fut
[docs] def export( self, output_type: str, output_connector: Connector, output_table_name: Optional[Union[str, Tuple]] = None, ) -> str: r"""Export the prediction output or the embedding output to the specific output location. Args: output_type: The type of output that should be export by the job. Can be either ``'predictions'`` or ``'embeddings'``. output_connector: The output data source that Kumo should write batch predictions to. output_table_name: The name of the table in the output data source that Kumo should write batch predictions to. In the case of a Databricks connector, this should be a tuple of two strings: the schema name and the output prediction table name. Returns: str: The artifact export job id. """ validate_output_arguments({output_type}, output_connector, output_table_name) prediction_output_config = build_prediction_output_config( output_type, output_connector, output_table_name) api = global_state.client.artifact_export_api request = ArtifactExportRequest( job_id=self.id, prediction_output=prediction_output_config) return api.create(request)
[docs] def predictions_urls(self) -> List[str]: r"""Returns a URL for downloading or reading the predictions. If Kumo is deployed as a SaaS application, the returned URL will be a presigned AWS S3 URL. If Kumo is deployed as a Snowpark Container Services application, the returned URL will be a Snowflake stage path that can be directly accessed within a Snowflake worksheet. """ api: BatchPredictionJobAPI = ( global_state.client.batch_prediction_job_api) return api.get_batch_predictions_url(self.job_id)
[docs] def predictions_df(self) -> pd.DataFrame: r"""Returns a :class:`~pandas.DataFrame` object representing the generated training data. .. warning:: This method will load the full prediction output into memory as a :class:`~pandas.DataFrame` object. If you are working on a machine with limited resources, please use :meth:`~kumoai.trainer.BatchPredictionResult.predictions_urls` instead to download the data and perform analysis per-partition. """ urls = self.predictions_urls() try: return pd.concat([pd.read_parquet(x) for x in urls]) except Exception as e: raise ValueError( f"Could not create a Pandas DataFrame object from data paths " f"{urls}. Please construct the DataFrame manually.") from e
def _get_job_resource( self, require_completed: bool) -> BatchPredictionJobResource: if self._job_resource: return self._job_resource try: api = global_state.client.batch_prediction_job_api resource: BatchPredictionJobResource = api.get(self.job_id) except Exception as e: raise RuntimeError( f"Batch prediction job {self.job_id} was not found in the " f"Kumo database. Please contact Kumo for further assistance. " ) from e if not require_completed: return resource status = resource.job_status_report.status if not status.is_terminal: raise RuntimeError( f"Batch prediction job {self.job_id} has not yet completed. " f"Please create a `BatchPredictionJob` class and await " "its completion before attempting to view stats.") if status != JobStatus.DONE: validation_resp = resource.job_status_report.validation_response validation_message = "" if validation_resp: validation_message = validation_resp.message() if len(validation_message) > 0: validation_message = f"Details: {validation_message}" raise ValueError( f"Batch prediction job {self.job_id} completed with status " f"{status}, and was therefore unable to produce metrics. " f"{validation_message}") self._job_resource = resource return resource
# Training Job Future #########################################################
[docs]class TrainingJob(KumoFuture[TrainingJobResult]): r"""Represents an in-progress training job. A :class:`TrainingJob` object can either be created as the result of :meth:`~kumoai.trainer.Trainer.fit` with ``non_blocking=True``, or directly with a training job ID (*e.g.* of a job created asynchronously in a different environment). .. code-block:: python import kumoai # See `Trainer` documentation: trainer = kumoai.Trainer(...) # If a Trainer is `fit` in nonblocking mode, the response will be of # type `TrainingJob`: training_job = trainer.fit(..., non_blocking=True) # You can also construct a `TrainingJob` from a job ID, e.g. one that # is present in the Kumo Jobs page: training_job = kumoai.TrainingJob("trainingjob-...") # Get the status of the job: print(training_job.status()) # Attach to the job, and poll progress updates: training_job.attach() # Training: 70%|█████████ | [300s<90s, trial=4, train_loss=1.056, val_loss=0.682, val_mae=35.709, val_mse=7906.239, val_rmse=88.917 # Cancel the job: training_job.cancel() # Wait for the job to complete, and return a `TrainingJobResult`: training_job.result() Args: job_id: The training job ID to await completion of. """ # noqa
[docs] def __init__(self, job_id: TrainingJobID) -> None: self.job_id = job_id self._fut: concurrent.futures.Future = create_future( _poll_training(job_id))
@property def id(self) -> TrainingJobID: r"""The unique ID of this training job.""" return self.job_id
[docs] @override def result(self) -> TrainingJobResult: return self._fut.result()
[docs] @override def future(self) -> 'concurrent.futures.Future[TrainingJobResult]': return self._fut
@property def tracking_url(self) -> str: r"""Returns a tracking URL pointing to the UI that can be used to monitor the status of an ongoing or completed job. """ return _rewrite_tracking_url(self.status().tracking_url)
[docs] def attach(self) -> TrainingJobResult: r"""Allows a user to attach to a running training job, and view its progress inline. Example: >>> job_future = kumoai.TrainingJob(job_id="...") # doctest: +SKIP >>> job_future.attach() # doctest: +SKIP Attaching to training job <id>. To track this job... Training: 70%|█████████ | [300s<90s, trial=4, train_loss=1.056, val_loss=0.682, val_mae=35.709, val_mse=7906.239, val_rmse=88.917 """ # noqa print(f"Attaching to training job {self.job_id}. To track this job in " f"the Kumo UI, please visit {self.tracking_url}. To detach from " f"this job, please enter Ctrl+C: the job will continue to run, " f"and you can re-attach anytime by calling the `attach()` " f"method on the `TrainingJob` object. For example: " f"kumoai.TrainingJob(\"{self.job_id}\").attach()") # TODO(manan): this is not perfect, the `asyncio.sleep` in the poller # may cause a "DONE" status to be printed for up to 2*`timeout` seconds # before the future resolves. That's probably fine: if self.done(): return self.result() # For every non-training stage, just show the stage and status: print("Waiting for job to start.") current_status = JobStatus.NOT_STARTED while current_status == JobStatus.NOT_STARTED: report = self.status() current_status = report.status current_stage = report.event_log[-1].stage_name time.sleep(2) prev_stage = current_stage print(f"Current stage: {current_stage}. In progress...", end="", flush=True) while not self.done(): # Print status of stage: if current_stage != prev_stage: print(" Done.") print(f"Current stage: {current_stage}. In progress...", end="", flush=True) if current_stage == "Training": _time = self.progress().estimated_training_time if _time and _time != 0: break time.sleep(2) report = self.status() prev_stage = current_stage current_stage = report.event_log[-1].stage_name # We are not on Training: if self.done(): return self.result() # We are training: print a progress bar progress = self.progress() bar_format = '{desc}: {percentage:3.0f}%|{bar}|{unit} ' total = int(progress.estimated_training_time) elapsed = int(progress.elapsed_training_time) pbar = tqdm(desc="Training", unit="% done", bar_format=bar_format, total=total, dynamic_ncols=True) pbar.update(elapsed) while not self.done(): progress = self.progress() trial_no = min(progress.completed_trials + 1, progress.total_trials) if f'{max(trial_no-1, 0)}' in progress.trial_progress: trial_metrics = progress.trial_progress[ f'{max(trial_no-1, 0)}'].metrics elif f'{max(trial_no-2, 0)}' in progress.trial_progress: trial_metrics = progress.trial_progress[ f'{max(trial_no-2, 0)}'].metrics else: trial_metrics = {} # If we don't have metrics, wait until we do: if len(trial_metrics) == 0: continue # Show all metrics: # TODO(manan): only show tune metric, trial, epoch, and loss: last_epoch_metrics = trial_metrics[sorted( trial_metrics.keys())[-1]] train_metrics_s = ", ".join([ f"{key_name}={key_val:.3f}" for key_name, key_val in last_epoch_metrics.train_metrics.items() ]) val_metrics_s = ", ".join([ f"{key_name}={key_val:.3f}" for key_name, key_val in last_epoch_metrics.validation_metrics.items() ]) # Update numbers: delta = int(progress.elapsed_training_time - pbar.n) total = int(progress.estimated_training_time) pbar.update(delta) pbar.total = total if pbar.n > pbar.total: pbar.total = pbar.n # NOTE we use `unit` here as a hack, instead of `set_postfix`, # since `tqdm` defaults to adding a comma before the postfix # (https://github.com/tqdm/tqdm/issues/712) pbar.unit = (f"[{pbar.n}s<{pbar.total-pbar.n}s, trial={trial_no}, " f"{train_metrics_s}, {val_metrics_s}]") pbar.refresh() time.sleep(2) pbar.update(pbar.total - pbar.n) pbar.close() # Future is done: return self.result()
[docs] def progress(self) -> AutoTrainerProgress: r"""Returns the progress of an ongoing or completed training job.""" api = global_state.client.training_job_api return api.get_progress(self.job_id)
[docs] def status(self) -> JobStatusReport: r"""Returns the status of a running training job.""" return _get_training_status(self.job_id)
[docs] def cancel(self) -> bool: r"""Cancels a running training job, and returns ``True`` if cancellation succeeded. Example: >>> job_future = kumoai.TrainingJob(job_id="...") # doctest: +SKIP >>> job_future.cancel() # doctest: +SKIP """ # noqa api = global_state.client.training_job_api return api.cancel(self.job_id).is_cancelled
[docs] def delete_tags(self, tags: List[str]) -> bool: r"""Removes the tags from the job. Args: tags (List[str]): The tags to remove. """ api = global_state.client.training_job_api return api.delete_tags(self.job_id, tags)
[docs] def update_tags(self, tags: Mapping[str, Optional[str]]) -> bool: r"""Updates the tags of the job. Args: tags (Mapping[str, Optional[str]]): The tags to update. Note that the value 'none' will remove the tag. If the tag is not present, it will be added. """ api = global_state.client.training_job_api return api.update_tags(self.job_id, tags)
def __repr__(self) -> str: return f'{self.__class__.__name__}(job_id={self.job_id})'
def _get_training_status(job_id: str) -> JobStatusReport: api = global_state.client.training_job_api resource: TrainingJobResource = api.get(job_id) return resource.job_status_report async def _poll_training(job_id: str) -> TrainingJobResult: # TODO(manan): make asynchronous natively with aiohttp: status = _get_training_status(job_id).status while not status.is_terminal: await asyncio.sleep(10) status = _get_training_status(job_id).status # TODO(manan, siyang): improve if status != JobStatus.DONE: api = global_state.client.training_job_api job_resource = api.get(job_id) validation_resp = (job_resource.job_status_report.validation_response) validation_message = "" if validation_resp: validation_message = validation_resp.message() if len(validation_message) > 0: validation_message = f"Details: {validation_message}" raise RuntimeError(f"Training job {job_id} completed with job status " f"{status}. {validation_message}") # TODO(manan): improve return TrainingJobResult(job_id=job_id) # Batch Prediction Job Future #################################################
[docs]class BatchPredictionJob(KumoFuture[BatchPredictionJobResult]): r"""Represents an in-progress batch prediction job. A :class:`BatchPredictionJob` object can either be created as the result of :meth:`~kumoai.trainer.Trainer.predict` with ``non_blocking=True``, or directly with a batch prediction job ID (*e.g.* of a job created asynchronously in a different environment). .. code-block:: python import kumoai # See `Trainer` documentation: trainer = kumoai.Trainer(...) # If a Trainer `predict` is called in nonblocking mode, the response # will be of type `BatchPredictionJob`: prediction_job = trainer.predict(..., non_blocking=True) # You can also construct a `BatchPredictionJob` from a job ID, e.g. one # that is present in the Kumo Jobs page: prediction_job = kumoai.BatchPredictionJob("bp-job-...") # Get the status of the job: print(prediction_job.status()) # Attach to the job, and poll progress updates: prediction_job.attach() # Attaching to batch prediction job <id>. To track this job... # Predicting (job_id=..., start=..., elapsed=..., status=...). Stage: ... # Cancel the job: prediction_job.cancel() # Wait for the job to complete, and return a `BatchPredictionJobResult`: prediction_job.result() Args: job_id: The batch prediction job ID to await completion of. """ # noqa
[docs] def __init__(self, job_id: BatchPredictionJobID) -> None: self.job_id = job_id self._fut: concurrent.futures.Future = create_future( _poll_batch_prediction(job_id))
@property def id(self) -> BatchPredictionJobID: r"""The unique ID of this batch prediction job.""" return self.job_id
[docs] @override def result(self) -> BatchPredictionJobResult: return self._fut.result()
[docs] @override def future(self) -> 'concurrent.futures.Future[BatchPredictionJobResult]': return self._fut
@property def tracking_url(self) -> str: r"""Returns a tracking URL pointing to the UI that can be used to monitor the status of an ongoing or completed job. """ return _rewrite_tracking_url(self.status().tracking_url)
[docs] def attach(self) -> BatchPredictionJobResult: r"""Allows a user to attach to a running batch prediction job, and view its progress inline. """ print(f"Attaching to batch prediction job {self.job_id}. To track " f"this job in the Kumo UI, please visit {self.tracking_url}. To " f"detach from this job, please enter Ctrl+C (the job will " f"continue to run, and you can re-attach anytime).") # TODO(manan): this is not perfect, the `asyncio.sleep` in the poller # may cause a "DONE" status to be printed for up to 2*`timeout` seconds # before the future resolves. That's probably fine: if self.done(): return self.result() print("Waiting for job to start.") current_status = JobStatus.NOT_STARTED while current_status == JobStatus.NOT_STARTED: report = self.status() current_status = report.status current_stage = report.event_log[-1].stage_name time.sleep(2) prev_stage = current_stage print(f"Current stage: {current_stage}. In progress...", end="", flush=True) while not self.done(): # Print status of stage: if current_stage != prev_stage: print(" Done.") print(f"Current stage: {current_stage}. In progress...", end="", flush=True) if current_stage == "Predicting": _time = self.progress().estimated_prediction_time if _time and _time != 0: break time.sleep(2) report = self.status() prev_stage = current_stage current_stage = report.event_log[-1].stage_name # We are not on Batch Prediction: if self.done(): return self.result() # We are predicting: print a progress bar bar_format = '{desc}: {percentage:3.0f}%|{bar} ' total_iterations, elapsed = 0, 0 pbar = tqdm(desc="Predicting", unit="% done", bar_format=bar_format, total=100, dynamic_ncols=True) pbar.update(elapsed) while not self.done(): progress = self.progress() if progress is None: time.sleep(2) continue total_iterations = progress.total_iterations completed_iterations = progress.completed_iterations pbar.update( (completed_iterations - elapsed) / total_iterations * 100) elapsed = completed_iterations elapsed_pct = completed_iterations / total_iterations pbar.refresh() time.sleep(2) pbar.update(1.0 - elapsed_pct) pbar.close() # Future is done: return self.result()
[docs] def progress(self) -> PredictionProgress: r"""Returns the progress of an ongoing or completed batch prediction job. """ api = global_state.client.batch_prediction_job_api return api.get_progress(self.job_id)
[docs] def status(self) -> JobStatusReport: r"""Returns the status of a running batch prediction job.""" return _get_batch_prediction_status(self.job_id)
[docs] def cancel(self) -> bool: r"""Cancels a running batch prediction job, and returns ``True`` if cancellation succeeded. """ api = global_state.client.batch_prediction_job_api return api.cancel(self.job_id).is_cancelled
[docs] def delete_tags(self, tags: List[str]) -> bool: r"""Removes the tags from the job. Args: tags (List[str]): The tags to remove. """ api = global_state.client.batch_prediction_job_api return api.delete_tags(self.job_id, tags)
[docs] def update_tags(self, tags: Mapping[str, Optional[str]]) -> bool: r"""Updates the tags of the job. Args: tags (Mapping[str, Optional[str]]): The tags to update. Note that the value 'none' will remove the tag. If the tag is not present, it will be added. """ api = global_state.client.batch_prediction_job_api return api.update_tags(self.job_id, tags)
def __repr__(self) -> str: return f'{self.__class__.__name__}(job_id={self.job_id})'
def _get_batch_prediction_job(job_id: str) -> BatchPredictionJobResource: api = global_state.client.batch_prediction_job_api return api.get(job_id) def _get_batch_prediction_status(job_id: str) -> JobStatusReport: api = global_state.client.batch_prediction_job_api resource: BatchPredictionJobResource = api.get(job_id) return resource.job_status_report async def _poll_batch_prediction(job_id: str) -> BatchPredictionJobResult: # TODO(manan): make asynchronous natively with aiohttp: job_resource = _get_batch_prediction_job(job_id) status = job_resource.job_status_report.status while not status.is_terminal: await asyncio.sleep(10) job_resource = _get_batch_prediction_job(job_id) status = job_resource.job_status_report.status # TODO(manan, siyang): improve if status != JobStatus.DONE: validation_resp = job_resource.job_status_report.validation_response validation_message = "" if validation_resp: validation_message = validation_resp.message() if len(validation_message) > 0: validation_message = f"Details: {validation_message}" raise ValueError( f"Batch prediction job {job_id} completed with status " f"{status}, and was therefore unable to produce metrics. " f"{validation_message}") # TODO(manan): improve return BatchPredictionJobResult(job_id=job_id) # Baseline Job Future ################################################# class BaselineJob(KumoFuture[BaselineJobResult]): r"""Represents an in-progress baseline job. A :class:`BaselineJob` object can either be created as the result of :meth:`~kumoai.trainer.BaselineTrainer.run` with ``non_blocking=True``, or directly with a baseline job ID (*e.g.* of a job created asynchronously in a different environment). Args: job_id: The baseline job ID to await completion of. Example: >>> import kumoai # doctest: +SKIP >>> id = "some_baseline_job_id" >>> job_future = kumoai.BaselineJob(id) # doctest: +SKIP >>> job_future.attach() # doctest: +SKIP Attaching to baseline job <id>. To track this job... """ # noqa def __init__(self, job_id: BaselineJobID) -> None: self.job_id = job_id self._fut: concurrent.futures.Future = create_future( _poll_baseline(job_id)) @property def id(self) -> BaselineJobID: r"""The unique ID of this training job.""" return self.job_id @override def result(self) -> BaselineJobResult: return self._fut.result() @override def future(self) -> 'concurrent.futures.Future[BaselineJobResult]': return self._fut @property def tracking_url(self) -> str: r"""Returns a tracking URL pointing to the UI that can be used to monitor the status of an ongoing or completed job. """ return "" def attach(self) -> BaselineJobResult: r"""Allows a user to attach to a running baseline job, and view its progress inline. Example: >>> job_future = kumoai.BaselineJob(job_id="...") # doctest: +SKIP >>> job_future.attach() # doctest: +SKIP Attaching to baseline job <id>. To track this job... """ # noqa print(f"Attaching to baseline job {self.job_id}." f"To detach from " f"this job, please enter Ctrl+C (the job will continue to run, " f"and you can re-attach anytime).") while not self.done(): report = self.status() status = report.status latest_event = report.event_log[-1] stage = latest_event.stage_name detail = ", " + latest_event.detail if latest_event.detail else "" start = report.start_time now = datetime.now(timezone.utc) print(f"Baseline job (job_id={self.job_id} start={start}, elapsed=" f"{now-start}, status={status}). Stage: {stage}{detail}") time.sleep(10) # Future is done: return self.result() def status(self) -> JobStatusReport: r"""Returns the status of a running baseline job.""" return _get_baseline_status(self.job_id) def __repr__(self) -> str: return f'{self.__class__.__name__}(job_id={self.job_id})' def _get_baseline_status(job_id: str) -> JobStatusReport: api = global_state.client.baseline_job_api resource: BaselineJobResource = api.get(job_id) return resource.job_status_report async def _poll_baseline(job_id: str) -> BaselineJobResult: status = _get_baseline_status(job_id).status while not status.is_terminal: await asyncio.sleep(10) status = _get_baseline_status(job_id).status if status != JobStatus.DONE: raise RuntimeError( f"Baseline job {job_id} failed with job status {status}.") return BaselineJobResult(job_id=job_id) def _rewrite_tracking_url(tracking_url: str) -> str: r"""Rewrites tracking URL to account for deployment subdomains.""" # TODO(manan): improve... if 'http' not in tracking_url: return tracking_url parsed_base = urlparse(global_state.client._url) parsed_tracking = urlparse(tracking_url) tracking_url = urlunparse(( parsed_base.scheme, parsed_base.netloc, parsed_tracking.path, parsed_tracking.params, parsed_tracking.query, parsed_tracking.fragment, )) return tracking_url