import logging
from typing import (
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
Union,
overload,
)
from kumoapi.jobs import (
BatchPredictionOptions,
BatchPredictionRequest,
JobStatus,
MetadataField,
PredictionOutputConfig,
TrainingJobRequest,
TrainingJobResource,
)
from kumoapi.model_plan import ModelPlan
from kumoai import global_state
from kumoai.artifact_export.config import OutputConfig
from kumoai.client.jobs import (
GeneratePredictionTableJobID,
TrainingJobAPI,
TrainingJobID,
)
from kumoai.connector.base import Connector
from kumoai.connector.s3_connector import S3URI
from kumoai.databricks import to_db_table_name
from kumoai.graph import Graph
from kumoai.pquery.prediction_table import PredictionTable, PredictionTableJob
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
from kumoai.trainer.job import (
BatchPredictionJob,
BatchPredictionJobResult,
TrainingJob,
TrainingJobResult,
)
from kumoai.trainer.util import (
build_prediction_output_config,
validate_output_arguments,
)
logger = logging.getLogger(__name__)
[docs]class Trainer:
r"""A trainer supports creating a Kumo machine learning model on a
:class:`~kumoai.pquery.PredictiveQuery`. It is primarily oriented around
two methods: :meth:`~kumoai.trainer.Trainer.fit`, which accepts a
:class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.TrainingTable` and
produces a :class:`~kumoai.trainer.TrainingJobResult`, and
:meth:`~kumoai.trainer.Trainer.predict`, which accepts a
:class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.PredictionTable`
and produces a :class:`~kumoai.trainer.BatchPredictionJobResult`.
A :class:`~kumoai.trainer.Trainer` can also be loaded from a training job,
with :meth:`~kumoai.trainer.Trainer.load`.
.. code-block:: python
import kumoai
# See `Graph` and `PredictiveQuery` documentation:
graph = kumoai.Graph(...)
pquery = kumoai.PredictiveQuery(graph=graph, query=...)
# Create a `Trainer` object, using a suggested model plan given the
# predictive query:
model_plan = pquery.suggest_model_plan()
trainer = kumoai.Trainer(model_plan)
# Create a training table from the predictive query:
training_table = pquery.generate_training_table()
# Fit a model:
training_job = trainer.fit(
graph = graph,
training_table = training_table,
)
# Create a prediction table from the predictive query:
prediction_table = pquery.generate_prediction_table()
# Generate predictions:
prediction_job = trainer.predict(
graph = graph,
prediction_table = prediction_table,
# other arguments here...
)
# Load a trained query to generate predictions:
pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...")
trainer = kumoai.Trainer.load("trainingjob-...")
prediction_job = trainer.predict(
pquery.graph,
pquery.generate_prediction_table(),
)
Args:
model_plan: A model plan that the trainer should follow when fitting a
Kumo model to a predictive query. This model plan can either be
generated from a predictive query, with
:meth:`~kumoai.pquery.PredictiveQuery.suggest_model_plan`, or can
be manually specified.
""" # noqa: E501
[docs] def __init__(self, model_plan: ModelPlan) -> None:
self._model_plan: Optional[ModelPlan] = model_plan
# Cached from backend:
self._training_job_id: Optional[TrainingJobID] = None
@property
def model_plan(self) -> Optional[ModelPlan]:
return self._model_plan
@model_plan.setter
def model_plan(self, model_plan: ModelPlan) -> None:
self._model_plan = model_plan
# Metadata ################################################################
@property
def is_trained(self) -> bool:
r"""Returns ``True`` if this trainer instance has successfully been
trained (and is therefore ready for prediction); ``False`` otherwise.
"""
if not self._training_job_id:
return False
try:
api = global_state.client.training_job_api
res: TrainingJobResource = api.get(self._training_job_id)
except Exception as e: # noqa
logger.exception(
"Failed to fetch training status for training job with ID %s",
self._training_job_id, exc_info=e)
return False
return res.job_status_report.status == JobStatus.DONE
# Fit / predict ###########################################################
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
) -> TrainingJobResult:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: Literal[False],
) -> TrainingJobResult:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: Literal[True],
) -> TrainingJob:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: bool,
) -> Union[TrainingJob, TrainingJobResult]:
pass
[docs] def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: bool = False,
custom_tags: Mapping[str, str] = {},
) -> Union[TrainingJob, TrainingJobResult]:
r"""Fits a model to the specified graph and training table, with the
strategy defined by this :class:`Trainer`'s :obj:`model_plan`.
Args:
graph: The :class:`~kumoai.graph.Graph` object that represents the
tables and relationships that Kumo will learn from.
train_table: The :class:`~kumoai.pquery.TrainingTable`, or
in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
represents the training data produced by a
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
non_blocking: Whether this operation should return immediately
after launching the training job, or await completion of the
training job.
custom_tags: Additional, customer defined k-v tags to be associated
with the job to be launched. Job tags are useful for grouping
and searching jobs.
Returns:
Union[TrainingJobResult, TrainingJob]:
If ``non_blocking=False``, returns a training job object. If
``non_blocking=True``, returns a training job future object.
"""
# TODO(manan, siyang): remove soon:
job_id = train_table.job_id
assert job_id is not None
train_table_job_api = global_state.client.generate_train_table_job_api
pq_id = train_table_job_api.get(job_id).config.pquery_id
assert pq_id is not None
custom_table = None
if isinstance(train_table, TrainingTable):
custom_table = train_table._custom_train_table
# NOTE the backend implementation currently handles sequentialization
# between a training table future and a training job; that is, if the
# training table future is still executing, the backend will wait on
# the job ID completion before executing a training job. This preserves
# semantics for both futures, ensures that Kumo works as expected if
# used only via REST API, and allows us to avoid chaining calllbacks
# in an ugly way here:
api = global_state.client.training_job_api
self._training_job_id = api.create(
TrainingJobRequest(
dict(custom_tags), pquery_id=pq_id,
model_plan=self._model_plan,
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
train_table_job_id=job_id,
custom_train_table=custom_table,))
out = TrainingJob(job_id=self._training_job_id)
if non_blocking:
return out
return out.attach()
[docs] def predict(
self,
graph: Graph,
prediction_table: Union[PredictionTable, PredictionTableJob],
output_types: Optional[Set[str]] = None,
output_connector: Optional[Connector] = None,
output_table_name: Optional[Union[str, Tuple]] = None,
output_metadata_fields: Optional[List[MetadataField]] = None,
output_config: Optional[OutputConfig] = None,
training_job_id: Optional[TrainingJobID] = None,
binary_classification_threshold: Optional[float] = None,
num_classes_to_return: Optional[int] = None,
num_workers: int = 1,
non_blocking: bool = False,
custom_tags: Mapping[str, str] = {},
) -> Union[BatchPredictionJob, BatchPredictionJobResult]:
"""Using the trained model specified by :obj:`training_job_id` (or
the model corresponding to the last invocation of
:meth:`~kumoai.trainer.Trainer.fit`, if not present), predicts the
future values of the entities in :obj:`prediction_table` leveraging
current information from :obj:`graph`.
Args:
graph: The :class:`~kumoai.graph.Graph` object that represents the
tables and relationships that Kumo will use to make
predictions.
prediction_table: The :class:`~kumoai.pquery.PredictionTable`, or
in-progress :class:`~kumoai.pquery.PredictionTableJob`, that
represents the prediction data produced by a
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`. This
table may also be custom-generated.
output_config: Output configuration defining properties of the
generated batch prediction outputs. Typically aspects like
destination for the output, any additional columns needed,
whether to override or append to an existing table etc. is
supplied via this.
output_types: The types of outputs that should be produced by
the prediction job. Can include either ``'predictions'``,
``'embeddings'``, or both.
output_connector: The output data source that Kumo should write
batch predictions to, if it is None, produce local download
output only.
output_table_name: The name of the table in the output data source
that Kumo should write batch predictions to. In the case of
a Databricks connector, this should be a tuple of two strings:
the schema name and the output prediction table name.
output_metadata_fields: Any additional metadata fields to include
as new columns in the produced ``'predictions'`` output.
Currently, allowed options are ``JOB_TIMESTAMP`` and
``ANCHOR_TIMESTAMP``.
training_job_id: The job ID of the training job whose model will be
used for making predictions.
binary_classification_threshold: If this model corresponds to
a binary classification task, the threshold for which higher
values correspond to ``1``, and lower values correspond to
``0``.
num_classes_to_return: If this model corresponds to a ranking task,
the number of classes to return in the prediction output.
num_workers: Number of workers to use when generating batch
predictions. When set to a value greater than 1, the prediction
table is partitioned into smaller parts and processed in
parallel. The default is 1, which implies sequential inference
over the prediction table.
non_blocking: Whether this operation should return immediately
after launching the batch prediction job, or await
completion of the batch prediction job.
custom_tags: Additional, customer defined k-v tags to be associated
with the job to be launched. Job tags are useful for grouping
and searching jobs.
Returns:
Union[BatchPredictionJob, BatchPredictionJobResult]:
If ``non_blocking=False``, returns a training job object. If
``non_blocking=True``, returns a training job future object.
"""
if (output_types is not None or output_connector is not None
or output_table_name is not None
or output_metadata_fields is not None):
raise ValueError(
'Specifying output_types, output_connector, '
'output_metadata_fields '
'and output_table_name as direct inputs to predict() is '
'deprecated. Please use output_config to specify these '
'parameters.')
assert output_config is not None
output_table_name = to_db_table_name(output_config.output_table_name)
validate_output_arguments(
output_config.output_types,
output_config.output_connector,
output_table_name,
)
# Create outputs:
outputs: List[PredictionOutputConfig] = []
for output_type in output_config.output_types:
if output_config.output_connector is None:
# Predictions are generated to the Kumo dataplane, and can
# only be exported via the UI for now:
pass
else:
outputs.append(
build_prediction_output_config(
output_type,
output_config.output_connector,
output_table_name,
output_config.output_metadata_fields,
output_config,
))
training_job_id = training_job_id or self._training_job_id
if training_job_id is None:
raise ValueError(
"Cannot run batch prediction without a specified or saved "
"training job ID. Please either call `fit(...)` or pass a "
"job ID of a completed training job to proceed.")
pred_table_job_id: Optional[GeneratePredictionTableJobID] = \
prediction_table.job_id
pred_table_data_path = None
if pred_table_job_id is None:
assert isinstance(prediction_table, PredictionTable)
if isinstance(prediction_table.table_data_uri, S3URI):
pred_table_data_path = prediction_table.table_data_uri.uri
else:
pred_table_data_path = prediction_table.table_data_uri
api = global_state.client.batch_prediction_job_api
job_id, response = api.maybe_create(
BatchPredictionRequest(
dict(custom_tags),
model_training_job_id=training_job_id,
predict_options=BatchPredictionOptions(
binary_classification_threshold=(
binary_classification_threshold),
num_classes_to_return=num_classes_to_return,
num_workers=num_workers,
),
outputs=outputs,
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
pred_table_job_id=pred_table_job_id,
pred_table_path=pred_table_data_path,
))
message = response.message()
if not response.ok:
raise RuntimeError(f"Prediction failed. {message}")
elif not response.empty():
logger.warning("Prediction produced the following warnings: %s",
message)
assert job_id is not None
self._batch_prediction_job_id = job_id
out = BatchPredictionJob(job_id=self._batch_prediction_job_id)
if non_blocking:
return out
return out.attach()
# Persistence #############################################################
@classmethod
def _load_from_job(cls, job: TrainingJobResource) -> 'Trainer':
trainer = cls(job.config.model_plan)
trainer._training_job_id = job.job_id
return trainer
[docs] @classmethod
def load(cls, job_id: TrainingJobID) -> 'Trainer':
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
job ID.
"""
api: TrainingJobAPI = global_state.client.training_job_api
job = api.get(job_id)
return cls._load_from_job(job)
# TODO(siyang): load trainer by searching training job via tags.