import logging
from typing import (
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
Union,
overload,
)
from kumoapi.jobs import (
BatchPredictionOptions,
BatchPredictionRequest,
JobStatus,
PredictionOutputConfig,
TrainingJobRequest,
TrainingJobResource,
)
from kumoapi.model_plan import ModelPlan
from kumoai import global_state
from kumoai.client.jobs import (
GeneratePredictionTableJobID,
TrainingJobAPI,
TrainingJobID,
)
from kumoai.connector import Connector
from kumoai.graph import Graph
from kumoai.pquery.prediction_table import PredictionTable, PredictionTableJob
from kumoai.pquery.training_table import TrainingTable, TrainingTableJob
from kumoai.trainer.job import (
BatchPredictionJob,
BatchPredictionJobResult,
TrainingJob,
TrainingJobResult,
)
from kumoai.trainer.util import (
build_prediction_output_config,
validate_output_arguments,
)
logger = logging.getLogger(__name__)
[docs]class Trainer:
r"""A trainer supports creating a Kumo machine learning model on a
:class:`~kumoai.pquery.PredictiveQuery`. It is primarily oriented around
two methods: :meth:`~kumoai.trainer.Trainer.fit`, which accepts a
:class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.TrainingTable` and
produces a :class:`~kumoai.trainer.TrainingJobResult`, and
:meth:`~kumoai.trainer.Trainer.predict`, which accepts a
:class:`~kumoai.graph.Graph` and :class:`~kumoai.pquery.PredictionTable`
and produces a :class:`~kumoai.trainer.BatchPredictionJobResult`.
A :class:`~kumoai.trainer.Trainer` can also be loaded from a training job,
with :meth:`~kumoai.trainer.Trainer.load`.
.. code-block:: python
import kumoai
# See `Graph` and `PredictiveQuery` documentation:
graph = kumoai.Graph(...)
pquery = kumoai.PredictiveQuery(graph=graph, query=...)
# Create a `Trainer` object, using a suggested model plan given the
# predictive query:
model_plan = pquery.suggest_model_plan()
trainer = kumoai.Trainer(model_plan)
# Create a training table from the predictive query:
training_table = pquery.generate_training_table()
# Fit a model:
training_job = trainer.fit(
graph = graph,
training_table = training_table,
)
# Create a prediction table from the predictive query:
prediction_table = pquery.generate_prediction_table()
# Generate predictions:
prediction_job = trainer.predict(
graph = graph,
prediction_table = prediction_table,
# other arguments here...
)
# Load a trained query to generate predictions:
pquery = kumoai.PredictiveQuery.load_from_training_job("trainingjob-...")
trainer = kumoai.Trainer.load("trainingjob-...")
prediction_job = trainer.predict(
pquery.graph,
pquery.generate_prediction_table(),
)
Args:
model_plan: A model plan that the trainer should follow when fitting a
Kumo model to a predictive query. This model plan can either be
generated from a predictive query, with
:meth:`~kumoai.pquery.PredictiveQuery.suggest_model_plan`, or can
be manually specified.
""" # noqa: E501
[docs] def __init__(self, model_plan: ModelPlan) -> None:
self._model_plan: Optional[ModelPlan] = model_plan
# Cached from backend:
self._training_job_id: Optional[TrainingJobID] = None
@property
def model_plan(self) -> Optional[ModelPlan]:
return self._model_plan
@model_plan.setter
def model_plan(self, model_plan: ModelPlan) -> None:
self._model_plan = model_plan
# Metadata ################################################################
@property
def is_trained(self) -> bool:
r"""Returns ``True`` if this trainer instance has successfully been
trained (and is therefore ready for prediction); ``False`` otherwise.
"""
if not self._training_job_id:
return False
try:
api = global_state.client.training_job_api
res: TrainingJobResource = api.get(self._training_job_id)
except Exception as e: # noqa
logger.exception(
"Failed to fetch training status for training job with ID %s",
self._training_job_id, exc_info=e)
return False
return res.job_status_report.status == JobStatus.DONE
# Fit / predict ###########################################################
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
) -> TrainingJobResult:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: Literal[False],
) -> TrainingJobResult:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: Literal[True],
) -> TrainingJob:
pass
@overload
def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: bool,
) -> Union[TrainingJob, TrainingJobResult]:
pass
[docs] def fit(
self,
graph: Graph,
train_table: Union[TrainingTable, TrainingTableJob],
*,
non_blocking: bool = False,
custom_tags: Mapping[str, str] = {},
) -> Union[TrainingJob, TrainingJobResult]:
r"""Fits a model to the specified graph and training table, with the
strategy defined by this :class:`Trainer`'s :obj:`model_plan`.
Args:
graph: The :class:`~kumoai.graph.Graph` object that represents the
tables and relationships that Kumo will learn from.
train_table: The :class:`~kumoai.pquery.TrainingTable`, or
in-progress :class:`~kumoai.pquery.TrainingTableJob`, that
represents the training data produced by a
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`.
non_blocking: Whether this operation should return immediately
after launching the training job, or await completion of the
training job.
custom_tags: Additional, customer defined k-v tags to be associated
with the job to be launched. Job tags are useful for grouping
and searching jobs.
Returns:
Union[TrainingJobResult, TrainingJob]:
If ``non_blocking=False``, returns a training job object. If
``non_blocking=True``, returns a training job future object.
"""
# TODO(manan, siyang): remove soon:
job_id = train_table.job_id
assert job_id is not None
train_table_job_api = global_state.client.generate_train_table_job_api
pq_id = train_table_job_api.get(job_id).config.pquery_id
assert pq_id is not None
# NOTE the backend implementation currently handles sequentialization
# between a training table future and a training job; that is, if the
# training table future is still executing, the backend will wait on
# the job ID completion before executing a training job. This preserves
# semantics for both futures, ensures that Kumo works as expected if
# used only via REST API, and allows us to avoid chaining calllbacks
# in an ugly way here:
api = global_state.client.training_job_api
self._training_job_id = api.create(
TrainingJobRequest(
dict(custom_tags), pquery_id=pq_id,
model_plan=self._model_plan,
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
train_table_job_id=job_id))
out = TrainingJob(job_id=self._training_job_id)
if non_blocking:
return out
return out.attach()
[docs] def predict(
self,
graph: Graph,
prediction_table: Union[PredictionTable, PredictionTableJob],
output_types: Set[str],
output_connector: Optional[Connector] = None,
output_table_name: Optional[Union[str, Tuple]] = None,
training_job_id: Optional[TrainingJobID] = None,
binary_classification_threshold: Optional[float] = None,
num_classes_to_return: Optional[int] = None,
num_workers: int = 1,
non_blocking: bool = False,
custom_tags: Mapping[str, str] = {},
) -> Union[BatchPredictionJob, BatchPredictionJobResult]:
r"""Using the trained model specified by :obj:`training_job_id` (or
the model corresponding to the last invocation of
:meth:`~kumoai.trainer.Trainer.fit`, if not present), predicts the
future values of the entities in :obj:`prediction_table` leveraging
current information from :obj:`graph`.
Args:
graph: The :class:`~kumoai.graph.Graph` object that represents the
tables and relationships that Kumo will use to make
predictions.
prediction_table: The :class:`~kumoai.pquery.PredictionTable`, or
in-progress :class:`~kumoai.pquery.PredictionTableJob`, that
represents the prediction data produced by a
:class:`~kumoai.pquery.PredictiveQuery` on :obj:`graph`. This
table may also be custom-generated.
output_types: The types of outputs that should be produced by
the prediction job. Can include either ``'predictions'``,
``'embeddings'``, or both.
output_connector: The output data source that Kumo should write
batch predictions to, if it is None, produce local download
output only.
output_table_name: The name of the table in the output data source
that Kumo should write batch predictions to. In the case of
a Databricks connector, this should be a tuple of two strings:
the schema name and the output prediction table name.
training_job_id: The job ID of the training job whose model will be
used for making predictions.
binary_classification_threshold: If this model corresponds to
a binary classification task, the threshold for which higher
values correspond to ``1``, and lower values correspond to
``0``.
num_classes_to_return: If this model corresponds to a ranking task,
the number of classes to return in the prediction output.
num_workers: Number of workers to use when generating batch
predictions. When set to a value greater than 1, the prediction
table is partitioned into smaller parts and processed in
parallel. The default is 1, which implies sequential inference
over the prediction table.
non_blocking: Whether this operation should return immediately
after launching the batch prediction job, or await
completion of the batch prediction job.
custom_tags: Additional, customer defined k-v tags to be associated
with the job to be launched. Job tags are useful for grouping
and searching jobs.
"""
validate_output_arguments(output_types, output_connector,
output_table_name)
# Create outputs:
outputs: List[PredictionOutputConfig] = []
for output_type in output_types:
if output_connector is None:
# Predictions are generated to the Kumo dataplane, and can
# only be exported via the UI for now:
pass
else:
outputs.append(
build_prediction_output_config(output_type,
output_connector,
output_table_name))
training_job_id = training_job_id or self._training_job_id
if training_job_id is None:
raise ValueError(
"Cannot run batch prediction without a specified or saved "
"training job ID. Please either call `fit(...)` or pass a "
"job ID of a completed training job to proceed.")
pred_table_job_id: Optional[GeneratePredictionTableJobID] = \
prediction_table.job_id
pred_table_data_path = None
if pred_table_job_id is None:
assert isinstance(prediction_table, PredictionTable)
pred_table_data_path = prediction_table.table_data_uri \
if global_state.is_spcs \
else prediction_table.table_data_uri.uri # type: ignore
api = global_state.client.batch_prediction_job_api
job_id, response = api.maybe_create(
BatchPredictionRequest(
dict(custom_tags),
model_training_job_id=training_job_id,
predict_options=BatchPredictionOptions(
binary_classification_threshold=(
binary_classification_threshold),
num_classes_to_return=num_classes_to_return,
num_workers=num_workers,
),
outputs=outputs,
graph_snapshot_id=graph.snapshot(non_blocking=non_blocking),
pred_table_job_id=pred_table_job_id,
pred_table_path=pred_table_data_path,
))
message = response.message()
if not response.ok:
raise RuntimeError(f"Prediction failed. {message}")
elif not response.empty():
logger.warning("Prediction produced the following warnings: %s",
message)
assert job_id is not None
self._batch_prediction_job_id = job_id
out = BatchPredictionJob(job_id=self._batch_prediction_job_id)
if non_blocking:
return out
return out.attach()
# Persistence #############################################################
@classmethod
def _load_from_job(cls, job: TrainingJobResource) -> 'Trainer':
trainer = cls(job.config.model_plan)
trainer._training_job_id = job.job_id
return trainer
[docs] @classmethod
def load(cls, job_id: TrainingJobID) -> 'Trainer':
r"""Creates a :class:`~kumoai.trainer.Trainer` instance from a training
job ID.
"""
api: TrainingJobAPI = global_state.client.training_job_api
job = api.get(job_id)
return cls._load_from_job(job)
# TODO(siyang): load trainer by searching training job via tags.