Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/gluonts/model/trivial/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from gluonts.model.trivial.constant import ConstantPredictor
from gluonts.pydantic import PositiveInt

from scipy.stats import skewnorm


class MeanPredictor(RepresentablePredictor):
"""
Expand Down Expand Up @@ -170,3 +172,27 @@ def train(
)

return ConstantPredictor(samples=samples)


class SkewedMeanPredictor(MeanPredictor):
def __init__(self, prediction_length, num_samples=20, skewness=10):
assert num_samples > 1, "num_samples must be set greater than 1"
self.skewness = skewness # positive values are right skewed, negative values are left skewed
super().__init__(
prediction_length=prediction_length, num_samples=num_samples
)

def generate_skew(self, target):
mean = target.info["mean"]
std = target.info["std"]
skewed_targets = skewnorm.rvs(
a=self.skewness, loc=mean, scale=std, size=self.shape
)
return skewed_targets

def predict_item(self, item):
return SampleForecast(
samples=self.generate_skew(super().predict_item(item)),
start_date=forecast_start(item),
item_id=item.get(FieldName.ITEM_ID),
)
159 changes: 159 additions & 0 deletions src/gluonts/model/trivial/median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Optional

import numpy as np

from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.util import forecast_start
from gluonts.model.estimator import Estimator
from gluonts.model.forecast import SampleForecast
from gluonts.model.predictor import RepresentablePredictor
from gluonts.model.trivial.constant import ConstantPredictor
from gluonts.pydantic import PositiveInt


class MedianPredictor(RepresentablePredictor):
"""
A :class:`Predictor` that predicts the samples based on the median of the
last `context_length` elements of the input target.

Parameters
----------
context_length
Length of the target context used to condition the predictions.
prediction_length
Length of the prediction horizon.
num_samples
Number of samples to use to construct :class:`SampleForecast` objects
for every prediction.
"""

@validated()
def __init__(
self,
prediction_length: int,
num_samples: int = 100,
context_length: Optional[int] = None,
) -> None:
super().__init__(prediction_length=prediction_length)
self.context_length = context_length
self.num_samples = num_samples
self.shape = (self.num_samples, self.prediction_length)

def predict_item(self, item: DataEntry) -> SampleForecast:
if self.context_length is not None:
target = item["target"][-self.context_length :]
else:
target = item["target"]

median = np.nanmedian(target)
std = np.nanstd(target)
normal = np.random.standard_normal(self.shape)

return SampleForecast(
samples=std * normal + median,
start_date=forecast_start(item),
item_id=item.get(FieldName.ITEM_ID),
)


class MovingMedianPredictor(RepresentablePredictor):
"""
A :class:`Predictor` that predicts the moving median based on the last
`context_length` elements of the input target.

If `prediction_length` = 1, the output is the moving median
based on the last `context_length` elements of the input target.

If `prediction_length` > 1, the output is the moving median based on the
last `context_length` elements of the input target, where previously
calculated moving medians are appended at the end of the input target.
Hence, for `prediction_length` larger than `context_length`, there will be
cases where the moving median is calculated on top of previous moving
medians.

Parameters
----------
context_length
Length of the target context used to condition the predictions.
prediction_length
Length of the prediction horizon.
"""

@validated()
def __init__(
self,
prediction_length: int,
context_length: Optional[int] = None,
) -> None:
super().__init__(prediction_length=prediction_length)

if context_length is not None:
assert (
context_length >= 1
), "The value of 'context_length' should be >= 1 or None"

self.context_length = context_length

def predict_item(self, item: DataEntry) -> SampleForecast:
target = item["target"].tolist()

for _ in range(self.prediction_length):
if self.context_length is not None:
window = target[-self.context_length :]
else:
window = target

target.append(np.nanmedian(window))

return SampleForecast(
samples=np.array([target[-self.prediction_length :]]),
start_date=forecast_start(item),
item_id=item.get(FieldName.ITEM_ID),
)


class MedianEstimator(Estimator):
"""
An `Estimator` that computes the median targets in the training data, in the
trailing `prediction_length` observations, and produces a
`ConstantPredictor` that always predicts such median value.

