kumoai.trainer.DistillationTrainer#

class kumoai.trainer.DistillationTrainer[source]#

Bases: object

Trains a shallow model for online serving using embeddings from a base (deep GNN) job. Distillation fits a PredictiveQuery while reusing representations from the base model identified by base_training_job_id.

Parameters:
  • model_plan (DistilledModelPlan) – The distilled model plan to use for the distillation process.

  • base_training_job_id (str) – The ID of the base training job to use for the distillation process.

__init__(model_plan, base_training_job_id)[source]#
property is_trained: bool#

Returns True if this trainer instance has successfully been trained (and is therefore ready for prediction); False otherwise.

fit(graph, train_table, *, non_blocking=False, custom_tags={})[source]#

Fits a model to the specified graph and training table, with the strategy defined by DistilledTrainer’s model_plan.

Parameters:
  • graph (Graph) – The Graph object that represents the tables and relationships that Kumo will learn from.

  • train_table (Union[TrainingTable, TrainingTableJob]) – The TrainingTable, or in-progress TrainingTableJob, that represents the training data produced by a PredictiveQuery on graph.

  • non_blocking (bool) – Whether this operation should return immediately after launching the training job, or await completion of the training job.

  • custom_tags (Mapping[str, str]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.

Returns:

If non_blocking=False, returns a training job object. If non_blocking=True, returns a training job future object.

Return type:

Union[TrainingJobResult, TrainingJob]

classmethod load(job_id)[source]#

Creates a Trainer instance from a training job ID.

Return type:

DistillationTrainer