Source code for kumoai.experimental.rfm.rfm

import logging
import time
from typing import Any, Dict, Optional

import requests
from kumoapi.online_serving import (
    OnlinePredictionOptions,
    OnlineServingEndpointRequest,
    OnlineServingStatusCode,
)

from kumoai import global_state
from kumoai.experimental.rfm.local_graph import LocalGraph
from kumoai.trainer.online_serving import (
    OnlineServingEndpoint,
    OnlineServingEndpointFuture,
)

logger = logging.getLogger(__name__)

DEFAULT_GRAPH_NAME = "rfm_byoc_graph"


[docs]class KumoRFM: r"""The :class:`KumoRFM` class is an interface to the Kumo Relational Foundation model (RFM), see `KumoRFM <https://kumo.ai/research/kumo_relational_foundation_model.pdf>`_. KumoRFM is a relational foundation model, which can make generate predicitons in context for any relational dataset. The model is pretrained and the class provides an interface to query the model. The class is constructed from a :class:`LocalGraph` object. Example: .. code-block:: python # raw data df_users = pd.DataFrame(...) df_items = pd.DataFrame(...) df_transactions = pd.DataFrame(...) # construct LocalGraph from raw data graph = LocalGraph.from_data( { 'users': df_users, 'items': df_items, 'transactions': df_transactions, } ) # start KumoRFM with this graph rfm = KumoRFM(graph) # query the model query_str = ("PREDICT COUNT(transactions.*, 0, 30, days) > 0 " "FOR users.user_id=1") result = rfm.query(query_str) # Result is a pandas DataFrame with prediction probabilities print(result) # user_id COUNT(transactions.*, 0, 30, days) > 0 # 1 0.85 """ # noqa: E501
[docs] def __init__(self, graph: LocalGraph, endpoint: Optional[OnlineServingEndpoint] = None) -> None: self.graph = graph self._endpoint: Optional[OnlineServingEndpoint] = endpoint self._endpoint_future: Optional[OnlineServingEndpointFuture] = None if self._endpoint is None: self._start()
[docs] def query( self, query: str, wait_result_timeout: float = 10.0, ) -> Dict[str, Any]: """Query the RFM model with a query string Args: query: The RFM query string (e.g., "PREDICT COUNT(orders.*, 0, 30, days) > 0 " "FOR users.user_id=1") wait_result_timeout: Timeout in seconds to wait for the result Returns: Dictionary containing the prediction result """ # noqa: E501 # Ensure endpoint is ready if not self._endpoint: try: if self._endpoint_future: logger.info("Waiting for endpoint to be ready...") self._endpoint = self._endpoint_future.result() else: raise RuntimeError("Endpoint not initialized") except Exception as e: raise RuntimeError("Failed to get endpoint result") from e # Execute query via direct HTTP request to predict endpoint payload = {"query": query, 'wait_result_timeout': wait_result_timeout} resp = requests.post(self._endpoint._predict_url, headers=global_state.client._session.headers, json=payload) resp.raise_for_status() result = resp.json() return result
[docs] def shutdown(self) -> None: r"""Clean up resources by destroying the RFM endpoint. This method attempts to destroy the RFM endpoint if one exists. Example: .. code-block:: python rfm = KumoRFM(graph) rfm.wait_until_ready() # ... use RFM ... rfm.shutdown() # Clean up endpoint when done """ # noqa: E501 if self._endpoint: logger.info("Found endpoint for KumoRFM, destroying...") try: self._endpoint.destroy() except Exception as e: logger.error(f"Failed to destroy endpoint: {e}") else: logger.info("No endpoint found, skipping...")
[docs] def poll_status(self) -> Optional[OnlineServingStatusCode]: """Poll the current status of the RFM endpoint. Returns: The current status code of the endpoint, or None if no endpoint exists. Possible values: - OnlineServingStatusCode.IN_PROGRESS: Endpoint is being created/updated - OnlineServingStatusCode.READY: Endpoint is ready for queries - OnlineServingStatusCode.FAILED: Endpoint creation/update failed """ if not self._endpoint_future: logger.info("Endpoint creation not started yet, " "make sure to call _start() first") return None # Get the endpoint resource to check current status endpoint_api = global_state.client.online_serving_endpoint_api res = endpoint_api.get_if_exists(self._endpoint_future.id) if res is None: return None return res.status.status_code
[docs] def is_ready(self) -> bool: """Check if the RFM endpoint is ready for queries. Returns: True if the endpoint is ready, False otherwise. """ status = self.poll_status() return status == OnlineServingStatusCode.READY
[docs] def wait_until_ready(self, timeout: Optional[float] = None, sleep_interval: float = 10.0) -> None: """Wait until the RFM endpoint is ready. Args: timeout: Maximum time to wait in seconds. If None, waits indefinitely. sleep_interval: Time to sleep between polls in seconds. Raises: TimeoutError: If the endpoint is not ready within the timeout period. RuntimeError: If the endpoint failed to start. """ start_time = time.time() if not self._endpoint_future: raise RuntimeError("Endpoint creation not started yet, " "make sure to call _start() first") while True: endpoint_api = global_state.client.online_serving_endpoint_api res = endpoint_api.get_if_exists(self._endpoint_future.id) if res is None: raise RuntimeError("Endpoint resource not found") status = res.status.status_code if status == OnlineServingStatusCode.READY: logger.info( "RFM endpoint is ready. Details:\n" f"ID: {self._endpoint_future.id}\n" f"URL: {res.endpoint_url}\n" f"Launched at: {res.launched_at}\n" f"Config: {res.config}\n" f"Status: {res.status}\n" f"Update status: {res.update if res.update else 'None'}") self._endpoint = self._endpoint_future.result() return elif status == OnlineServingStatusCode.FAILED: raise RuntimeError("RFM endpoint failed to start") elif timeout and (time.time() - start_time) > timeout: raise TimeoutError( f"RFM endpoint not ready within {timeout} seconds") else: logger.info( f"RFM endpoint not ready yet. Current status: {status}. ") time.sleep(sleep_interval)
def _start(self) -> None: """Initialize KumoRFM by uploading data and creating endpoint""" for table in self.graph.tables.values(): if not table.has_primary_key(): table._add_default_primary_key() self.graph.validate() kumo_graph = self.graph.to_kumo_graph() # skip validation as we validate the graph manually kumo_graph.save(DEFAULT_GRAPH_NAME, skip_validation=True) self._endpoint_future = self._create_endpoint() def _create_endpoint(self) -> OnlineServingEndpointFuture: """Create online serving endpoint for RFM""" endpoint_api = global_state.client.online_serving_endpoint_api request = OnlineServingEndpointRequest( model_training_job_id='FOUNDATION_MODEL', predict_options=OnlinePredictionOptions(), ) # TODO(blaz): fix when we allow multiple graphs endpoint_id = endpoint_api.create( request, graph_name=DEFAULT_GRAPH_NAME, use_ge=False, ) return OnlineServingEndpointFuture(endpoint_id)