Source code for kumoai.utils.datasets

from kumoai.connector import FileUploadConnector
from kumoai.connector.utils import replace_table
from kumoai.graph import Edge, Graph, Table


[docs]def from_relbench(dataset_name: str) -> Graph: r"""Creates a Kumo graph from a RelBench dataset. This function processes the specified RelBench dataset, uploads its tables to the Kumo data plane, and registers them as part of a Kumo graph, including inferred table metadata and edges. .. note:: Please note that this method is subject to the limitations for file upload in :class:`~kumoai.connector.FileUploadConnector`. .. code-block:: python import kumoai from kumoai.utils.datasets import from_relbench # Assume dataset `example_dataset` in the RelBench repository: graph = from_relbench(dataset_name="example_dataset") Args: dataset_name: The name of the RelBench dataset to be processed. Returns: A :class:`~kumoai.Graph` object containing the tables and edges derived from the RelBench dataset. Raises: ValueError: If the dataset cannot be retrieved or processed. """ try: import relbench except ImportError: raise RuntimeError( "Creating a Kumo Graph from a RelBench dataset requires the " "'relbench' package to be installed. Please install the package " "before proceeding.") connector = FileUploadConnector(file_type="parquet") dataset = relbench.datasets.get_dataset.get_dataset( dataset_name, download=True) db = dataset.get_db() # Store table metadata and edges: table_metadata = {} # Process each table in the database for table_key in db.table_dict.keys(): # Save the table locally as a parquet file: table = db.table_dict[table_key] parquet_path = f"tmp_{table_key}.parquet" table.df.to_parquet(parquet_path, index=False) # Replace the table on the Kumo data plane: replace_table(name=table_key, path=parquet_path, file_type="parquet") # Register the table with inferred metadata and collect edge # information: table_metadata[table_key] = dict( table=Table.from_source_table( source_table=connector[table_key], primary_key=table.pkey_col, time_column=table.time_col, ).infer_metadata(), edges=table.fkey_col_to_pkey_table) tables = { table_key: table_metadata[table_key]['table'] for table_key in table_metadata.keys() } edges = [] for table_key, table_data in table_metadata.items(): for edge_key, dst_table in table_data['edges'].items(): edges.append( Edge(src_table=table_key, fkey=edge_key, dst_table=dst_table)) # Create and return the graph return Graph( tables=tables, edges=edges, )