Source code for kumoai.artifact_export.job

import asyncio
import concurrent
import concurrent.futures
import time
from typing import Literal, Union, overload

from kumoapi.common import JobStatus
from kumoapi.jobs import ArtifactExportRequest, ModelOutputConfig
from typing_extensions import override

from kumoai import global_state
from kumoai.futures import KumoProgressFuture, create_future


[docs] class ArtifactExportResult: r"""Represents a completed artifact export job."""
[docs] def __init__(self, job_id: str) -> None: self.job_id = job_id
[docs] def tracking_url(self) -> str: r"""Returns a tracking URL pointing to the UI display of this prediction export job. """ raise NotImplementedError
def __repr__(self) -> str: return f"{self.__class__.__name__}(job_id={self.job_id})"
[docs] class ArtifactExportJob(KumoProgressFuture[ArtifactExportResult]): """Represents an in-progress artifact export job."""
[docs] def __init__(self, job_id: str) -> None: self.job_id = job_id self._fut: concurrent.futures.Future = create_future( _poll_export(job_id))
@property def id(self) -> str: """The unique ID of this export job.""" return self.job_id
[docs] @override def result(self) -> ArtifactExportResult: return self._fut.result()
[docs] @override def future(self) -> 'concurrent.futures.Future[ArtifactExportResult]': return self._fut
@override def _attach_internal( self, interval_s: float = 20.0, ) -> ArtifactExportResult: """Allows a user to attach to a running export job and view its progress. Args: interval_s (float): Time interval (seconds) between polls, minimum value allowed is 4 seconds. """ assert interval_s >= 4.0 print(f"Attaching to export job {self.job_id}. To detach from " f"this job, please enter Ctrl+C (the job will continue to run, " f"and you can re-attach anytime).") # TODO improve print statements. # Will require changes to status to return # JobStatusReport instead of JobStatus. while not self.done(): status = self.status() print(f"Export job {self.job_id} status: {status}") time.sleep(interval_s) return self.result()
[docs] def status(self) -> JobStatus: """Returns the status of a running export job.""" return get_export_status(self.job_id)
[docs] def cancel(self) -> bool: """Cancels a running export job. Returns: bool: True if the job is in a terminal state. """ api = global_state.client.artifact_export_api status = api.cancel(self.job_id) if status == JobStatus.CANCELLED: return True return False
@overload def export_model(config: ModelOutputConfig, ) -> 'ArtifactExportJob': pass @overload def export_model( config: ModelOutputConfig, *, non_blocking: Literal[True], ) -> 'ArtifactExportJob': pass @overload def export_model( config: ModelOutputConfig, *, non_blocking: Literal[False], ) -> 'ArtifactExportResult': pass @overload def export_model( config: ModelOutputConfig, *, non_blocking: bool, ) -> Union['ArtifactExportJob', 'ArtifactExportResult']: pass
[docs] def export_model( config: ModelOutputConfig, *, non_blocking: bool = True, ) -> Union['ArtifactExportJob', 'ArtifactExportResult']: r"""Export model files and batch prediction embeddings to an external storage location for use in online serving. The export copies online serving model dir in its entirety to the output path, along with the embeddings.parquet result from the BP job specified. Args: config: A :class:`~kumoapi.jobs.ModelOutputConfig` specifying the training job, output path, and batch prediction job to bundle. non_blocking: If ``True``, returns an :class:`~kumoai.artifact_export.ArtifactExportJob` future. If ``False``, blocks until complete and returns an :class:`~kumoai.artifact_export.ArtifactExportResult`. """ api = global_state.client.artifact_export_api job_id = api.create( ArtifactExportRequest(job_id=config.training_job_id, model_output=config)) if non_blocking: return ArtifactExportJob(job_id) return ArtifactExportJob(job_id).attach()
def get_export_status(job_id: str) -> JobStatus: api = global_state.client.artifact_export_api resource = api.get(job_id) return resource async def _poll_export(job_id: str) -> ArtifactExportResult: status = get_export_status(job_id) while not status.is_terminal: await asyncio.sleep(10) status = get_export_status(job_id) if status != JobStatus.DONE: raise RuntimeError(f"Export job {job_id} failed " f"with job status {status}.") return ArtifactExportResult(job_id=job_id)