from dataclasses import field, fields
from typing import Dict, List, Union
from pydantic import ConfigDict
from pydantic import Field as PydanticField
from pydantic.dataclasses import dataclass
from kumoapi.model_plan import (
ColumnProcessingPlan,
Metadata,
MissingType,
OptimizationPlan,
PlanMixin,
RunMode,
TrainingJobPlan,
_add_indent,
_SerializableColumnProcessingPlan,
_SerializableOptimizationPlan,
_SerializableTrainingJobPlan,
compat_conlist,
compat_field,
confloat,
conint,
)
from kumoapi.pquery import QueryType
from kumoapi.task import TaskType
from kumoapi.typing import WITH_PYDANTIC_V2, TimeUnit
@dataclass(config=dict(validate_assignment=True))
class TimeOffset:
r"""Represents a time offset with a value and unit.
:ivar value: (``int``) The numerical value of the time offset.
:ivar unit: (``TimeUnit``) The unit of time for the offset
(*default:* ``TimeUnit.DAYS``).
"""
value: int = PydanticField(ge=0)
unit: TimeUnit = compat_field(default=TimeUnit.DAYS, metadata=Metadata())
[docs]
@dataclass(config=dict(validate_assignment=True))
class DistillationPlan:
r"""Defines attributes that affect the features/interactions used to
train the online serving model.
:ivar embedding_keys: (``list[str]``) Key column(s) in the entity table
used to extract embeddings from the deep model during distillation.
A primary key extracts from the entity table itself; foreign key(s)
extract from their corresponding 1-hop neighbor table(s).
:ivar max_embedding_offset: (``TimeOffset``) Maximum staleness of deep
model embeddings relative to the anchor time.
Defines the upper bound on the offset between the embedding seed
time and the anchor time.
:ivar min_embedding_offset: (``TimeOffset``) Minimum staleness of deep
model embeddings relative to the anchor time.
Defines the lower bound on the offset between the embedding seed
time and the anchor time. Models the latency between embedding
generation and its availability at serving.
:ivar real_time_interactions: (``dict[str, int]``) Real-time interaction
key paths mapped to the maximum number of recent interactions to
incorporate at inference time.
For entity-level predictions, formatted as
``'entityTable.pkeyCol->interactionTable.fkeyCol'``.
For fact-level predictions, formatted as
``'factTable.fkeyCol->entityTable.pkeyCol->interactionTable.fkeyCol'``.
Examples: ``{'users.id->orders.user_id': 32}``,
``{'orders.user_id->users.id->views.user_id': 16}``.
(*default:* ``{}``).
:ivar real_time_offset: (``TimeOffset``) Minimum offset between the anchor
time and the most recent real-time interaction available at serving.
Models the end-to-end ingestion latency of the real-time data pipeline;
interactions arriving within this window of the anchor time are
excluded.
"""
embedding_keys: Union[List[str], MissingType] = compat_field(
default_factory=lambda: MissingType.VALUE,
metadata=Metadata(),
min_length=1,
)
max_embedding_offset: Union[TimeOffset, MissingType] = compat_field(
default=MissingType.VALUE,
metadata=Metadata(),
)
min_embedding_offset: Union[TimeOffset, MissingType] = compat_field(
default=MissingType.VALUE,
metadata=Metadata(),
)
real_time_interactions: Dict[str, int] = compat_field(
default_factory=dict,
metadata=Metadata(),
)
real_time_offset: Union[TimeOffset, MissingType] = compat_field(
default=MissingType.VALUE,
metadata=Metadata(),
)
@dataclass(config=dict(validate_assignment=True), repr=False)
class DistillationModelArchitecturePlan(PlanMixin):
r"""Model architecture configuration for distilled models.
:ivar channels: (``list[int]``) A list of candidate hidden
feature dimensionalities for the online serving
transformer model
(*default:* ``[64, 128]``).
:ivar num_layers: (``list[int]``) A list of candidate
numbers of transformer layers
(*default:* ``[4, 6]``).
:ivar num_heads: (``list[int]``) A list of candidate
numbers of attention heads
(*default:* ``[8]``).
:ivar emb_dropout: (``list[float]``) A list of candidate
embedding dropout rates. Probability of zeroing out
deep model embeddings during training
(*default:* ``[0.0]``).
:ivar dropout: (``list[float]``) A list of candidate
dropout rates
(*default:* ``[0.2]``).
"""
channels: Union[
compat_conlist( # type: ignore
conint(ge=1, le=512),
min_length=1,
),
MissingType] = compat_field(
default_factory=lambda: MissingType.VALUE,
metadata=Metadata(tunable=True),
)
num_layers: Union[
compat_conlist( # type: ignore
conint(ge=1, le=12),
min_length=1,
),
MissingType] = compat_field(
default_factory=lambda: MissingType.VALUE,
metadata=Metadata(tunable=True),
)
num_heads: Union[
compat_conlist( # type: ignore
conint(ge=1, le=32),
min_length=1,
),
MissingType] = compat_field(
default_factory=lambda: MissingType.VALUE,
metadata=Metadata(tunable=True),
)
emb_dropout: Union[
compat_conlist( # type: ignore
confloat(ge=0.0, lt=1.0),
min_length=1,
),
MissingType] = compat_field(
default_factory=lambda: MissingType.VALUE,
metadata=Metadata(tunable=True),
)
dropout: Union[
compat_conlist( # type: ignore
confloat(ge=0.0, lt=1.0),
min_length=1,
),
MissingType] = compat_field(
default_factory=lambda: MissingType.VALUE,
metadata=Metadata(tunable=True),
)
if WITH_PYDANTIC_V2:
from pydantic import SerializeAsAny
_SerializableDistillationPlan = SerializeAsAny[DistillationPlan]
_SerializableDistillationModelArchitecturePlan = SerializeAsAny[
DistillationModelArchitecturePlan]
else:
_SerializableDistillationPlan = DistillationPlan
_SerializableDistillationModelArchitecturePlan = (
DistillationModelArchitecturePlan)
[docs]
@dataclass(config=ConfigDict(validate_assignment=True), repr=False)
class DistilledModelPlan:
r"""A complete definition of a Kumo distilled model plan.
Encompasses a
:class:`~kumoapi.model_plan.TrainingJobPlan`,
:class:`~kumoapi.model_plan.ColumnProcessingPlan`,
:class:`~kumoapi.model_plan.OptimizationPlan`,
:class:`~kumoapi.distilled_model_plan.DistillationPlan`,
and a :class:`~kumoapi.distilled_model_plan\
.DistillationModelArchitecturePlan`.
"""
training_job: _SerializableTrainingJobPlan = field(
default_factory=TrainingJobPlan)
# In the column_processing plan the training table
# features can be keyed as `TRAIN_TABLE.{column_name}`.
column_processing: _SerializableColumnProcessingPlan = field(
default_factory=ColumnProcessingPlan)
optimization: _SerializableOptimizationPlan = field(
default_factory=OptimizationPlan)
distillation: _SerializableDistillationPlan = field(
default_factory=DistillationPlan)
model_architecture: _SerializableDistillationModelArchitecturePlan = field(
default_factory=DistillationModelArchitecturePlan)
def __repr__(self) -> str:
field_repr = '\n'.join(
[f'{f.name}={getattr(self, f.name)},' for f in fields(self)])
reprs = _add_indent(field_repr, num_spaces=2)
return f'{self.__class__.__name__}(\n{reprs}\n)'
@dataclass
class SuggestDistilledModelPlanRequest:
r"""A request to suggest a default distilled model plan based on
``query_string``, ``graph_id``, ``run_mode``, and ``base_model_id``.
"""
query_string: str
graph_id: str
run_mode: RunMode
base_model_id: str
has_train_table_weight_col: bool = False
@dataclass
class DefaultDistilledModelPlanInfo:
r"""A response containing the suggested distilled model plan and
associated metadata.
"""
distilled_model_plan: DistilledModelPlan
task_type: TaskType
query_type: QueryType
has_train_table_weight_col: bool = False