Source code for kumoai.trainer.distilled_trainer

import logging
from typing import Literal, Mapping, Optional, Union, overload

from kumoapi.distilled_model_plan import DistilledModelPlan
from kumoapi.jobs import DistillationJobRequest, DistillationJobResource

from kumoai import global_state
from kumoai.client.jobs import TrainingJobID
from kumoai.graph import Graph
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
from kumoai.trainer.job import TrainingJob, TrainingJobResult

logger = logging.getLogger(__name__)


[docs] class DistillationTrainer: r"""Trains a shallow model for online serving using embeddings from a base (deep GNN) job. Distillation fits a :class:`~kumoai.pquery.PredictiveQuery` while reusing representations from the base model identified by ``base_training_job_id``. Args: model_plan: The distilled model plan to use for the distillation process. base_training_job_id: The ID of the base training job to use for the distillation process. """ # noqa: E501
[docs] def __init__( self, model_plan: DistilledModelPlan, base_training_job_id: TrainingJobID, ) -> None: self.model_plan: DistilledModelPlan = model_plan self.base_training_job_id: TrainingJobID = base_training_job_id # Cached from backend: self._training_job_id: Optional[TrainingJobID] = None
# Metadata ################################################################ @property def is_trained(self) -> bool: r"""Returns ``True`` if this trainer instance has successfully been trained (and is therefore ready for prediction); ``False`` otherwise. """ raise NotImplementedError( "Checking if a distilled trainer is trained is not " "implemented yet.") @overload def fit( self, graph: Graph, train_table: Union[TrainingTable, TrainingTableJob], ) -> TrainingJobResult: pass @overload def fit( self, graph: Graph, train_table: Union[TrainingTable, TrainingTableJob], *, non_blocking: Literal[False], ) -> TrainingJobResult: pass @overload def fit( self, graph: Graph, train_table: Union[TrainingTable, TrainingTableJob], *, non_blocking: Literal[True], ) -> TrainingJob: pass @overload def fit( self, graph: Graph, train_table: Union[TrainingTable, TrainingTableJob], *, non_blocking: bool, ) -> Union[TrainingJob, TrainingJobResult]: pass
[docs] def fit( self, graph: Graph, train_table: Union[TrainingTable, TrainingTableJob], *, non_blocking: bool = False, custom_tags: Mapping[str, str] = {}, ) -> Union[TrainingJob, TrainingJobResult]: r"""Fits a model to the specified graph and training table, with the strategy defined by :class:`DistilledTrainer`'s :obj:`model_plan`. Args: graph: The :class:`~kumoai.graph.Graph` object that represents the tables and relationships that Kumo will learn from. train_table: The :class:`~kumoai.pquery.TrainingTable`, or in-progress :class:`~kumoai.pquery.TrainingTableJob`, that represents the training data produced by a :class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`. non_blocking: Whether this operation should return immediately after launching the training job, or await completion of the training job. custom_tags: Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs. Returns: Union[TrainingJobResult, TrainingJob]: If ``non_blocking=False``, returns a training job object. If ``non_blocking=True``, returns a training job future object. """ # TODO(manan, siyang): remove soon: job_id = train_table.job_id assert job_id is not None train_table_job_api = global_state.client.generate_train_table_job_api pq_id = train_table_job_api.get(job_id).config.pquery_id assert pq_id is not None custom_table = None if isinstance(train_table, TrainingTable): custom_table = train_table._custom_train_table # NOTE the backend implementation currently handles sequentialization # between a training table future and a training job; that is, if the # training table future is still executing, the backend will wait on # the job ID completion before executing a training job. This preserves # semantics for both futures, ensures that Kumo works as expected if # used only via REST API, and allows us to avoid chaining calllbacks # in an ugly way here: api = global_state.client.distillation_job_api self._training_job_id = api.create( DistillationJobRequest( dict(custom_tags), pquery_id=pq_id, base_training_job_id=self.base_training_job_id, distilled_model_plan=self.model_plan, graph_snapshot_id=graph.snapshot(non_blocking=non_blocking), train_table_job_id=job_id, custom_train_table=custom_table, )) out = TrainingJob(job_id=self._training_job_id) if non_blocking: return out return out.attach()
@classmethod def _load_from_job( cls, job: DistillationJobResource, ) -> 'DistillationTrainer': trainer = cls(job.config.distilled_model_plan, job.config.base_training_job_id) trainer._training_job_id = job.job_id return trainer
[docs] @classmethod def load(cls, job_id: TrainingJobID) -> 'DistillationTrainer': r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training job ID. """ raise NotImplementedError( "Loading a distilled trainer from a job ID is not implemented yet." )
@classmethod def load_from_tags(cls, tags: Mapping[str, str]) -> 'DistillationTrainer': raise NotImplementedError( "Loading a distilled trainer from tags is not implemented yet.")