kumoai.trainer.Trainer#
- class kumoai.trainer.Trainer[source]#
Bases:
object
A trainer supports creating a Kumo machine learning model on a
PredictiveQuery
. It is primarily oriented around two methods:fit()
, which accepts aGraph
andTrainingTable
and produces aTrainingJobResult
, andpredict()
, which accepts aGraph
andPredictionTable
and produces aBatchPredictionJobResult
.A
Trainer
can also be loaded from a training job, withload()
.import kumoai # See `Graph` and `PredictiveQuery` documentation: graph = kumoai.Graph(...) pquery = kumoai.PredictiveQuery(graph=graph, query=...) # Create a `Trainer` object, using a suggested model plan given the # predictive query: model_plan = pquery.suggest_model_plan() trainer = kumoai.Trainer(model_plan) # Create a training table from the predictive query: training_table = pquery.generate_training_table() # Fit a model: training_job = trainer.fit( graph = graph, training_table = training_table, ) # Create a prediction table from the predictive query: prediction_table = pquery.generate_prediction_table() # Generate predictions: prediction_job = trainer.predict( graph = graph, prediction_table = prediction_table, # other arguments here... ) # Load a trained query to generate predictions: pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...") trainer = kumoai.Trainer.load("trainingjob-...") prediction_job = trainer.predict( pquery.graph, pquery.generate_prediction_table(), )
- Parameters:
model_plan (
ModelPlan
) – A model plan that the trainer should follow when fitting a Kumo model to a predictive query. This model plan can either be generated from a predictive query, withsuggest_model_plan()
, or can be manually specified.
- property is_trained: bool#
Returns
True
if this trainer instance has successfully been trained (and is therefore ready for prediction);False
otherwise.
- fit(graph, train_table, *, non_blocking=False, custom_tags={})[source]#
Fits a model to the specified graph and training table, with the strategy defined by this
Trainer
’smodel_plan
.- Parameters:
graph (
Graph
) – TheGraph
object that represents the tables and relationships that Kumo will learn from.train_table (
Union
[TrainingTable
,TrainingTableJob
]) – TheTrainingTable
, or in-progressTrainingTableJob
, that represents the training data produced by aPredictiveQuery
ongraph
.non_blocking (
bool
) – Whether this operation should return immediately after launching the training job, or await completion of the training job.custom_tags (
Mapping
[str
,str
]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.
- Returns:
If
non_blocking=False
, returns a training job object. Ifnon_blocking=True
, returns a training job future object.- Return type:
Union[TrainingJobResult, TrainingJob]
- predict(graph, prediction_table, output_types, output_connector=None, output_table_name=None, training_job_id=None, binary_classification_threshold=None, num_classes_to_return=None, num_workers=1, non_blocking=False, custom_tags={})[source]#
Using the trained model specified by
training_job_id
(or the model corresponding to the last invocation offit()
, if not present), predicts the future values of the entities inprediction_table
leveraging current information fromgraph
.- Parameters:
graph (
Graph
) – TheGraph
object that represents the tables and relationships that Kumo will use to make predictions.prediction_table (
Union
[PredictionTable
,PredictionTableJob
]) – ThePredictionTable
, or in-progressPredictionTableJob
, that represents the prediction data produced by aPredictiveQuery
ongraph
. This table may also be custom-generated.output_types (
Set
[str
]) – The types of outputs that should be produced by the prediction job. Can include either'predictions'
,'embeddings'
, or both.output_connector (
Optional
[Connector
]) – The output data source that Kumo should write batch predictions to, if it is None, produce local download output only.output_table_name (
Union
[str
,Tuple
,None
]) – The name of the table in the output data source that Kumo should write batch predictions to. In the case of a Databricks connector, this should be a tuple of two strings: the schema name and the output prediction table name.training_job_id (
Optional
[str
]) – The job ID of the training job whose model will be used for making predictions.binary_classification_threshold (
Optional
[float
]) – If this model corresponds to a binary classification task, the threshold for which higher values correspond to1
, and lower values correspond to0
.num_classes_to_return (
Optional
[int
]) – If this model corresponds to a ranking task, the number of classes to return in the prediction output.num_workers (
int
) – Number of workers to use when generating batch predictions. When set to a value greater than 1, the prediction table is partitioned into smaller parts and processed in parallel. The default is 1, which implies sequential inference over the prediction table.non_blocking (
bool
) – Whether this operation should return immediately after launching the batch prediction job, or await completion of the batch prediction job.custom_tags (
Mapping
[str
,str
]) – Additional, customer defined k-v tags to be associated with the job to be launched. Job tags are useful for grouping and searching jobs.
- Return type: