kumoai.experimental.rfm.KumoRFM#

class kumoai.experimental.rfm.KumoRFM[source]#

Bases: object

The Kumo Relational Foundation model (RFM) from the KumoRFM: A Foundation Model for In-Context Learning on Relational Data paper.

KumoRFM is a foundation model to generate predictions for any relational dataset without training. The model is pre-trained and the class provides an interface to query the model from a Graph object.

from kumoai.experimental.rfm import Graph, KumoRFM

df_users = pd.DataFrame(...)
df_items = pd.DataFrame(...)
df_orders = pd.DataFrame(...)

graph = Graph.from_data({
    'users': df_users,
    'items': df_items,
    'orders': df_orders,
})

rfm = KumoRFM(graph)

query = ("PREDICT COUNT(orders.*, 0, 30, days)>0 "
         "FOR users.user_id=1")
result = rfm.predict(query)

print(result)  # user_id  COUNT(transactions.*, 0, 30, days) > 0
               # 1        0.85
Parameters:
  • graph (Graph) – The graph.

  • verbose (bool | ProgressLogger) – Whether to print verbose output.

  • optimize (bool) – If set to True, will optimize the underlying data backend for optimal querying. For example, for transactional database backends, will create any missing indices. Requires write-access to the data backend.

__init__(graph, verbose=True, optimize=False)[source]#
retry(num_retries=1)[source]#

Context manager to retry failed queries due to unexpected server issues.

with model.retry(num_retries=1):
    df = model.predict(query, indices=...)
Parameters:

num_retries (int) – The maximum number of retries.

Return type:

Generator[None, None, None]

batch_mode(batch_size='max', num_retries=1)[source]#

Context manager to predict in batches.

with model.batch_mode(batch_size='max', num_retries=1):
    df = model.predict(query, indices=...)
Parameters:
  • batch_size (Union[int, Literal['max']]) – The batch size. If set to "max", will use the maximum applicable batch size for the given task.

  • num_retries (int) – The maximum number of retries for failed queries due to unexpected server issues.

Return type:

Generator[None, None, None]

predict(query, indices=None, *, explain=False, anchor_time=None, context_anchor_time=None, run_mode=fast, num_neighbors=None, num_hops=2, max_pq_iterations=10, random_seed=42, verbose=True, use_prediction_time=False)[source]#

Returns predictions for a predictive query.

Parameters:
  • query (str) – The predictive query.

  • indices (Union[list[str], list[float], list[int], None]) – The entity primary keys to predict for. Will override the indices given as part of the predictive query. Predictions will be generated for all indices, independent of whether they fulfill entity filter constraints.

  • explain (bool | ExplainConfig | dict[str, Any]) – Configuration for explainability. If set to True, will additionally explain the prediction. Passing in an ExplainConfig instance provides control over which parts of explanation are generated. Explainability is currently only supported for single entity predictions with run_mode="FAST".

  • anchor_time (Union[Timestamp, Literal['entity'], None]) – The anchor timestamp for the prediction. If set to None, will use the maximum timestamp in the data. If set to "entity", will use the timestamp of the entity.

  • context_anchor_time (Optional[Timestamp]) – The maximum anchor timestamp for context examples. If set to None, anchor_time will determine the anchor time for context examples.

  • run_mode (RunMode | str) – The RunMode for the query.

  • num_neighbors (Optional[list[int]]) – The number of neighbors to sample for each hop. If specified, the num_hops option will be ignored.

  • num_hops (int) – The number of hops to sample when generating the context.

  • max_pq_iterations (int) – The maximum number of iterations to perform to collect valid labels. It is advised to increase the number of iterations in case the predictive query has strict entity filters, in which case, KumoRFM needs to sample more entities to find valid labels.

  • random_seed (int | None) – A manual seed for generating pseudo-random numbers.

  • verbose (bool | ProgressLogger) – Whether to print verbose output.

  • use_prediction_time (bool) – Whether to use the anchor timestamp as an additional feature during prediction. This is typically beneficial for time series forecasting tasks.

Return type:

DataFrame | Explanation

Returns:

The predictions as a pandas.DataFrame. If explain is provided, returns an Explanation object containing the prediction, summary, and details.

predict_task(task, *, explain=False, run_mode=fast, num_neighbors=None, num_hops=2, verbose=True, exclude_cols_dict=None, use_prediction_time=False, top_k=None)[source]#

Returns predictions for a custom task specification.

