Source code for kumoai.pquery.prediction_table
from __future__ import annotations
import asyncio
import logging
from concurrent.futures import Future
from datetime import datetime
from functools import reduce
from typing import List, Mapping, Optional, Union
import pandas as pd
from kumoapi.common import JobStatus
from kumoapi.jobs import GeneratePredictionTableJobResource, JobStatusReport
from typing_extensions import override
from kumoai import global_state
from kumoai.client.jobs import GeneratePredictionTableJobID
from kumoai.connector.s3_connector import S3URI
from kumoai.formatting import pretty_print_error_details
from kumoai.futures import KumoFuture, create_future
logger = logging.getLogger(__name__)
[docs]class PredictionTable:
r"""A prediction table in the Kumo platform. A prediction table can
either be initialized from a job ID of a completed prediction table
generation job, or a path on a supported object store (S3 for a SaaS or
Databricks deployment, and Snowflake session storage for Snowflake).
.. warning::
Custom prediction table is an experimental feature; please work
with your Kumo POC to ensure you are using it correctly!
.. code-block:: python
import kumoai
# Create a Prediction Table from a prediction table generation job.
# Note that the job ID passed here must be in a completed state:
prediction_table = kumoai.PredictionTable("gen-predtable-job-...")
# Read the prediction table as a Pandas DataFrame:
prediction_df = prediction_table.data_df()
# Get URLs to download the prediction table:
prediction_download_urls = prediction_table.data_urls()
Args:
job_id: ID of the prediction table generation job which
generated this prediction table. If a custom table data path is
specified, this parameter should be left as ``None``.
table_data_path: S3 path of the table data location, for which Kumo
must at least have read access. If a job ID is specified, this
parameter should be left as ``None``.
"""
[docs] def __init__(
self,
job_id: Optional[GeneratePredictionTableJobID] = None,
table_data_path: Optional[str] = None,
) -> None:
# Validation:
if not (job_id or table_data_path):
raise ValueError(
"A PredictionTable must either be initialized with a table "
"data path, or a job ID of a completed prediction table "
"generation job.")
if job_id and table_data_path:
raise ValueError(
"Please either pass a table data path, or a job ID of a "
"completed prediction table generation job; passing both "
"is not allowed.")
# Custom path:
self.table_data_uri: Optional[Union[str, S3URI]] = None
if table_data_path is not None:
if global_state.is_spcs:
if table_data_path.startswith('s3://'):
raise ValueError(
"SPCS does not support S3 paths for prediction tables."
)
# TODO(zeyuan): support custom stage path on SPCS:
self.table_data_uri = table_data_path
else:
self.table_data_uri = S3URI(table_data_path).validate()
# Job ID:
self.job_id = job_id
if job_id:
status = _get_status(job_id).status
if status != JobStatus.DONE:
raise ValueError(
f"Job {job_id} is not yet complete (status: {status}). If "
f"you would like to create a future (waiting for "
f"prediction table generation success), please use "
f"`PredictionTableJob`.")
[docs] def data_urls(self) -> List[str]:
r"""Returns a list of URLs that can be used to view generated
prediction table data; if a custom data path was passed, this path is
simply returned.
The list will contain more than one element if the table is
partitioned; paths will be relative to the location of the Kumo data
plane.
"""
api = global_state.client.generate_prediction_table_job_api
if not self.job_id:
# Custom prediction table:
if global_state.is_spcs:
assert isinstance(self.table_data_uri, str)
return [self.table_data_uri]
else:
assert isinstance(self.table_data_uri, S3URI)
return [self.table_data_uri.uri]
return api.get_table_data(self.job_id, presigned=True)
[docs] def data_df(self) -> pd.DataFrame:
r"""Returns a Pandas DataFrame object representing the generated
or custom-specified prediction table data.
.. warning::
This method will load the full prediction table into memory as a
:class:`~pandas.DataFrame` object. If you are working on a machine
with limited resources, please use
:meth:`~kumoai.pquery.PredictionTable.data_urls` instead to
download the data and perform analysis per-partition.
"""
if global_state.is_spcs:
from snowflake.snowpark import DataFrame
from kumoai.spcs import _parquet_to_df
df_list = [_parquet_to_df(url) for url in self.data_urls()]
# TODO(manan): return type hint is wrong
return reduce(DataFrame.union_all, df_list)
else:
urls = self.data_urls()
try:
return pd.concat([pd.read_parquet(x) for x in urls])
except Exception as e:
raise ValueError(
f"Could not create a Pandas DataFrame object from data "
f"paths {urls}. Please construct the DataFrame manually."
) from e
@property
def anchor_time(self) -> Optional[datetime]:
r"""Returns the anchor time corresponding to the generated prediction
table data, if the data was not custom-specified.
"""
if self.job_id is None:
logger.warning(
"Fetching the anchor time is not supported for a custom "
"prediction table (path: %s)", self.table_data_uri)
return None
api = global_state.client.generate_prediction_table_job_api
return api.get_anchor_time(self.job_id)
# Prediction Table Future #####################################################
[docs]class PredictionTableJob(KumoFuture[PredictionTable]):
r"""A representation of an ongoing prediction table generation job in the
Kumo platform.
.. code-block:: python
import kumoai
# See `PredictiveQuery` documentation:
pquery = kumoai.PredictiveQuery(...)
# If a prediction table is generated in nonblocking mode, the response
# will be of type `PredictionTableJob`:
prediction_table_job = pquery.generate_prediction_table(non_blocking=True)
# You can also construct a `PredictionTableJob` from a job ID, e.g.
# one that is present in the Kumo Jobs page:
prediction_table_job = kumoai.PredictionTableJob("gen-predtable-job-...")
# Get the status of the job:
print(prediction_table_job.status())
# Cancel the job:
prediction_table_job.cancel()
# Wait for the job to complete, and return a `PredictionTable`:
prediction_table_job.result()
Args:
job_id: ID of the prediction table generation job.
""" # noqa
[docs] def __init__(
self,
job_id: GeneratePredictionTableJobID,
) -> None:
self.job_id = job_id
self.job: Optional[GeneratePredictionTableJobResource] = None
# A training table holds a reference to the future that tracks the
# execution of its generation.
self._fut: Future = create_future(self._poll())
@property
def id(self) -> GeneratePredictionTableJobID:
r"""The unique ID of this prediction table generation process."""
return self.job_id
[docs] @override
def result(self) -> PredictionTable:
return self._fut.result()
[docs] @override
def future(self) -> Future[PredictionTable]:
return self._fut
[docs] def status(self) -> JobStatusReport:
r"""Returns the status of a running prediction table generation job."""
return self._poll_job().job_status_report
[docs] def cancel(self) -> None:
r"""Cancels a running prediction table generation job, and raises an
error if cancellation failed.
"""
api = global_state.client.generate_prediction_table_job_api
return api.cancel(self.job_id)
# TODO(manan): make asynchronous natively with aiohttp:
def _poll_job(self) -> GeneratePredictionTableJobResource:
# Skip polling if job is already in terminal state.
api = global_state.client.generate_prediction_table_job_api
if not self.job or not self.job.job_status_report.status.is_terminal:
self.job = api.get(self.job_id)
return self.job
async def _poll(self) -> PredictionTable:
while not self.status().status.is_terminal:
await asyncio.sleep(10)
api = global_state.client.generate_prediction_table_job_api
status = self.status().status
if status != JobStatus.DONE:
error_details = api.get_job_error(self.job_id)
error_str = pretty_print_error_details(error_details)
raise RuntimeError(
f"Prediction table generation for job {self.job_id} failed "
f"with job status {status}. Encountered below"
f" errors: {error_str}")
return PredictionTable(self.job_id)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(job_id={self.job_id})'
def _get_status(job_id: str) -> JobStatusReport:
api = global_state.client.generate_prediction_table_job_api
resource: GeneratePredictionTableJobResource = api.get(job_id)
return resource.job_status_report