Parameters
----------
prediction_length
Prediction horizon.
num_samples
Number of samples to include in the forecasts. Not that the samples
produced by this predictor will all be identical.
"""

@validated()
def __init__(
self,
prediction_length: PositiveInt,
num_samples: PositiveInt,
) -> None:
super().__init__()
self.prediction_length = prediction_length
self.num_samples = num_samples

def train(
self,
training_data: Dataset,
validation_dataset: Optional[Dataset] = None,
) -> ConstantPredictor:
contexts = np.array(
[
item["target"][-self.prediction_length :]
for item in training_data
]
)

samples = np.broadcast_to(
array=np.nanmedian(contexts, axis=0),
shape=(self.num_samples, self.prediction_length),
)

return ConstantPredictor(samples=samples)
141 changes: 141 additions & 0 deletions src/gluonts/model/trivial/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import Iterator, Tuple

import numpy as np
import pandas as pd

from gluonts.core.component import validated
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.model.forecast import Forecast
from gluonts.model.predictor import Predictor
from gluonts.model.forecast import SampleForecast
from gluonts.model.predictor import Predictor
from gluonts.dataset.util import period_index
from gluonts.dataset.split import split
from gluonts.dataset.util import forecast_start


def _to_dataframe(input_label: Tuple[DataEntry, DataEntry]) -> pd.DataFrame:
"""
Turn a pair of consecutive (in time) data entries into a dataframe.
"""
start = input_label[0][FieldName.START]
targets = [entry[FieldName.TARGET] for entry in input_label]
full_target = np.concatenate(targets, axis=-1)
index = period_index(
{FieldName.START: start, FieldName.TARGET: full_target}
)
return pd.DataFrame(full_target.transpose(), index=index)


def make_oracle_predictions(
dataset: Dataset,
predictor: Predictor,
num_samples: int = 100,
) -> Tuple[Iterator[Forecast], Iterator[pd.Series]]:
"""
!!! CAN ONLY BE USED WITH ORACLE STYLE PREDICTORS !!!
Oracle predictors are predictors that can access the future values of the
target during prediction time. This function is used to evaluate such
predictors by providing them with the actual future values of the target
during prediction time.

Parameters
----------
dataset
Dataset where the evaluation will happen. Only the portion excluding
the prediction_length portion is used when making prediction.
predictor
Model used to draw predictions.
num_samples
Number of samples to draw on the model when evaluating. Only
sampling-based models will use this.

Returns
-------
Tuple[Iterator[Forecast], Iterator[pd.Series]]
A pair of iterators, the first one yielding the forecasts, and the
second one yielding the corresponding ground truth series.
"""

window_length = predictor.prediction_length + getattr(
predictor, "lead_time", 0
)
_, test_template = split(dataset, offset=-window_length)
test_data = test_template.generate_instances(window_length)

return (
predictor.predict(
test_data.input,
num_samples=num_samples,
ground_truth=test_data.label,
),
map(_to_dataframe, test_data),
)


class TrueOraclePredictor(Predictor):
@validated()
def __init__(self, prediction_length: int, num_samples: int) -> None:
self.prediction_length = prediction_length
self.num_samples = num_samples

def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
ground_truth = kwargs.pop("ground_truth", None)
if ground_truth is None:
raise ValueError(
"Ground truth is required for TrueOraclePredictor"
)
for item, label in zip(dataset, ground_truth):
yield self.predict_item(
item, ground_truth=label[FieldName.TARGET], **kwargs
)

def predict_item(
self, item: DataEntry, num_samples: int, ground_truth=None
) -> Forecast:
forecast_start_time = forecast_start(item)
samples = np.tile(
ground_truth[-self.prediction_length :], (num_samples, 1)
)
return SampleForecast(
samples=samples,
start_date=forecast_start_time,
item_id=item.get("item_id", None),
)


class OffsetOraclePredictor(Predictor):
@validated()
def __init__(self, prediction_length: int, num_samples: int) -> None:
self.prediction_length = prediction_length
self.num_samples = num_samples

def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
ground_truth = kwargs.pop("ground_truth", None)
if ground_truth is None:
raise ValueError(
"Ground truth is required for OffsetOraclePredictor"
)
for item, label in zip(dataset, ground_truth):
yield self.predict_item(
item, ground_truth=label[FieldName.TARGET], **kwargs
)

def predict_item(
self,
item: DataEntry,
num_samples: int,
ground_truth=None,
offset: int = 1,
) -> Forecast:
forecast_start_time = forecast_start(item)
offset_ground_truth = np.roll(
ground_truth[-self.prediction_length :], shift=offset
)
samples = np.tile(offset_ground_truth, (num_samples, 1))
return SampleForecast(
samples=samples,
start_date=forecast_start_time,
item_id=item.get("item_id", None),
)
2 changes: 1 addition & 1 deletion src/gluonts/mx/block/snmlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def jacobian_sn_mlp_block_bf(
layers: List[Tuple[mx.gluon.HybridBlock, Tensor]]
layers: List[Tuple[mx.gluon.HybridBlock, Tensor]],
) -> Tensor:
"""
Brute force computation of the jacobian of a SNMlpBlock jac is of shape
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def union_dicts(
dicts: List[Dict[str, ModelConfig]]
dicts: List[Dict[str, ModelConfig]],
) -> Dict[str, List[ModelConfig]]:
"""
Merges the dicts by aggregating model configurations with the same key into
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/torch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def resolve_device(
device: Union[str, torch.device]
device: Union[str, torch.device],
) -> Union[str, torch.device]:
"""
Resolves a torch device to the most appropriate one.
Expand Down
Loading