kumoai.trainer.Trainer#

class kumoai.trainer.Trainer[source]#

Bases: object

A trainer supports creating a Kumo machine learning model on a PredictiveQuery. It is primarily oriented around two methods: fit(), which accepts a Graph and TrainingTable and produces a TrainingJobResult, and predict(), which accepts a Graph and PredictionTable and produces a BatchPredictionJobResult.

A Trainer can also be loaded from a training job, with load().

import kumoai

# See `Graph` and `PredictiveQuery` documentation:
graph = kumoai.Graph(...)
pquery = kumoai.PredictiveQuery(graph=graph, query=...)

# Create a `Trainer` object, using a suggested model plan given the
# predictive query:
model_plan = pquery.suggest_model_plan()
trainer = kumoai.Trainer(model_plan)

# Create a training table from the predictive query:
training_table = pquery.generate_training_table()

# Fit a model:
training_job = trainer.fit(
    graph = graph,
    training_table = training_table,
)

# Create a prediction table from the predictive query:
prediction_table = pquery.generate_prediction_table()

# Generate predictions:
prediction_job = trainer.predict(
    graph = graph,
    prediction_table = prediction_table,
    # other arguments here...
)

# Load a trained query to generate predictions:
pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...")
trainer = kumoai.Trainer.load("trainingjob-...")
prediction_job = trainer.predict(
    pquery.graph,
    pquery.generate_prediction_table(),
)
Parameters:

model_plan (ModelPlan) – A model plan that the trainer should follow when fitting a Kumo model to a predictive query. This model plan can either be generated from a predictive query, with suggest_model_plan(), or can be manually specified.

__init__(model_plan)[source]#
property is_trained: bool#

Returns True if this trainer instance has successfully been trained (and is therefore ready for prediction); False otherwise.

fit(graph, train_table, *, non_blocking=False, custom_tags={})[source]#

Fits a model to the specified graph and training table, with the strategy defined by this Trainer’s model_plan.

Parameters:
  • graph (Graph) – The Graph object that represents the tables and relationships that Kumo will learn from.

  • train_table (Union[TrainingTable, TrainingTableJob]) – The TrainingTable, or in-progress TrainingTableJob, that represents the training data produced by a PredictiveQuery on graph.

  • non_blocking (bool) – Whether this operation should return immediately after launching the training job, or await completion of the training job.

  • custom_tags (Mapping[str, str]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.

Returns:

If non_blocking=False, returns a training job object. If non_blocking=True, returns a training job future object.

Return type:

Union[TrainingJobResult, TrainingJob]

predict(graph, prediction_table, output_types, output_connector=None, output_table_name=None, training_job_id=None, binary_classification_threshold=None, num_classes_to_return=None, num_workers=1, non_blocking=False, custom_tags={})[source]#

Using the trained model specified by training_job_id (or the model corresponding to the last invocation of fit(), if not present), predicts the future values of the entities in prediction_table leveraging current information from graph.

Parameters:
  • graph (Graph) – The Graph object that represents the tables and relationships that Kumo will use to make predictions.

  • prediction_table (Union[PredictionTable, PredictionTableJob]) – The PredictionTable, or in-progress PredictionTableJob, that represents the prediction data produced by a PredictiveQuery on graph. This table may also be custom-generated.

  • output_types (Set[str]) – The types of outputs that should be produced by the prediction job. Can include either 'predictions', 'embeddings', or both.

  • output_connector (Optional[Connector]) – The output data source that Kumo should write batch predictions to, if it is None, produce local download output only.

  • output_table_name (Union[str, Tuple, None]) – The name of the table in the output data source that Kumo should write batch predictions to. In the case of a Databricks connector, this should be a tuple of two strings: the schema name and the output prediction table name.

  • training_job_id (Optional[str]) – The job ID of the training job whose model will be used for making predictions.

  • binary_classification_threshold (Optional[float]) – If this model corresponds to a binary classification task, the threshold for which higher values correspond to 1, and lower values correspond to 0.

  • num_classes_to_return (Optional[int]) – If this model corresponds to a ranking task, the number of classes to return in the prediction output.

  • num_workers (int) – Number of workers to use when generating batch predictions. When set to a value greater than 1, the prediction table is partitioned into smaller parts and processed in parallel. The default is 1, which implies sequential inference over the prediction table.

  • non_blocking (bool) – Whether this operation should return immediately after launching the batch prediction job, or await completion of the batch prediction job.

  • custom_tags (Mapping[str, str]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.

Return type:

Union[BatchPredictionJob, BatchPredictionJobResult]

classmethod load(job_id)[source]#

Creates a Trainer instance from a training job ID.

Return type:

Trainer

classmethod load_from_tags(tags)[source]#

Creates a Trainer instance from a set of job tags. If multiple jobs match the list of tags, only one will be selected.

Return type:

Trainer