kumoai.experimental.rfm.KumoRFM#
- class kumoai.experimental.rfm.KumoRFM[source]#
Bases:
object
The Kumo Relational Foundation model (RFM) from the KumoRFM: A Foundation Model for In-Context Learning on Relational Data paper.
KumoRFM
is a foundation model to generate predictions for any relational dataset without training. The model is pre-trained and the class provides an interface to query the model from aLocalGraph
object.from kumoai.experimental.rfm import LocalGraph, KumoRFM df_users = pd.DataFrame(...) df_items = pd.DataFrame(...) df_orders = pd.DataFrame(...) graph = LocalGraph.from_data({ 'users': df_users, 'items': df_items, 'orders': df_orders, }) rfm = KumoRFM(graph) query = ("PREDICT COUNT(transactions.*, 0, 30, days)>0 " "FOR users.user_id=0") result = rfm.query(query) print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0 # 1 0.85
- Parameters:
graph (
LocalGraph
) – The graph.preprocess (
bool
) – Whether to pre-process the data in advance during graph materialization. This is a runtime trade-off between graph materialization and model processing speed. It can be benefical to preprocess your data once and then run many queries on top to achieve maximum model speed. However, if activiated, graph materialization can take potentially much longer, especially on graphs with many large text columns. Best to tune this option manually.verbose (
Union
[bool
,ProgressLogger
]) – Whether to print verbose output.
- batch_mode(batch_size='max', num_retries=1)[source]#
Context manager to predict in batches.
with model.batch_mode(batch_size='max', num_retries=1): df = model.predict(query, indices=...)
- predict(query, indices=None, *, explain=False, anchor_time=None, context_anchor_time=None, run_mode=fast, num_neighbors=None, num_hops=2, max_pq_iterations=20, random_seed=42, verbose=True)[source]#
Returns predictions for a predictive query.
- Parameters:
query (
str
) – The predictive query.indices (
Union
[List
[str
],List
[float
],List
[int
],None
]) – The entity primary keys to predict for. Will override the indices given as part of the predictive query. Predictions will be generated for all indices, independent of whether they fulfill entity filter constraints. To pre-filter entities, useis_valid_entity()
.explain (
bool
) – If set toTrue
, will additionally explain the prediction. Explainability is currently only supported for single entity predictions withrun_mode="FAST"
.anchor_time (
Union
[Timestamp
,Literal
['entity'
],None
]) – The anchor timestamp for the prediction. If set toNone
, will use the maximum timestamp in the data. If set to"entity"
, will use the timestamp of the entity.context_anchor_time (
Optional
[Timestamp
]) – The maximum anchor timestamp for context examples. If set toNone
,anchor_time
will determine the anchor time for context examples.num_neighbors (
Optional
[List
[int
]]) – The number of neighbors to sample for each hop. If specified, thenum_hops
option will be ignored.num_hops (
int
) – The number of hops to sample when generating the context.max_pq_iterations (
int
) – The maximum number of iterations to perform to collect valid labels. It is advised to increase the number of iterations in case the predictive query has strict entity filters, in which case,KumoRFM
needs to sample more entities to find valid labels.random_seed (
Optional
[int
]) – A manual seed for generating pseudo-random numbers.verbose (
Union
[bool
,ProgressLogger
]) – Whether to print verbose output.
- Return type:
Union
[DataFrame
,Explanation
]- Returns:
The predictions as a
pandas.DataFrame
. Ifexplain=True
, additionally returns a textual summary that explains the prediction.
- is_valid_entity(query, indices=None, *, anchor_time=None)[source]#
Returns a mask that denotes which entities are valid for the given predictive query, i.e., which entities fulfill (temporal) entity filter constraints.
- Parameters:
query (
str
) – The predictive query.indices (
Union
[List
[str
],List
[float
],List
[int
],None
]) – The entity primary keys to predict for. Will override the indices given as part of the predictive query.anchor_time (
Union
[Timestamp
,Literal
['entity'
],None
]) – The anchor timestamp for the prediction. If set toNone
, will use the maximum timestamp in the data. If set to"entity"
, will use the timestamp of the entity.
- Return type:
- evaluate(query, *, metrics=None, anchor_time=None, context_anchor_time=None, run_mode=fast, num_neighbors=None, num_hops=2, max_pq_iterations=20, random_seed=42, verbose=True)[source]#
Evaluates a predictive query.
- Parameters:
query (
str
) – The predictive query.anchor_time (
Union
[Timestamp
,Literal
['entity'
],None
]) – The anchor timestamp for the prediction. If set toNone
, will use the maximum timestamp in the data. If set to"entity"
, will use the timestamp of the entity.context_anchor_time (
Optional
[Timestamp
]) – The maximum anchor timestamp for context examples. If set toNone
,anchor_time
will determine the anchor time for context examples.num_neighbors (
Optional
[List
[int
]]) – The number of neighbors to sample for each hop. If specified, thenum_hops
option will be ignored.num_hops (
int
) – The number of hops to sample when generating the context.max_pq_iterations (
int
) – The maximum number of iterations to perform to collect valid labels. It is advised to increase the number of iterations in case the predictive query has strict entity filters, in which case,KumoRFM
needs to sample more entities to find valid labels.random_seed (
Optional
[int
]) – A manual seed for generating pseudo-random numbers.verbose (
Union
[bool
,ProgressLogger
]) – Whether to print verbose output.
- Return type:
- Returns:
The metrics as a
pandas.DataFrame
- get_train_table(query, size, *, anchor_time=None, random_seed=42, max_iterations=20)[source]#
Returns the labels of a predictive query for a specified anchor time.
- Parameters:
query (
str
) – The predictive query.size (
int
) – The maximum number of entities to generate labels for.anchor_time (
Union
[Timestamp
,Literal
['entity'
],None
]) – The anchor timestamp for the query. If set toNone
, will use the maximum timestamp in the data. If set to :”entity”, will use the timestamp of the entity.random_seed (
Optional
[int
]) – A manual seed for generating pseudo-random numbers.max_iterations (
int
) – The number of steps to run before aborting.
- Return type:
- Returns:
The labels as a
pandas.DataFrame
.