[docs]classForecastVisualizer:r"""A tool for quickly visualizing forecast results from the holdout dataframe of a forecasting job. .. code-block:: python import kumoai # Retrieve job results from a training training job. Note # that the job ID passed here must be in a completed state: job_result = kumoai.TrainingJob("trainingjob-...").result() # Read the holdout table as a Pandas DataFrame: holdout_df = job_result.holdout_df() # Pass holdout table to ForecastVisualizer and visualize results holdout_forecast = kumoai.utils.ForecastVisualizer(holdout_df) holdout_forecast.visualize() """
[docs]def__init__(self,holdout_df:pd.DataFrame)->None:# Sort the holdout dataframe and extract unique entities:self.forecast=holdout_df.sort_values(['ENTITY','TIMESTAMP'])self.entities=holdout_df['ENTITY'].unique().tolist()self.fig=self._initialize_subplot()self.buttons:List[Dict]=[]self.plot_config={'target':{'color':'blue','name':'TARGET'},'prediction':{'color':'red','name':'TARGET_PRED'},'residuals':{'color':'green','name':'Residuals'},'residuals_time':{'color':'orange','name':'Residuals Over Time'}}
@staticmethoddef_initialize_subplot()->go.Figure:r"""Initializes the subplot structure with three rows: Row 1: Line plot of actual forecast vs predicted Row 2: Line plot of residuals overtime Row 3: Histogram distribution of residuals """returnmake_subplots(rows=3,cols=1,specs=[[{"type":"scatter"}],[{"type":"scatter"}],[{"type":"xy"}]],row_heights=[0.5,0.25,0.25],subplot_titles=('Forecast vs Actual','Residuals Over Time','Residuals Distribution'),)def_create_time_series_trace(self,data:pd.Series,entity:str,trace_type:str,visibility:bool,)->go.Scatter:r"""Create a time series trace for either target or prediction"""config=self.plot_config[trace_type]returngo.Scatter(x=data["TIMESTAMP"],y=data[config['name']],name=f"{entity} - {config['name']}",mode="lines",line=dict(color=config['color']),visible=visibility,opacity=0.75,)def_create_residuals_time_trace(self,data:pd.Series,entity:str,visibility:bool,)->go.Scatter:r"""Create a time series trace for residuals over time"""residuals=data["TARGET"]-data["TARGET_PRED"]returngo.Scatter(x=data["TIMESTAMP"],y=residuals,name=f"{entity} - Residuals Over Time",mode="lines+markers",line=dict(color=self.plot_config['residuals_time']['color']),visible=visibility,opacity=0.75,)def_create_residuals_hist_trace(self,data:pd.Series,entity:str,visibility:bool,)->go.Histogram:r"""Create a histogram trace for residuals distribution."""residuals=data["TARGET"]-data["TARGET_PRED"]returngo.Histogram(x=residuals,name=f"{entity} - Residuals Distribution",marker=dict(color=self.plot_config['residuals']['color']),visible=visibility,opacity=0.75,nbinsx=30,)def_create_button(self,index:int,entity:str)->None:r"""Create visibility toggle button for an entity."""# target, prediction, residuals time, and residuals hist:num_traces_per_entity=4total_traces=len(self.entities)*num_traces_per_entitybutton=dict(label=entity,method="update",args=[{"visible":[False]*total_traces}])# Set visibility for the entity's traces:base_index=index*num_traces_per_entityforiinrange(num_traces_per_entity):button["args"][0]["visible"][base_index+i]=True# type: ignoreself.buttons.append(button)def_create_traces(self)->None:"""Create all traces for the visualization."""fori,entityinenumerate(self.entities):entity_data=self.forecast.loc[self.forecast.ENTITY==entity]# First entity's traces are visible by default:visibility=(i==0)# Create tracestrace_target=self._create_time_series_trace(entity_data,entity,'target',visibility)trace_pred=self._create_time_series_trace(entity_data,entity,'prediction',visibility)trace_residuals_time=self._create_residuals_time_trace(entity_data,entity,visibility)trace_residuals_hist=self._create_residuals_hist_trace(entity_data,entity,visibility)# Add traces to appropriate subplotsself.fig.add_trace(trace_target,row=1,col=1)self.fig.add_trace(trace_pred,row=1,col=1)self.fig.add_trace(trace_residuals_time,row=2,col=1)self.fig.add_trace(trace_residuals_hist,row=3,col=1)self._create_button(i,entity)def_update_layout(self)->None:r"""Update the figure layout with all necessary configurations."""self.fig.update_layout(updatemenus=[dict(active=0,buttons=self.buttons,direction="down",pad={"r":10,"t":10},showactive=True,x=1,xanchor="left",y=1.07,yanchor="top")],title="Forecast Results by Department",height=1000,# Increased height to accommodate third plotwidth=1300,showlegend=True,hovermode='x unified')# Update axis labels and add zero reference line for residualsself.fig.update_xaxes(title_text="Timestamp",row=1,col=1)self.fig.update_xaxes(title_text="Timestamp",row=2,col=1)self.fig.update_xaxes(title_text="Residual Value",row=3,col=1)self.fig.update_yaxes(title_text="Patient Volume",row=1,col=1)self.fig.update_yaxes(title_text="Residual Value",row=2,col=1)self.fig.update_yaxes(title_text="Frequency",row=3,col=1)# Add zero reference line for residuals time seriesself.fig.add_hline(y=0,line_dash="dash",line_color="gray",row=2,col=1,)
[docs]defvisualize(self)->None:r"""Generate and display the complete visualization."""self._create_traces()self._update_layout()self.fig.show()