import logging
from typing import Literal, Mapping, Optional, Union, overload
from kumoapi.distilled_model_plan import DistilledModelPlan
from kumoapi.jobs import DistillationJobRequest, DistillationJobResource
from kumoai import global_state
from kumoai.client.jobs import TrainingJobID
from kumoai.graph import Graph
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
from kumoai.trainer.job import TrainingJob, TrainingJobResult
logger = logging.getLogger(__name__)
[docs]
class DistillationTrainer:
r"""Trains a shallow model for online serving using embeddings from a base
(deep GNN) job. Distillation fits a :class:`~kumoai.pquery.PredictiveQuery`
while reusing representations from the base model identified by
``base_training_job_id``.
Args:
model_plan: The distilled model plan to use for the distillation process.
base_training_job_id: The ID of the base training job to use for the distillation process.
""" # noqa: E501
[docs]
def __init__(
self,
model_plan: DistilledModelPlan,
base_training_job_id: TrainingJobID,
) -> None:
self.model_plan: DistilledModelPlan = model_plan
self.base_training_job_id: TrainingJobID = base_training_job_id
# Cached from backend:
self._training_job_id: Optional[TrainingJobID] = None
# Metadata ################################################################
@property
def is_trained(self) -> bool:
r"""Returns ``True`` if this trainer instance has successfully been
trained (and is therefore ready for prediction); ``False`` otherwise.
"""
raise NotImplementedError(
"Checking if a distilled trainer is trained is not "
"implemented yet.")
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
) -> TrainingJobResult:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: Literal[False],
) -> TrainingJobResult:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: Literal[True],
) -> TrainingJob:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: bool,
) -> Union[TrainingJob, TrainingJobResult]:
pass
[docs]
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: bool = False,
custom_tags: Mapping[str, str] = {},
) -> Union[TrainingJob, TrainingJobResult]:
r"""Fits a model to the specified graph and training table, with the
strategy defined by :class:`DistilledTrainer`'s :obj:`model_plan`.
Args:
graph: The :class:`~kumoai.graph.Graph` object that represents the
tables and relationships that Kumo will learn from.
train_table: The :class:`~kumoai.pquery.TrainingTable`, or
in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
represents the training data produced by a
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
non_blocking: Whether this operation should return immediately
after launching the training job, or await completion of the
training job.
custom_tags: Additional, customer defined k-v tags to be associated
with the job to be launched. Job tags are useful for grouping
and searching jobs.
Returns:
Union[TrainingJobResult, TrainingJob]:
If ``non_blocking=False``, returns a training job object. If
``non_blocking=True``, returns a training job future object.
"""
# TODO(manan, siyang): remove soon:
job_id = train_table.job_id
assert job_id is not None
train_table_job_api = global_state.client.generate_train_table_job_api
pq_id = train_table_job_api.get(job_id).config.pquery_id
assert pq_id is not None
custom_table = None
if isinstance(train_table, TrainingTable):
custom_table = train_table._custom_train_table
# NOTE the backend implementation currently handles sequentialization
# between a training table future and a training job; that is, if the
# training table future is still executing, the backend will wait on
# the job ID completion before executing a training job. This preserves
# semantics for both futures, ensures that Kumo works as expected if
# used only via REST API, and allows us to avoid chaining calllbacks
# in an ugly way here:
api = global_state.client.distillation_job_api
self._training_job_id = api.create(
DistillationJobRequest(
dict(custom_tags),
pquery_id=pq_id,
base_training_job_id=self.base_training_job_id,
distilled_model_plan=self.model_plan,
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
train_table_job_id=job_id,
custom_train_table=custom_table,
))
out = TrainingJob(job_id=self._training_job_id)
if non_blocking:
return out
return out.attach()
@classmethod
def _load_from_job(
cls,
job: DistillationJobResource,
) -> 'DistillationTrainer':
trainer = cls(job.config.distilled_model_plan,
job.config.base_training_job_id)
trainer._training_job_id = job.job_id
return trainer
[docs]
@classmethod
def load(cls, job_id: TrainingJobID) -> 'DistillationTrainer':
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
job ID.
"""
raise NotImplementedError(
"Loading a distilled trainer from a job ID is not implemented yet."
)
@classmethod
def load_from_tags(cls, tags: Mapping[str, str]) -> 'DistillationTrainer':
raise NotImplementedError(
"Loading a distilled trainer from tags is not implemented yet.")