Source code for kumoai.utils.forecasting

from typing import Dict, List

import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots


[docs]class ForecastVisualizer: r"""A tool for quickly visualizing forecast results from the holdout dataframe of a forecasting job. .. code-block:: python import kumoai # Retrieve job results from a training training job. Note # that the job ID passed here must be in a completed state: job_result = kumoai.TrainingJob("trainingjob-...").result() # Read the holdout table as a Pandas DataFrame: holdout_df = job_result.holdout_df() # Pass holdout table to ForecastVisualizer and visualize results holdout_forecast = kumoai.utils.ForecastVisualizer(holdout_df) holdout_forecast.visualize() """
[docs] def __init__(self, holdout_df: pd.DataFrame) -> None: # Sort the holdout dataframe and extract unique entities: self.forecast = holdout_df.sort_values(['ENTITY', 'TIMESTAMP']) self.entities = holdout_df['ENTITY'].unique().tolist() self.fig = self._initialize_subplot() self.buttons: List[Dict] = [] self.plot_config = { 'target': { 'color': 'blue', 'name': 'TARGET' }, 'prediction': { 'color': 'red', 'name': 'TARGET_PRED' }, 'residuals': { 'color': 'green', 'name': 'Residuals' }, 'residuals_time': { 'color': 'orange', 'name': 'Residuals Over Time' } }
@staticmethod def _initialize_subplot() -> go.Figure: r"""Initializes the subplot structure with three rows: Row 1: Line plot of actual forecast vs predicted Row 2: Line plot of residuals overtime Row 3: Histogram distribution of residuals """ return make_subplots( rows=3, cols=1, specs=[[{ "type": "scatter" }], [{ "type": "scatter" }], [{ "type": "xy" }]], row_heights=[0.5, 0.25, 0.25], subplot_titles=('Forecast vs Actual', 'Residuals Over Time', 'Residuals Distribution'), ) def _create_time_series_trace( self, data: pd.Series, entity: str, trace_type: str, visibility: bool, ) -> go.Scatter: r"""Create a time series trace for either target or prediction""" config = self.plot_config[trace_type] return go.Scatter( x=data["TIMESTAMP"], y=data[config['name']], name=f"{entity} - {config['name']}", mode="lines", line=dict(color=config['color']), visible=visibility, opacity=0.75, ) def _create_residuals_time_trace( self, data: pd.Series, entity: str, visibility: bool, ) -> go.Scatter: r"""Create a time series trace for residuals over time""" residuals = data["TARGET"] - data["TARGET_PRED"] return go.Scatter( x=data["TIMESTAMP"], y=residuals, name=f"{entity} - Residuals Over Time", mode="lines+markers", line=dict(color=self.plot_config['residuals_time']['color']), visible=visibility, opacity=0.75, ) def _create_residuals_hist_trace( self, data: pd.Series, entity: str, visibility: bool, ) -> go.Histogram: r"""Create a histogram trace for residuals distribution.""" residuals = data["TARGET"] - data["TARGET_PRED"] return go.Histogram( x=residuals, name=f"{entity} - Residuals Distribution", marker=dict(color=self.plot_config['residuals']['color']), visible=visibility, opacity=0.75, nbinsx=30, ) def _create_button(self, index: int, entity: str) -> None: r"""Create visibility toggle button for an entity.""" # target, prediction, residuals time, and residuals hist: num_traces_per_entity = 4 total_traces = len(self.entities) * num_traces_per_entity button = dict(label=entity, method="update", args=[{ "visible": [False] * total_traces }]) # Set visibility for the entity's traces: base_index = index * num_traces_per_entity for i in range(num_traces_per_entity): button["args"][0]["visible"][base_index + i] = True # type: ignore self.buttons.append(button) def _create_traces(self) -> None: """Create all traces for the visualization.""" for i, entity in enumerate(self.entities): entity_data = self.forecast.loc[self.forecast.ENTITY == entity] # First entity's traces are visible by default: visibility = (i == 0) # Create traces trace_target = self._create_time_series_trace( entity_data, entity, 'target', visibility) trace_pred = self._create_time_series_trace( entity_data, entity, 'prediction', visibility) trace_residuals_time = self._create_residuals_time_trace( entity_data, entity, visibility) trace_residuals_hist = self._create_residuals_hist_trace( entity_data, entity, visibility) # Add traces to appropriate subplots self.fig.add_trace(trace_target, row=1, col=1) self.fig.add_trace(trace_pred, row=1, col=1) self.fig.add_trace(trace_residuals_time, row=2, col=1) self.fig.add_trace(trace_residuals_hist, row=3, col=1) self._create_button(i, entity) def _update_layout(self) -> None: r"""Update the figure layout with all necessary configurations.""" self.fig.update_layout( updatemenus=[ dict(active=0, buttons=self.buttons, direction="down", pad={ "r": 10, "t": 10 }, showactive=True, x=1, xanchor="left", y=1.07, yanchor="top") ], title="Forecast Results by Department", height=1000, # Increased height to accommodate third plot width=1300, showlegend=True, hovermode='x unified') # Update axis labels and add zero reference line for residuals self.fig.update_xaxes(title_text="Timestamp", row=1, col=1) self.fig.update_xaxes(title_text="Timestamp", row=2, col=1) self.fig.update_xaxes(title_text="Residual Value", row=3, col=1) self.fig.update_yaxes(title_text="Patient Volume", row=1, col=1) self.fig.update_yaxes(title_text="Residual Value", row=2, col=1) self.fig.update_yaxes(title_text="Frequency", row=3, col=1) # Add zero reference line for residuals time series self.fig.add_hline( y=0, line_dash="dash", line_color="gray", row=2, col=1, )
[docs] def visualize(self) -> None: r"""Generate and display the complete visualization.""" self._create_traces() self._update_layout() self.fig.show()