Parameters:
  • task (TaskTable) – The custom TaskTable.

  • explain (bool | ExplainConfig | dict[str, Any]) – Configuration for explainability. If set to True, will additionally explain the prediction. Passing in an ExplainConfig instance provides control over which parts of explanation are generated. Explainability is currently only supported for single entity predictions with run_mode="FAST".

  • run_mode (RunMode | str) – The RunMode for the query.

  • num_neighbors (Optional[list[int]]) – The number of neighbors to sample for each hop. If specified, the num_hops option will be ignored.

  • num_hops (int) – The number of hops to sample when generating the context.

  • verbose (bool | ProgressLogger) – Whether to print verbose output.

  • exclude_cols_dict (Optional[dict[str, list[str]]]) – Any column in any table to exclude from the model input.

  • use_prediction_time (bool) – Whether to use the anchor timestamp as an additional feature during prediction. This is typically beneficial for time series forecasting tasks.

  • top_k (Optional[int]) – The number of predictions to return per entity.

Return type:

DataFrame | Explanation

Returns:

The predictions as a pandas.DataFrame. If explain is provided, returns an Explanation object containing the prediction, summary, and details.

evaluate(query, *, metrics=None, anchor_time=None, context_anchor_time=None, run_mode=fast, num_neighbors=None, num_hops=2, max_pq_iterations=10, random_seed=42, verbose=True, use_prediction_time=False)[source]#

Evaluates a predictive query.

Parameters:
  • query (str) – The predictive query.

  • metrics (Optional[list[str]]) – The metrics to use.

  • anchor_time (Union[Timestamp, Literal['entity'], None]) – The anchor timestamp for the prediction. If set to None, will use the maximum timestamp in the data. If set to "entity", will use the timestamp of the entity.

  • context_anchor_time (Optional[Timestamp]) – The maximum anchor timestamp for context examples. If set to None, anchor_time will determine the anchor time for context examples.

  • run_mode (RunMode | str) – The RunMode for the query.

  • num_neighbors (Optional[list[int]]) – The number of neighbors to sample for each hop. If specified, the num_hops option will be ignored.

  • num_hops (int) – The number of hops to sample when generating the context.

  • max_pq_iterations (int) – The maximum number of iterations to perform to collect valid labels. It is advised to increase the number of iterations in case the predictive query has strict entity filters, in which case, KumoRFM needs to sample more entities to find valid labels.

  • random_seed (int | None) – A manual seed for generating pseudo-random numbers.

  • verbose (bool | ProgressLogger) – Whether to print verbose output.

  • use_prediction_time (bool) – Whether to use the anchor timestamp as an additional feature during prediction. This is typically beneficial for time series forecasting tasks.

Return type:

DataFrame

Returns:

The metrics as a pandas.DataFrame

evaluate_task(task, *, metrics=None, run_mode=fast, num_neighbors=None, num_hops=2, verbose=True, exclude_cols_dict=None, use_prediction_time=False)[source]#

Evaluates a custom task specification.

Parameters:
  • task (TaskTable) – The custom TaskTable.

  • metrics (Optional[list[str]]) – The metrics to use.

  • run_mode (RunMode | str) – The RunMode for the query.

  • num_neighbors (Optional[list[int]]) – The number of neighbors to sample for each hop. If specified, the num_hops option will be ignored.

  • num_hops (int) – The number of hops to sample when generating the context.

  • verbose (bool | ProgressLogger) – Whether to print verbose output.

  • exclude_cols_dict (Optional[dict[str, list[str]]]) – Any column in any table to exclude from the model input.

  • use_prediction_time (bool) – Whether to use the anchor timestamp as an additional feature during prediction. This is typically beneficial for time series forecasting tasks.

Return type:

DataFrame

Returns:

The metrics as a pandas.DataFrame

get_train_table(query, size, *, anchor_time=None, random_seed=42, max_iterations=10)[source]#

Returns the labels of a predictive query for a specified anchor time.

Parameters:
  • query (str) – The predictive query.

  • size (int) – The maximum number of entities to generate labels for.

  • anchor_time (Union[Timestamp, Literal['entity'], None]) – The anchor timestamp for the query. If set to None, will use the maximum timestamp in the data. If set to :”entity”, will use the timestamp of the entity.

  • random_seed (int | None) – A manual seed for generating pseudo-random numbers.

  • max_iterations (int) – The number of steps to run before aborting.

Return type:

DataFrame

Returns:

The labels as a pandas.DataFrame.