kumoai.trainer.Trainer#
- class kumoai.trainer.Trainer[source]#
Bases:
object
A trainer supports creating a Kumo machine learning model on a
PredictiveQuery
. It is primarily oriented around two methods:fit()
, which accepts aGraph
andTrainingTable
and produces aTrainingJobResult
, andpredict()
, which accepts aGraph
andPredictionTable
and produces aBatchPredictionJobResult
.A
Trainer
can also be loaded from a training job, withload()
.import kumoai # See `Graph` and `PredictiveQuery` documentation: graph = kumoai.Graph(...) pquery = kumoai.PredictiveQuery(graph=graph, query=...) # Create a `Trainer` object, using a suggested model plan given the # predictive query: model_plan = pquery.suggest_model_plan() trainer = kumoai.Trainer(model_plan) # Create a training table from the predictive query: training_table = pquery.generate_training_table() # Fit a model: training_job = trainer.fit( graph = graph, training_table = training_table, ) # Create a prediction table from the predictive query: prediction_table = pquery.generate_prediction_table() # Generate predictions: prediction_job = trainer.predict( graph = graph, prediction_table = prediction_table, # other arguments here... ) # Load a trained query to generate predictions: pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...") trainer = kumoai.Trainer.load("trainingjob-...") prediction_job = trainer.predict( pquery.graph, pquery.generate_prediction_table(), )
- Parameters:
model_plan (
ModelPlan
) – A model plan that the trainer should follow when fitting a Kumo model to a predictive query. This model plan can either be generated from a predictive query, withsuggest_model_plan()
, or can be manually specified.
- property is_trained: bool#
Returns
True
if this trainer instance has successfully been trained (and is therefore ready for prediction);False
otherwise.
- fit(graph, train_table, *, non_blocking=False, custom_tags={})[source]#
Fits a model to the specified graph and training table, with the strategy defined by this
Trainer
’smodel_plan
.- Parameters:
graph (
Graph
) – TheGraph
object that represents the tables and relationships that Kumo will learn from.train_table (
Union
[TrainingTable
,TrainingTableJob
]) – TheTrainingTable
, or in-progressTrainingTableJob
, that represents the training data produced by aPredictiveQuery
ongraph
.non_blocking (
bool
) – Whether this operation should return immediately after launching the training job, or await completion of the training job.custom_tags (
Mapping
[str
,str
]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.
- Returns:
If
non_blocking=False
, returns a training job object. Ifnon_blocking=True
, returns a training job future object.- Return type:
Union[TrainingJobResult, TrainingJob]
- predict(graph, prediction_table, output_types=None, output_connector=None, output_table_name=None, output_metadata_fields=None, output_config=None, training_job_id=None, binary_classification_threshold=None, num_classes_to_return=None, num_workers=1, non_blocking=False, custom_tags={})[source]#
Using the trained model specified by
training_job_id
(or the model corresponding to the last invocation offit()
, if not present), predicts the future values of the entities inprediction_table
leveraging current information fromgraph
.- Parameters:
graph (
Graph
) – TheGraph
object that represents the tables and relationships that Kumo will use to make predictions.prediction_table (
Union
[PredictionTable
,PredictionTableJob
]) – ThePredictionTable
, or in-progressPredictionTableJob
, that represents the prediction data produced by aPredictiveQuery
ongraph
. This table may also be custom-generated.output_config (
Optional
[OutputConfig
]) – Output configuration defining properties of the generated batch prediction outputs. Typically aspects like destination for the output, any additional columns needed, whether to override or append to an existing table etc. is supplied via this.output_types (
Optional
[Set
[str
]]) – The types of outputs that should be produced by the prediction job. Can include either'predictions'
,'embeddings'
, or both.output_connector (
Optional
[Connector
]) – The output data source that Kumo should write batch predictions to, if it is None, produce local download output only.output_table_name (
Union
[str
,Tuple
,None
]) – The name of the table in the output data source that Kumo should write batch predictions to. In the case of a Databricks connector, this should be a tuple of two strings: the schema name and the output prediction table name.output_metadata_fields (
Optional
[List
[MetadataField
]]) – Any additional metadata fields to include as new columns in the produced'predictions'
output. Currently, allowed options areJOB_TIMESTAMP
andANCHOR_TIMESTAMP
.training_job_id (
Optional
[str
]) – The job ID of the training job whose model will be used for making predictions.binary_classification_threshold (
Optional
[float
]) – If this model corresponds to a binary classification task, the threshold for which higher values correspond to1
, and lower values correspond to0
.num_classes_to_return (
Optional
[int
]) – If this model corresponds to a ranking task, the number of classes to return in the prediction output.num_workers (
int
) – Number of workers to use when generating batch predictions. When set to a value greater than 1, the prediction table is partitioned into smaller parts and processed in parallel. The default is 1, which implies sequential inference over the prediction table.non_blocking (
bool
) – Whether this operation should return immediately after launching the batch prediction job, or await completion of the batch prediction job.custom_tags (
Mapping
[str
,str
]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.
- Returns:
If
non_blocking=False
, returns a training job object. Ifnon_blocking=True
, returns a training job future object.- Return type: