| from typing import Any, List, Dict | |
| import torch | |
| from chronos import ChronosPipeline | |
| class EndpointHandler: | |
| def __init__(self) -> None: | |
| self.pipeline = ChronosPipeline.from_pretrained("amazon/chronos-t5-tiny") | |
| def __call__(self, data: Any) -> List[Dict[str, float]]: | |
| inputs = data.pop("inputs") | |
| # parameters = data.pop("parameters", {"prediction_length"}) | |
| forecast = self.pipeline.predict( | |
| torch.tensor(inputs["context"]), prediction_length=5 | |
| ) | |
| return {"forecast": forecast.tolist()} | |