kumoai.trainer.Trainer#

class kumoai.trainer.Trainer[source]#

Bases: object

A trainer supports creating a Kumo machine learning model on a PredictiveQuery. It is primarily oriented around two methods: fit(), which accepts a Graph and TrainingTable and produces a TrainingJobResult, and predict(), which accepts a Graph and PredictionTable and produces a BatchPredictionJobResult.

A Trainer can also be loaded from a training job, with load().

import kumoai

# See `Graph` and `PredictiveQuery` documentation:
graph = kumoai.Graph(...)
pquery = kumoai.PredictiveQuery(graph=graph, query=...)

# Create a `Trainer` object, using a suggested model plan given the
# predictive query:
model_plan = pquery.suggest_model_plan()
trainer = kumoai.Trainer(model_plan)

# Create a training table from the predictive query:
training_table = pquery.generate_training_table()

# Fit a model:
training_job = trainer.fit(
    graph = graph,
    training_table = training_table,
)

# Create a prediction table from the predictive query:
prediction_table = pquery.generate_prediction_table()

# Generate predictions:
prediction_job = trainer.predict(
    graph = graph,
    prediction_table = prediction_table,
    # other arguments here...
)

# Load a trained query to generate predictions:
pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...")
trainer = kumoai.Trainer.load("trainingjob-...")
prediction_job = trainer.predict(
    pquery.graph,
    pquery.generate_prediction_table(),
)
Parameters:

model_plan (ModelPlan) – A model plan that the trainer should follow when fitting a Kumo model to a predictive query. This model plan can either be generated from a predictive query, with suggest_model_plan(), or can be manually specified.

__init__(model_plan)[source]#
property is_trained: bool#

Returns True if this trainer instance has successfully been trained (and is therefore ready for prediction); False otherwise.

fit(graph, train_table, *, non_blocking=False, custom_tags={})[source]#

Fits a model to the specified graph and training table, with the strategy defined by this Trainer’s model_plan.

Parameters:
  • graph (Graph) – The Graph object that represents the tables and relationships that Kumo will learn from.

  • train_table (Union[TrainingTable, TrainingTableJob]) – The TrainingTable, or in-progress TrainingTableJob, that represents the training data produced by a PredictiveQuery on graph.

  • non_blocking (bool) – Whether this operation should return immediately after launching the training job, or await completion of the training job.

  • custom_tags (Mapping[str, str]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.

Returns:

If non_blocking=False, returns a training job object. If non_blocking=True, returns a training job future object.

Return type:

Union[TrainingJobResult, TrainingJob]

predict(graph, prediction_table, output_types=None, output_connector=None, output_table_name=None, output_metadata_fields=None, output_config=None, training_job_id=None, binary_classification_threshold=None, num_classes_to_return=None, num_workers=1, non_blocking=False, custom_tags={})[source]#

Using the trained model specified by training_job_id (or the model corresponding to the last invocation of fit(), if not present), predicts the future values of the entities in prediction_table leveraging current information from graph.

Parameters:
  • graph (Graph) – The Graph object that represents the tables and relationships that Kumo will use to make predictions.

  • prediction_table (Union[PredictionTable, PredictionTableJob]) – The PredictionTable, or in-progress PredictionTableJob, that represents the prediction data produced by a PredictiveQuery on graph. This table may also be custom-generated.

  • output_config (Optional[OutputConfig]) – Output configuration defining properties of the generated batch prediction outputs. Typically aspects like destination for the output, any additional columns needed, whether to override or append to an existing table etc. is supplied via this.

  • output_types (Optional[Set[str]]) – The types of outputs that should be produced by the prediction job. Can include either 'predictions', 'embeddings', or both.

  • output_connector (Optional[Connector]) – The output data source that Kumo should write batch predictions to, if it is None, produce local download output only.

  • output_table_name (Union[str, Tuple, None]) – The name of the table in the output data source that Kumo should write batch predictions to. In the case of a Databricks connector, this should be a tuple of two strings: the schema name and the output prediction table name.

  • output_metadata_fields (Optional[List[MetadataField]]) – Any additional metadata fields to include as new columns in the produced 'predictions' output. Currently, allowed options are JOB_TIMESTAMP and ANCHOR_TIMESTAMP.

  • training_job_id (Optional[str]) – The job ID of the training job whose model will be used for making predictions.

  • binary_classification_threshold (Optional[float]) – If this model corresponds to a binary classification task, the threshold for which higher values correspond to 1, and lower values correspond to 0.

  • num_classes_to_return (Optional[int]) – If this model corresponds to a ranking task, the number of classes to return in the prediction output.

  • num_workers (int) – Number of workers to use when generating batch predictions. When set to a value greater than 1, the prediction table is partitioned into smaller parts and processed in parallel. The default is 1, which implies sequential inference over the prediction table.

  • non_blocking (bool) – Whether this operation should return immediately after launching the batch prediction job, or await completion of the batch prediction job.

  • custom_tags (Mapping[str, str]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.

Returns:

If non_blocking=False, returns a training job object. If non_blocking=True, returns a training job future object.

Return type:

Union[BatchPredictionJob, BatchPredictionJobResult]

classmethod load(job_id)[source]#

Creates a Trainer instance from a training job ID.

Return type:

Trainer

classmethod load_from_tags(tags)[source]#

Creates a Trainer instance from a set of job tags. If multiple jobs match the list of tags, only one will be selected.

Return type:

Trainer