kumoai.trainer.DistillationTrainer#
- class kumoai.trainer.DistillationTrainer[source]#
Bases:
objectTrains a shallow model for online serving using embeddings from a base (deep GNN) job. Distillation fits a
PredictiveQuerywhile reusing representations from the base model identified bybase_training_job_id.- Parameters:
model_plan (
DistilledModelPlan) – The distilled model plan to use for the distillation process.base_training_job_id (
str) – The ID of the base training job to use for the distillation process.
- property is_trained: bool#
Returns
Trueif this trainer instance has successfully been trained (and is therefore ready for prediction);Falseotherwise.
- fit(graph, train_table, *, non_blocking=False, custom_tags={})[source]#
Fits a model to the specified graph and training table, with the strategy defined by
DistilledTrainer’smodel_plan.- Parameters:
graph (
Graph) – TheGraphobject that represents the tables and relationships that Kumo will learn from.train_table (
Union[TrainingTable,TrainingTableJob]) – TheTrainingTable, or in-progressTrainingTableJob, that represents the training data produced by aPredictiveQueryongraph.non_blocking (
bool) – Whether this operation should return immediately after launching the training job, or await completion of the training job.custom_tags (
Mapping[str,str]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.
- Returns:
If
non_blocking=False, returns a training job object. Ifnon_blocking=True, returns a training job future object.- Return type:
Union[TrainingJobResult, TrainingJob]