diff --git a/pytorch_forecasting/models/nbeats/_nbeats.py b/pytorch_forecasting/models/nbeats/_nbeats.py index e87fa649c..dc4c7be1b 100644 --- a/pytorch_forecasting/models/nbeats/_nbeats.py +++ b/pytorch_forecasting/models/nbeats/_nbeats.py @@ -52,54 +52,88 @@ def __init__( **kwargs, ): """ - Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible. + Initialize NBeats Model. - Based on the article - `N-BEATS: Neural basis expansion analysis for interpretable time series - forecasting `_. The network has (if used as ensemble) outperformed all - other methods - including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably - the most - important benchmark for univariate time series forecasting. + The model can be initialized in two ways: + 1. Using the :py:meth:`~from_dataset` classmethod + (recommended for standard time series forecasting) + 2. Direct initialization with required parameters (for custom use cases) - The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform - N-BEATS. + Based on the article `N-BEATS: Neural basis expansion analysis for + interpretable time series forecasting `_. + The network has (if used as ensemble) outperformed all other methods including + ensembles of traditional statical methods in the M4 competition. + + The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently + shown to consistently outperform N-BEATS. Args: - stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings - of length 1 or ‘num_stacks’. Default and recommended value - for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”] - num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’. - Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3] - num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length - 1 or ‘num_stacks’. - Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4] - width: Widths of the fully connected layers with ReLu activation in the blocks. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512] - Recommended value for interpretable mode: [256, 2048] - sharing: Whether the weights are shared with the other blocks per stack. - A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False] - Recommended value for interpretable mode: [True] - expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion - coefficient. - If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S” - (seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep. - A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for - interpretable mode: [3] - prediction_length: Length of the prediction. Also known as 'horizon'. - context_length: Number of time units that condition the predictions. Also known as 'lookback period'. - Should be between 1-10 times the prediction length. - backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss. - A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and - forecast lengths). Defaults to 0.0, i.e. no weight. - loss: loss to optimize. Defaults to MASE(). - log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training - failures - reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10 - logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training. - Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) - **kwargs: additional arguments to :py:class:`~BaseModel`. - """ # noqa: E501 + stack_types: One of the following values: "generic", "seasonality" or + "trend". + A list of strings of length 1 or 'num_stacks'. + Default and recommended value for generic mode: ["generic"]. + Recommended value for interpretable mode: ["trend","seasonality"] + num_blocks: The number of blocks per stack. A list of ints of length 1 or + 'num_stacks'. Default and recommended value for generic mode: [1]. + Recommended value for interpretable mode: [3] + num_block_layers: Number of fully connected layers with + ReLu activation per block. + A list of ints of length 1 or 'num_stacks'. Default and recommended + value for generic mode: [4]. + Recommended value for interpretable mode: [4] + width: Width of fully connected layers with ReLu activation. + A list of ints (length = 'num_stacks'). + Default generic mode: [512] + Default interpretable mode: [256, 2048] + sharing: Share weights between blocks per stack. + A list of bools (length = 'num_stacks'). + Default generic mode: [False] + Default interpretable mode: [True] + expansion_coefficient_lengths: Configures each stack type: + - "generic": expansion coefficient length + - "trend": polynomial degree + - "seasonal": minimum period for changes + A list of ints (length = 'num_stacks'). + Default generic mode: [32] + Default interpretable mode: [3] + prediction_length: Length of the prediction horizon + context_length: Number of timesteps for predictions. + Should be 1-10x prediction_length. + dropout: Dropout rate (0.0 to 1.0) + learning_rate: Initial learning rate + log_interval: Logging frequency (-1 = end of epoch) + log_gradient_flow: If to log gradient flow, this takes time and should be + only done to diagnose training failures + log_val_interval: Log validation metrics every x batches. + weight_decay: L2 regularization factor + backcast_loss_ratio: Ratio of backcast loss vs forecast loss. + loss: PyTorch metric to optimize. Defaults to MASE() + reduce_on_plateau_patience: Patience after which learning rate is reduced + logging_metrics: List of metrics logged during training. Defaults to + nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) + **kwargs: Additional arguments for BaseModel + + Example: + Direct initialization: + + >>> from pytorch_forecasting.models import NBeats + >>> model = NBeats( + ... stack_types=["trend", "seasonality"], + ... num_blocks=[3, 3], + ... num_block_layers=[3, 3], + ... widths=[32, 512], + ... sharing=[True, True], + ... expansion_coefficient_lengths=[3, 7], + ... prediction_length=24, + ... context_length=72, + ... ) + + Initialization from dataset (recommended): + + >>> from pytorch_forecasting import TimeSeriesDataSet, NBeats + >>> dataset = TimeSeriesDataSet(...) + >>> model = NBeats.from_dataset(dataset) + """ if expansion_coefficient_lengths is None: expansion_coefficient_lengths = [3, 7] if sharing is None: @@ -116,6 +150,32 @@ def __init__( logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()]) if loss is None: loss = MASE() + + # Validate parameters + if not isinstance(prediction_length, int) or prediction_length < 1: + raise ValueError("prediction_length must be a positive integer") + if not isinstance(context_length, int) or context_length < 1: + raise ValueError("context_length must be a positive integer") + if not all(s in ["generic", "seasonality", "trend"] for s in stack_types): + raise ValueError( + "stack_types must contain only 'generic', 'seasonality', or 'trend'" + ) + + # Validate list lengths + n_stacks = len(stack_types) + for param_name, param_value in [ + ("num_blocks", num_blocks), + ("num_block_layers", num_block_layers), + ("widths", widths), + ("sharing", sharing), + ("expansion_coefficient_lengths", expansion_coefficient_lengths), + ]: + if len(param_value) != n_stacks: + raise ValueError( + f"Length of {param_name} ({len(param_value)}) must match " + f"length of stack_types ({n_stacks})" + ) + self.save_hyperparameters() super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs) @@ -232,15 +292,22 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: @classmethod def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs): """ - Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. + Create an NBeats model from a TimeSeriesDataSet. + + This is the recommended way to create an NBeats model for standard + time series forecasting. For custom uses where dataset constraints + don't fit, initialize the model directly using the constructor. Args: dataset (TimeSeriesDataSet): dataset where sole predictor is the target. **kwargs: additional arguments to be passed to ``__init__`` method. Returns: - NBeats - """ # noqa: E501 + NBeats: initialized model + + Raises: + AssertionError: if dataset constraints are not met + """ new_kwargs = { "prediction_length": dataset.max_prediction_length, "context_length": dataset.max_encoder_length, @@ -361,17 +428,18 @@ def plot_interpretation( """ Plot interpretation. - Plot two pannels: prediction and backcast vs actuals and - decomposition of prediction into trend, seasonality and generic forecast. + Plot two pannels: prediction and backcast vs actuals and decomposition of prediction + into trend, seasonality and generic forecast. Args: x (Dict[str, torch.Tensor]): network input output (Dict[str, torch.Tensor]): network output idx (int): index of sample for which to plot the interpretation. - ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation. - Defaults to None. - plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and - generic forecast on secondary axis in second panel. Defaults to False. + ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to + plot the interpretation. Defaults to None. + plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot + seasonality and generic forecast on secondary axis in second panel. Defaults + to False. Returns: plt.Figure: matplotlib figure diff --git a/tests/test_models/test_nbeats.py b/tests/test_models/test_nbeats.py index c3379fbf1..b6af5edfe 100644 --- a/tests/test_models/test_nbeats.py +++ b/tests/test_models/test_nbeats.py @@ -36,7 +36,7 @@ def test_integration(dataloaders_fixed_window_without_covariates, tmp_path): train_dataloader.dataset, learning_rate=0.15, log_gradient_flow=True, - widths=[4, 4, 4], + widths=[4, 4], log_interval=1000, backcast_loss_ratio=1.0, ) @@ -77,7 +77,7 @@ def model(dataloaders_fixed_window_without_covariates): dataset, learning_rate=0.15, log_gradient_flow=True, - widths=[4, 4, 4], + widths=[4, 4], log_interval=1000, backcast_loss_ratio=1.0, ) @@ -101,3 +101,39 @@ def test_interpretation(model, dataloaders_fixed_window_without_covariates): fast_dev_run=True, ) model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0) + + +def test_direct_initialization(): + # Test that the model can be initialized directly without from_dataset + net = NBeats( + stack_types=["trend", "seasonality"], + num_blocks=[3, 3], + num_block_layers=[3, 3], + widths=[32, 512], + sharing=[True, True], + expansion_coefficient_lengths=[3, 7], + prediction_length=24, + context_length=72, + ) + assert len(net.net_blocks) == 6 # 2 stacks * 3 blocks each + assert net.hparams.prediction_length == 24 + assert net.hparams.context_length == 72 + + # Test validation of parameters + with pytest.raises(ValueError, match="stack_types must contain only"): + NBeats(stack_types=["invalid_type"]) + + with pytest.raises(ValueError, match="Length of num_blocks"): + NBeats( + stack_types=["trend", "seasonality"], + num_blocks=[3], # Should be length 2 + prediction_length=24, + context_length=72, + ) + + with pytest.raises(ValueError, match="prediction_length must be"): + NBeats( + stack_types=["trend", "seasonality"], + prediction_length=0, # Invalid + context_length=72, + )