Source code for kumoai.artifact_export.job
import asyncio
import concurrent
import concurrent.futures
import time
from typing import Literal, Union, overload
from kumoapi.common import JobStatus
from kumoapi.jobs import ArtifactExportRequest, ModelOutputConfig
from typing_extensions import override
from kumoai import global_state
from kumoai.futures import KumoProgressFuture, create_future
[docs]
class ArtifactExportResult:
r"""Represents a completed artifact export job."""
[docs]
def __init__(self, job_id: str) -> None:
self.job_id = job_id
[docs]
def tracking_url(self) -> str:
r"""Returns a tracking URL pointing to the UI display of
this prediction export job.
"""
raise NotImplementedError
def __repr__(self) -> str:
return f"{self.__class__.__name__}(job_id={self.job_id})"
[docs]
class ArtifactExportJob(KumoProgressFuture[ArtifactExportResult]):
"""Represents an in-progress artifact export job."""
[docs]
def __init__(self, job_id: str) -> None:
self.job_id = job_id
self._fut: concurrent.futures.Future = create_future(
_poll_export(job_id))
@property
def id(self) -> str:
"""The unique ID of this export job."""
return self.job_id
[docs]
@override
def result(self) -> ArtifactExportResult:
return self._fut.result()
[docs]
@override
def future(self) -> 'concurrent.futures.Future[ArtifactExportResult]':
return self._fut
@override
def _attach_internal(
self,
interval_s: float = 20.0,
) -> ArtifactExportResult:
"""Allows a user to attach to a running export job and view
its progress.
Args:
interval_s (float): Time interval (seconds) between polls, minimum
value allowed is 4 seconds.
"""
assert interval_s >= 4.0
print(f"Attaching to export job {self.job_id}. To detach from "
f"this job, please enter Ctrl+C (the job will continue to run, "
f"and you can re-attach anytime).")
# TODO improve print statements.
# Will require changes to status to return
# JobStatusReport instead of JobStatus.
while not self.done():
status = self.status()
print(f"Export job {self.job_id} status: {status}")
time.sleep(interval_s)
return self.result()
[docs]
def status(self) -> JobStatus:
"""Returns the status of a running export job."""
return get_export_status(self.job_id)
[docs]
def cancel(self) -> bool:
"""Cancels a running export job.
Returns:
bool: True if the job is in a terminal state.
"""
api = global_state.client.artifact_export_api
status = api.cancel(self.job_id)
if status == JobStatus.CANCELLED:
return True
return False
@overload
def export_model(config: ModelOutputConfig, ) -> 'ArtifactExportJob':
pass
@overload
def export_model(
config: ModelOutputConfig,
*,
non_blocking: Literal[True],
) -> 'ArtifactExportJob':
pass
@overload
def export_model(
config: ModelOutputConfig,
*,
non_blocking: Literal[False],
) -> 'ArtifactExportResult':
pass
@overload
def export_model(
config: ModelOutputConfig,
*,
non_blocking: bool,
) -> Union['ArtifactExportJob', 'ArtifactExportResult']:
pass
[docs]
def export_model(
config: ModelOutputConfig,
*,
non_blocking: bool = True,
) -> Union['ArtifactExportJob', 'ArtifactExportResult']:
r"""Export model files and batch prediction embeddings to an external
storage location for use in online serving.
The export copies online serving model dir in its entirety to the output
path, along with the embeddings.parquet result from the BP job specified.
Args:
config: A :class:`~kumoapi.jobs.ModelOutputConfig` specifying the
training job, output path, and batch prediction job to bundle.
non_blocking: If ``True``, returns an
:class:`~kumoai.artifact_export.ArtifactExportJob` future. If
``False``, blocks until complete and returns an
:class:`~kumoai.artifact_export.ArtifactExportResult`.
"""
api = global_state.client.artifact_export_api
job_id = api.create(
ArtifactExportRequest(job_id=config.training_job_id,
model_output=config))
if non_blocking:
return ArtifactExportJob(job_id)
return ArtifactExportJob(job_id).attach()
def get_export_status(job_id: str) -> JobStatus:
api = global_state.client.artifact_export_api
resource = api.get(job_id)
return resource
async def _poll_export(job_id: str) -> ArtifactExportResult:
status = get_export_status(job_id)
while not status.is_terminal:
await asyncio.sleep(10)
status = get_export_status(job_id)
if status != JobStatus.DONE:
raise RuntimeError(f"Export job {job_id} failed "
f"with job status {status}.")
return ArtifactExportResult(job_id=job_id)