import copy
import dataclasses
import datetime
import types
from dataclasses import field, fields, make_dataclass
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from pydantic import (
Field,
confloat,
conint,
conlist,
root_validator,
validator,
)
from pydantic.dataclasses import dataclass
from kumoapi.common import StrEnum
from kumoapi.encoder import EncoderType
from kumoapi.pquery import QueryType
from kumoapi.task import TaskType
[docs]class RunMode(StrEnum):
r"""Defines the run mode for AutoML. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/whats-the-recommended-way-to-cut-down-on-model-training-time>`_
for more information.""" # noqa
#: Speeds up the search process—typically about 4x faster than
#: using the normal mode.
FAST = 'fast'
#: Default value.
NORMAL = 'normal'
#: Typically takes 4x the time used by the normal mode.
BEST = 'best'
DEBUG = 'debug'
class RHSEmbeddingMode(StrEnum):
r"""Specifies how to incorporate shallow RHS representations in link
prediction tasks."""
# Use trainable look-up embeddings (transductive):
LOOKUP = 'lookup'
# Purely rely on shallow RHS input features (inductive):
FEATURE = 'feature'
# Rely on shallow/single layer RHS input features (inductive):
SHALLOW_FEATURE = 'shallow_feature'
# Fuse look-up embeddings and shallow RHS input features (transductive):
FUSION = 'fusion'
@property
def use_rhs_lookup(self) -> bool:
return self in [
RHSEmbeddingMode.LOOKUP,
RHSEmbeddingMode.FUSION,
]
@property
def use_rhs_feature(self) -> bool:
return self in [
RHSEmbeddingMode.FEATURE,
RHSEmbeddingMode.SHALLOW_FEATURE,
RHSEmbeddingMode.FUSION,
]
@property
def only_use_rhs_feature(self) -> bool:
return self in [
RHSEmbeddingMode.FEATURE,
RHSEmbeddingMode.SHALLOW_FEATURE,
]
class WeightMode(StrEnum):
r"""Specifies how to deal with imbalanced datasets or training tables that
contain a weight column."""
# Sample training examples with replacement:
SAMPLE = 'sample'
# Weight training examples in the loss function:
WEIGHTED_LOSS = 'weighted_loss'
# Sample training examples with replacement, but re-weigh them in the loss
# function according to the inverse of the given weight:
MIX = 'mix'
@property
def use_sampling(self) -> bool:
return self in [WeightMode.SAMPLE, WeightMode.MIX]
@property
def use_weighted_loss(self) -> bool:
return self in [WeightMode.WEIGHTED_LOSS, WeightMode.MIX]
class LinkPredOutputType(StrEnum):
RANKING = 'ranking'
EMBEDDING = 'embedding'
@classmethod
def _missing_(cls, value: str):
# Ensure backward compatibility:
if not isinstance(value, str):
value = value.value
if value.lower() == 'default':
return LinkPredOutputType.RANKING
if value.lower() == 'link_prediction_ranking':
return LinkPredOutputType.RANKING
if value.lower() == 'link_prediction_embedding':
return LinkPredOutputType.EMBEDDING
if value.lower() == "embedding":
return LinkPredOutputType.EMBEDDING
if value.lower() == "ranking":
return LinkPredOutputType.RANKING
class LossType(StrEnum):
BINARY_CROSS_ENTROPY = 'binary_cross_entropy'
CROSS_ENTROPY = 'cross_entropy'
FOCAL_LOSS = 'focal'
MAE = 'mae'
MSE = 'mse'
HUBER = 'huber'
NORMAL_DISTRIBUTION = 'normal_distribution'
NEGATIVE_BINOMIAL = 'negative_binomial_distribution'
LOG_NORMAL_DISTRIBUTION = 'log_normal_distribution'
@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class FocalLossConfig:
name: Literal['focal']
# Weighting factor to balance positive vs. negative examples.
alpha: float = Field(default=0.25, gt=0, lt=1)
# Balance easy vs. hard examples.
gamma: float = Field(default=2.0, ge=1)
def __repr__(self) -> str:
return f'FocalLoss(alpha={self.alpha}, gamma={self.gamma})'
@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class HuberLossConfig:
name: Literal['huber']
# The threshold at which to change between delta-scaled L1 and L2 loss.
delta: float = Field(default=1.0, gt=0)
def __repr__(self) -> str:
return f'HuberLoss(delta={self.delta})'
class IntervalType(StrEnum):
STEP = 'step'
EPOCH = 'epoch'
class AggregationType(StrEnum):
SUM = 'sum'
MEAN = 'mean'
MIN = 'min'
MAX = 'max'
STD = 'std'
VAR = 'var'
class ActivationType(StrEnum):
RELU = 'relu'
LEAKY_RELU = 'leaky_relu'
ELU = 'elu'
GELU = 'gelu'
class NormalizationType(StrEnum):
LAYER_NORM = 'layer_norm'
BATCH_NORM = 'batch_norm'
class PastEncoderType(StrEnum):
DECOMPOSED = 'decomposed'
NORMALIZED = 'normalized'
MLP = 'mlp'
TRANSFORMER = 'transformer'
class DistanceMeasureType(StrEnum):
DOT_PRODUCT = 'dot_product'
COSINE = 'cosine'
@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class EarlyStoppingConfig:
min_delta: float = Field(ge=0)
patience: int = Field(gt=0)
def __repr__(self) -> str:
return (f'EarlyStopping(min_delta={self.min_delta}, '
f'patience={self.patience})')
@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class LRSchedulerConfig:
name: str
interval: IntervalType
kwargs: Dict[str, Any] = field(default_factory=dict)
def __repr__(self) -> str:
if len(self.kwargs) == 0:
return f'LRScheduler(name={self.name}, interval={self.interval})'
kwargs_repr = ', '.join(
[f'{key}={value}' for key, value in self.kwargs.items()])
kwargs_repr = '{' + kwargs_repr + '}'
return (f'LRScheduler(\n'
f' name={self.name},\n'
f' interval={self.interval},\n'
f' kwargs={kwargs_repr},\n'
f')')
class MissingType(StrEnum):
VALUE = '???'
class InferredType(StrEnum):
VALUE = 'inferred'
@dataclass(
config=dict( # type: ignore
validate_assignment=True,
extra='allow',
),
repr=False)
class HopConfig:
default: conint(ge=-1, le=128) # type: ignore
def __getitem__(self, key: str) -> Union[int, InferredType]:
return getattr(self, key)
def __setitem__(
self,
key: str,
value: Union[int, str, InferredType],
) -> None:
if isinstance(value, str):
assert value == 'inferred'
value = InferredType.VALUE
setattr(self, key, value)
@property
def __pydantic_extra__(self) -> Dict[str, Union[int, InferredType]]:
extra = copy.copy(self.__dict__)
extra.pop('default')
extra.pop('__pydantic_initialised__')
return extra
@root_validator()
def validate_extra(
cls,
values: Dict[str, Any],
) -> Dict[str, Union[int, InferredType]]:
for key, value in values.items():
if key == 'default':
continue
if '->' not in key:
raise ValueError(f"'{key}' is not a valid edge definition. "
f"Ensure that the edge points from a source "
f"key to a destination key via "
f"'source_key->destination_key' syntax")
if value == 'inferred':
values[key] = InferredType.VALUE
continue
if not isinstance(value, int):
raise ValueError(f"Value of '{key}' is not a valid integer "
f"(got {value})")
if isinstance(value, int) and value < -1:
raise ValueError(f"Ensure the value of '{key}' is greater "
f"than or equal to -1 (got {value})")
if isinstance(value, int) and value > 512:
raise ValueError(f"Ensure the value of '{key}' is less than "
f"or equal to 512 (got {value})")
return values
def __repr__(self) -> str:
extra_repr = ', '.join([
f'{key}={value}' for key, value in self.__pydantic_extra__.items()
])
if len(extra_repr) > 0:
extra_repr = ', ' + extra_repr
return f'{self.__class__.__name__}(default={self.default}{extra_repr})'
# NOTE We need to monkey-patch `dataclasses.asdict()` in order to correctly
# serialize pydantic extra fields in pydantic<2.0.
_asdict_inner_orig = dataclasses._asdict_inner # type: ignore
def _asdict_inner(obj, dict_factory):
if isinstance(obj, HopConfig):
return {**dict(default=obj.default), **obj.__pydantic_extra__}
return _asdict_inner_orig(obj, dict_factory)
dataclasses._asdict_inner = _asdict_inner # type: ignore
MAX_NUM_HOPS = 6
def _to_dict(self) -> Dict[int, HopConfig]:
return {
int(f.name[len('hop'):]) - 1: getattr(self, f.name)
for f in fields(self) if getattr(self, f.name) is not None
}
def _validate_consecutive_hops(self) -> None:
hops = list(self.to_dict().keys())
if hops != list(range(len(hops))):
raise ValueError(f"Found non-consecutive hop definition "
f"{[hop + 1 for hop in hops]}.")
def _num_hops(self) -> int:
self._validate_consecutive_hops()
return len(self.to_dict())
def _repr(self) -> str:
self._validate_consecutive_hops()
hops_repr = []
for i in range(1, self.num_hops() + 1):
hop = getattr(self, f'hop{i}')
hop_dict = {**dict(default=hop.default), **hop.__pydantic_extra__}
if len(hop_dict) == 1:
hop_repr = str(hop.default)
else:
hop_info = [f'{key}={value}' for key, value in hop_dict.items()]
hop_info = [' ' * 2 + x for x in hop_info]
hop_repr = '{\n' + ',\n'.join(hop_info) + ',\n}'
hop_repr = f'hop{i}={hop_repr},'
hops_repr.append(hop_repr)
if len(hops_repr) == 0:
return 'NumNeighbors()'
return 'NumNeighbors(\n' + _add_indent('\n'.join(hops_repr), 2) + '\n)'
_NumNeighborsConfig = make_dataclass(
'_NumNeighborsConfig',
fields=[
(f'hop{i}', Optional[HopConfig], None) # type: ignore
for i in range(1, MAX_NUM_HOPS + 1)
],
namespace={
'to_dict': _to_dict,
'_validate_consecutive_hops': _validate_consecutive_hops,
'num_hops': _num_hops,
'__repr__': _repr,
'__len__': lambda self: self.num_hops(),
},
)
types._NumNeighborsConfig = _NumNeighborsConfig # type: ignore
NumNeighborsConfig = dataclass(
config=dict(validate_assignment=True), # type: ignore
repr=False,
)(_NumNeighborsConfig)
@dataclass
class Metadata:
tunable: bool = False # Tunable during AutoML search.
hidden: bool = False
valid_task_types: List[TaskType] = field( # all by default.
default_factory=lambda: list(TaskType))
valid_query_types: List[QueryType] = field( # all by default.
default_factory=lambda: list(QueryType))
class PlanMixin:
def items(self) -> Iterable[Tuple[str, Any, Metadata]]:
r"""Iterates over all attributes of this dataclass."""
schema = self.__pydantic_model__.schema() # type: ignore
for key in self.__dataclass_fields__: # type: ignore
value = getattr(self, key)
metadata = schema['properties'][key]['metadata']
yield key, value, metadata
def is_valid_option(
self,
name: str,
metadata: Metadata,
task_type: TaskType,
query_type: QueryType,
) -> bool:
"""
Whether the option is valid, given its task and query type.
Args:
name (str): The name of the field to check.
metadata (Metadata): The metadata associated with the option.
task_type (TaskType): The task type.
query_type (QueryType): The query type.
"""
return (task_type in metadata.valid_task_types
and query_type in metadata.valid_query_types)
def __repr__(self) -> str:
field_reprs = []
for key, value, metadata in self.items():
if metadata.hidden:
continue
if metadata.tunable and isinstance(value, list):
value_repr = '\n'.join(
[f'{_add_indent(repr(v), 2)},' for v in value])
value_repr = '[\n' + value_repr + '\n]'
else:
value_repr = repr(value)
field_reprs.append(f'{key}={value_repr},')
field_repr = '\n'.join(field_reprs)
reprs = _add_indent(field_repr, num_spaces=2)
return f'{self.__class__.__name__}(\n{reprs}\n)'
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class TrainingTableGenerationPlan(PlanMixin):
r"""Configuration parameters that define the construction of a Kumo
training table from a predictive query. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations
#training-table-generation>`_ for more information.
:ivar split: (``str``) A custom split that is used to generate a training,
validation, and test set in the training table
(*default:* ``"inferred"``).
**Supported Task Types:** All
:ivar train_start_offset: (``int`` | ``"inferred``") Defines the numerical
offset from the most recent entry to use to generate training data
labels. Unless a custom time unit is specified in the aggregation, this
value is in days (*default:* ``"inferred"``).
**Supported Task Types:** Temporal
:ivar train_end_offset: (``int`` | ``"inferred"``) Defines the numerical
offset from the most recent entry to not use to generate training data
labels. Unless a custom time unit is specified in the aggregation, this
value is in days (*default:* ``"inferred"``).
**Supported Task Types:** Temporal
:ivar timeframe_step: (``int`` | ``"inferred"``) Defines the step size of
generating time intervals for training table generation
(*default:* ``"inferred"``).
**Supported Task Types:** Temporal
:ivar forecast_length: (``int``) Turns a node regression problem into a
forecasting problem (*default:* ``1``).
**Supported Task Types:** Temporal Regression
:ivar lag_timesteps: (``int``) For forecasting problems,
leverage the auto-regressive labels as inputs. This parameter controls
the number of previous values that should be considered as
auto-regressive labels (*default:* ``0``).
**Supported Task Types:** Temporal Regression
:ivar year_over_year: (``bool``) For forecasting problems, integrate
Year-Over-Year features as inputs to give more attention to the data
from the previous year when making a prediction. (*default:* ``False``)
""" # noqa
# General Options =========================================================
# Respect resolution order by first trying to map strings to `InferredType`
split: Union[InferredType, str] = Field(
default=InferredType.VALUE,
metadata=Metadata(),
)
train_start_offset: Union[int, InferredType, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(valid_query_types=[QueryType.TEMPORAL]),
ge=0,
)
train_end_offset: Union[int, InferredType, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(valid_query_types=[QueryType.TEMPORAL]),
ge=0,
)
timeframe_step: Union[int, InferredType, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(valid_query_types=[QueryType.TEMPORAL]),
ge=1,
)
# Forecasting =============================================================
forecast_length: Optional[Union[int, MissingType]] = Field(
default=MissingType.VALUE,
metadata=Metadata(
valid_task_types=[TaskType.REGRESSION, TaskType.FORECASTING],
valid_query_types=[QueryType.TEMPORAL],
),
ge=1,
)
lag_timesteps: Optional[Union[int, MissingType]] = Field(
default=MissingType.VALUE,
metadata=Metadata(
valid_task_types=[TaskType.REGRESSION, TaskType.FORECASTING],
valid_query_types=[QueryType.TEMPORAL],
),
ge=0,
)
year_over_year: Union[bool, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
valid_task_types=[TaskType.REGRESSION, TaskType.FORECASTING],
valid_query_types=[QueryType.TEMPORAL],
),
)
# Entity Candidate Generation =============================================
entity_candidate: Optional[str] = Field(
default=None,
metadata=Metadata(
hidden=True,
valid_query_types=[QueryType.TEMPORAL],
),
)
entity_candidate_aggregation: Optional[str] = Field(
default=None,
metadata=Metadata(
hidden=True,
valid_query_types=[QueryType.TEMPORAL],
),
)
# Overriding Predictive Queries ===========================================
task_path_override: Optional[str] = Field(
default=None,
metadata=Metadata(hidden=True),
)
train_table_path_override: Optional[str] = Field(
default=None,
metadata=Metadata(hidden=True),
)
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class PredictionTableGenerationPlan(PlanMixin):
r"""Configuration parameters that define the construction of a Kumo
prediction table from a predictive query.
:ivar anchor_time: (``int`` | ``"inferred"`` | ``datetime.datetime``) The
time that a prediction horizon start time of "zero" refers to. If not
set, will be inferred to be the latest timestamp in the target fact
table. Note that this value can either be provided as an integer,
representing the number of nanoseconds from the Unix epoch, or as a
``datetime.datetime`` object (*default:* ``"inferred"``).
**Supported Task Types:** Temporal
"""
# General Options =========================================================
anchor_time: Union[int, InferredType, datetime.datetime] = Field(
default=InferredType.VALUE,
metadata=Metadata(valid_query_types=[QueryType.TEMPORAL]),
)
@validator('anchor_time')
def is_nanosecond_timestamp(
cls,
value: Union[int, InferredType, datetime.datetime],
) -> Union[int, InferredType]:
if isinstance(value, (InferredType, int)):
return value
# Convert datetime to timestmap in nanoseconds, note that this
# incorporates whatever timezone the `datetime` object is
# represented in:
return int(int(value.strftime('%s')) * 1e9)
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class TrainingJobPlan(PlanMixin):
r"""Configuration parameters that define the general execution of a Kumo
AutoML search. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations
#training-job-plan>`_ for more information.
:ivar num_experiments: (``int``) The number of experiments to run
(*default:* ``run_mode``-dependent).
**Supported Task Types:** All
:ivar metrics: (``list[str]``) The metrics to compute for the run
(*default:* ``task_type``-dependent).
**Supported Task Types:** All
:ivar tune_metric: (``str``) The metric to judge performance on
(*default:* ``task_type``-dependent).
**Supported Task Types:** All
:ivar refit_trainval: (``bool``) Whether to refit the model after training
on the training and validation splits (*default:* ``True``).
**Supported Task Types:** All
:ivar refit_full: (``bool``) Whether to refit the model after training
on the training, validation and test splits (*default:* ``False``).
**Supported Task Types:** All
"""
# General Options =========================================================
num_experiments: Union[int, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
ge=1,
)
metrics: Union[List[str], MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
)
# Respect resolution order by first trying to map strings to `MissingType`.
tune_metric: Union[MissingType, str] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
)
disable_compilation: bool = Field(
default=True,
metadata=Metadata(hidden=True),
)
# Refitting ===============================================================
refit_trainval: bool = Field(
default=True,
metadata=Metadata(),
)
refit_full: bool = Field(
default=False,
metadata=Metadata(),
)
# Debugging ===============================================================
disable_explain: bool = Field(
default=False,
metadata=Metadata(hidden=True),
)
manual_seed: Optional[int] = Field(
default=None,
metadata=Metadata(hidden=True),
)
# Deprecated Options ======================================================
enable_baselines: bool = Field(
default=False,
metadata=Metadata(hidden=True),
)
# =========================================================================
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class ColumnProcessingPlan(PlanMixin):
r"""Configuration parameters that define how columns are encoded in the
training and batch prediction pipelines. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations
#column-processing>`_ for more information.
:ivar encoder_overrides: (``dict[str, Encoder] | None``) A dictionary of
encoder overrides, which maps the ``{table_name}.{column name}`` to an
:class:`~kumoapi.encoder.Encoder` (*default:* ``None``).
**Supported Task Types:** All
"""
# General Options =========================================================
encoder_overrides: Optional[Dict[str, Union[EncoderType, str]]] = Field(
default=None,
metadata=Metadata(),
)
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class NeighborSamplingPlan(PlanMixin):
r"""Configuration parameters that define how subgraphs are sampled in the
training and batch prediction pipelines. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations
#neighbor-sampling>`_ for more information.
:ivar num_neighbors: (``list[NumNeighborsConfig]``) Determines the number
of neighbors to sample for each hop when sampling subgraphs for
training and prediction (*default:* ``run_mode``-dependent).
**Supported Task Types:** All
:ivar sample_from_entity_table: (``bool``) Whether to include the entity
table in sampling (*default:* ``True``).
**Supported Task Types:** Static
"""
# General Options =========================================================
max_target_neighbors_per_entity: Union[
conlist( # type: ignore
Union[conint(ge=-1, le=512), InferredType],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(hidden=True, tunable=True),
)
num_neighbors: Union[
conlist(NumNeighborsConfig, min_items=1), # type: ignore
conlist( # type: ignore
conlist(conint(ge=-1, le=128), max_items=6),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
sample_from_entity_table: Union[bool, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
valid_task_types=TaskType.get_node_pred_tasks(),
valid_query_types=[QueryType.STATIC],
),
)
# =========================================================================
def is_valid_option(
self,
name: str,
metadata: Metadata,
task_type: TaskType,
query_type: QueryType,
) -> bool:
if name == 'max_target_neighbors_per_entity':
return (query_type == QueryType.TEMPORAL
or task_type == TaskType.STATIC_LINK_PREDICTION)
return super().is_valid_option(name, metadata, task_type, query_type)
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class OptimizationPlan(PlanMixin):
r"""Configuration parameters that define how columns are encoded in the
training and batch prediction pipelines. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations
#optimization>`_ for more information.
:ivar max_epochs: (``int``) The maximum number of epochs to train a model
for (*default:* ``run_mode``-dependent).
**Supported Task Types:** All
:ivar min_steps_per_epoch: (``int``) The minimum number of steps to be
included in an epoch; one step corresponds to one forward pass of a
mini-batch (*default:* ``30``).
**Supported Task Types:** All
:ivar max_steps_per_epoch: (``int``) The maximum number of steps to be
included in an epoch; one step corresponds to one forward pass of a
mini-batch (*default:* ``run_mode``-dependent).
**Supported Task Types:** All
:ivar max_val_steps: (``int``) The maximum number of steps to be included
in a validation pass; one step corresponds to one forward pass of a
mini-batch (*default:* ``run_mode``-dependent).
**Supported Task Types:** All
:ivar max_test_steps: (``int``) The maximum number of steps to be included
in a test pass; one step corresponds to one forward pass of a
mini-batch (*default:* ``run_mode``-dependent).
**Supported Task Types:** All
:ivar loss: (``list[str]``) The loss type to use in the model optimizer
(*default:* ``task_type``-dependent).
**Supported Task Types:** All
:ivar base_lr: (``list[float]``) The base learning rate (pre-decay) to be
used in the model optimizer.
(*default:* ``[1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2]``).
**Supported Task Types:** All
:ivar weight_decay: (``list[float]``) A list of potential weight decay
options in the model optimizer.
(*default:* ``[0.0, 5e-8, 5e-7, 5e-6]``).
**Supported Task Types:** All
:ivar batch_size: (``list[int]``) The number of examples to be included in
one mini-batch. (*default:* ``[512, 1024]``).
**Supported Task Types:** All
:ivar early_stopping: (``list[EarlyStoppingConfig]``) A list of potential
early stopping strategies
:class:`~kumoapi.model_plan.EarlyStoppingConfig` for model optimization
(*default:* ``[{min_delta=0.0, patience=3}]``).
**Supported Task Types:** All
:ivar lr_scheduler: (``list[LRSchedulerConfig]``) A list of potential
learning rate schedulers
:class:`~kumoapi.model_plan.LRSchedulerConfig` for model optimization
(*default:* ``[
{name="cosine_with_warmup_restarts", interval="step"},
{name="constant_with_warmup", interval="step"},
{name="linear_with_warmup", interval="step"},
{name="csoine_with_warmup", interval="step"}]``).
**Supported Task Types:** All
:ivar majority_sampling_ratio: (``list[float | None]``) A ratio to specify
how examples are smapled from the majority class
(*default:* ``[None]``).
**Supported Task Types:** Binary Classification
:ivar weight_mode: (``list[float | None]``) If ``majority_sampling_ratio``
is given, this option specifies how to weigh majority vs. minority
classes during training
(*default:* ``["sample"]``).
**Supported Task Types:** Binary Classification
"""
# General Options =========================================================
max_epochs: Union[int, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
ge=1,
)
min_steps_per_epoch: Union[int, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
ge=1,
le=4000,
)
max_steps_per_epoch: Union[int, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
ge=1,
)
max_val_steps: Union[int, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
ge=1,
)
max_test_steps: Union[int, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(),
ge=1,
)
loss: Union[
conlist( # type: ignore
Union[LossType, FocalLossConfig, HuberLossConfig],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
base_lr: Union[
conlist( # type: ignore
confloat(gt=0.0),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
weight_decay: Union[
conlist( # type: ignore
confloat(ge=0.0),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
batch_size: Union[
conlist( # type: ignore
conint(ge=1, le=2048),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
early_stopping: Union[
conlist( # type: ignore
Optional[EarlyStoppingConfig],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
lr_scheduler: Union[
conlist( # type: ignore
Optional[LRSchedulerConfig],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
# Task-specific Options ===================================================
majority_sampling_ratio: Union[
conlist( # type: ignore
Optional[confloat(gt=0.0)],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=[TaskType.BINARY_CLASSIFICATION],
),
)
weight_mode: Union[
conlist( # type: ignore
WeightMode,
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=[TaskType.BINARY_CLASSIFICATION],
),
)
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class ModelArchitecturePlan(PlanMixin):
r"""Configuration parameters that define how the Kumo graph neural network
is architected. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations
#model-architecture>`_ for more information.
:ivar channels: (``list[int]``) A list of potential dimension of layers in
the Graph Neural Network model
(*default:* ``[64, 128, 256]``).
**Supported Task Types:** All
:ivar num_pre_message_passing_layers: (``list[int]``) A list of potential
number of multi-layer perceptron layers *before* message passing layers
in the Graph Neural Network model
(*default:* ``[0, 1, 2]``).
**Supported Task Types:** All
:ivar num_pre_message_passing_layers: (``list[int]``) A list of potential
number of multi-layer perceptron layers *after* message passing layers
in the Graph Neural Network model
(*default:* ``[1, 2]``).
**Supported Task Types:** All
:ivar aggregation: (``list[list["sum" | "mean" | "min" | "max" | "std"]]``)
A nested list of aggregation operators in the Graph Neural Network
aggregation process
(*default:* ``[
["sum", "mean", "max"],
["sum", "mean", "min", "max", "std"]]``).
**Supported Task Types:** All
:ivar activation: (``list["relu" | "leaky_relu" | "elu" | "gelu"]``) A list
of activation functions to use during AutoML
(*default:* ``["relu", "leaky_relu", "elu", "gelu"]``).
**Supported Task Types:** All
:ivar normalization: (``list[None | "layer_norm" | "batch_norm"]``) The
normalization layer to apply (*default:* ``["layer_norm"]``).
**Supported Task Types:** All
:ivar module: (``"ranking"`` | ``"embedding"``) The link prediction module
to use (*default:* ``["ranking"]``).
**Supported Task Types:** Link Prediction
:ivar handle_new_target_entities: (``bool``) Whether to make link
prediction models be able to handle predictions on new target entities
at batch prediction time (*default:* ``False``).
**Supported Task Types:** Link Prediction
:ivar target_embedding_mode: (``["lookup" | "feature" | "shallow_feature" |
fusion]``) Specifies how target node embeddings are embedded
(*default:* ``["lookup"]``).
**Supported Task Types:** Link Prediction
:ivar output_embedding_dim: (``[int]``) The output embedding dimension for
link prediction models (*default:* ``[32]``).
**Supported Task Types:** Link Prediction
:ivar ranking_embedding_loss_coeff: (``[float]``) The coefficient of the
embedding loss applied to train ranking-based link prediction models
link prediction models (*default:* ``[0.0]``).
**Supported Task Types:** Temporal Link Prediction
:ivar distance_measure: (``["dot_product" | "cosine"]``) Specifies the
distance measure between node embeddings to use in the final link
prediction calculation (*default:* ``["dot_product"]``).
**Supported Task Types:** Link Prediction
:ivar use_seq_id: (``[bool]``) Specifies whether to use postional encodings
of the sequence order of facts as an additional model feature
(*default:* ``[False]``).
**Supported Task Types:** All
:ivar prediction_time_encodings: (``[bool]``) Specifies whether to encode
the absolute prediction time as an additional model feature
(*default:* ``[False]``).
**Supported Task Types:** Temporal Node Prediction
:ivar past_encoder: (``["decomposed" | "normalized" | "mlp" |
"transformer"]``) Specifies how to encode auto-regressive labels if
present (*default:* ``["decomposed"]``).
**Supported Task Types:** Temporal Regression
:ivar handle_new_entities: (``bool``) Whether to make forecasting models
transductive by learning entity-specific heads. This can improve
performance in case entities stay static over time, but will decrease
performance on new entities arising during batch prediction time
(*default:* ``True``).
**Supported Task Types:** Forecasting
"""
# General Options =========================================================
channels: Union[
conlist( # type: ignore
conint(ge=1, le=512),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
num_pre_message_passing_layers: Union[
conlist( # type: ignore
conint(ge=0, le=4),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
num_post_message_passing_layers: Union[
conlist( # type: ignore
conint(ge=1, le=8),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
aggregation: Union[
conlist( # type: ignore
List[AggregationType],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
activation: Union[
conlist( # type: ignore
ActivationType,
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
normalization: Union[
conlist( # type: ignore
Optional[NormalizationType],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
# Link Prediction =========================================================
module: Union[LinkPredOutputType, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(valid_task_types=TaskType.get_link_pred_tasks()),
)
handle_new_target_entities: Union[bool, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(valid_task_types=TaskType.get_link_pred_tasks()),
)
target_embedding_mode: Union[
conlist( # type: ignore
RHSEmbeddingMode,
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=TaskType.get_link_pred_tasks(),
),
)
output_embedding_dim: Union[
conlist( # type: ignore
conint(ge=1, le=256),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=TaskType.get_link_pred_tasks(),
),
)
ranking_embedding_loss_coeff: Union[
conlist( # type: ignore
confloat(ge=0.0),
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=[TaskType.TEMPORAL_LINK_PREDICTION],
),
)
distance_measure: Union[
conlist( # type: ignore
DistanceMeasureType,
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=TaskType.get_link_pred_tasks(),
),
)
# Private Preview Options =================================================
use_seq_id: Union[
conlist( # type: ignore
bool,
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(tunable=True),
)
prediction_time_encodings: Union[
conlist( # type: ignore
bool,
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=[
TaskType.REGRESSION,
TaskType.FORECASTING,
TaskType.BINARY_CLASSIFICATION,
TaskType.MULTICLASS_CLASSIFICATION,
TaskType.MULTILABEL_CLASSIFICATION,
TaskType.MULTILABEL_RANKING,
],
valid_query_types=[QueryType.TEMPORAL],
),
)
past_encoder: Union[
conlist( # type: ignore
Optional[PastEncoderType],
min_items=1,
),
MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
tunable=True,
valid_task_types=[TaskType.REGRESSION, TaskType.FORECASTING],
valid_query_types=[QueryType.TEMPORAL],
),
)
handle_new_entities: Union[bool, MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
valid_task_types=[TaskType.REGRESSION, TaskType.FORECASTING],
valid_query_types=[QueryType.TEMPORAL],
),
)
# Deprecated Options ======================================================
forecast_type: Union[List[str], MissingType] = Field(
default=MissingType.VALUE,
metadata=Metadata(
hidden=True,
valid_task_types=[TaskType.REGRESSION, TaskType.FORECASTING],
valid_query_types=[QueryType.TEMPORAL],
),
)
# =========================================================================
@validator('channels')
def is_even_channel(
cls,
values: Union[List[int], MissingType],
) -> Union[List[int], MissingType]:
if isinstance(values, list):
for value in values:
if value % 2 != 0:
raise ValueError(f"'channels' requires an even number "
f"(got {value})")
return values
[docs]@dataclass(config=dict(validate_assignment=True), repr=False) # type: ignore
class ModelPlan:
r"""A complete definition of a Kumo model plan, encompassing a
:class:`~kumoapi.model_plan.TrainingJobPlan`,
:class:`~kumoapi.model_plan.ColumnProcessingPlan`,
:class:`~kumoapi.model_plan.NeighborSamplingPlan`,
:class:`~kumoapi.model_plan.OptimizationPlan`, and a
:class:`~kumoapi.model_plan.ModelArchitecturePlan`. Please see the
`Kumo documentation <https://docs.kumo.ai/docs/advanced-operations>`_
for more information."""
#: The training job plan.
training_job: TrainingJobPlan = field(default_factory=TrainingJobPlan)
# The column processing plan.
column_processing: ColumnProcessingPlan = field(
default_factory=ColumnProcessingPlan)
#: The neighbor sampling plan.
neighbor_sampling: NeighborSamplingPlan = field(
default_factory=NeighborSamplingPlan)
#: The model optimization plan.
optimization: OptimizationPlan = field(default_factory=OptimizationPlan)
#: The model architecture plan.
model_architecture: ModelArchitecturePlan = field(
default_factory=ModelArchitecturePlan)
def __repr__(self) -> str:
field_repr = '\n'.join(
[f'{f.name}={getattr(self, f.name)},' for f in fields(self)])
reprs = _add_indent(field_repr, num_spaces=2)
return f'{self.__class__.__name__}(\n{reprs}\n)'
@dataclass
class ModelPlanInfo:
model_plan: ModelPlan
task_type: TaskType
query_type: QueryType
# =============================================================================
@dataclass
class SuggestModelPlanRequest:
r"""A request to infer Kumo table metadata."""
query_string: str
graph_id: str
run_mode: RunMode
@dataclass
class SuggestModelPlanResponse:
r"""A response containing metadata for a Kumo table."""
model_plan: ModelPlan
def _add_indent(text: str, num_spaces: int) -> str:
lines = text.split('\n')
return '\n'.join([' ' * num_spaces + line for line in lines])