Spaces:
Sleeping
Sleeping
| from typing import Any, Dict, List, Tuple | |
| import torch | |
| from torch import Tensor | |
| from data import Dataset | |
| class DatasetEnv: | |
| def __init__(self, datasets: List[Dataset], action_names: List[str]) -> None: | |
| self.datasets = [d for d in datasets if len(d) > 0] | |
| assert len(self.datasets) > 0 | |
| self.action_names = action_names | |
| self.dataset_id = 0 | |
| self.dataset = self.datasets[0] | |
| self.episode_id = None | |
| self.episode = None | |
| self.t = None | |
| self.ep_return = None | |
| self.ep_length = None | |
| self.pos_return = None | |
| self.neg_return = None | |
| self.load_episode(0) | |
| def print_controls(self) -> None: | |
| print("\nControls (dataset mode):\n") | |
| print(f"m : datasets ({'/'.join([d.name for d in self.datasets])})") | |
| print("โ : next episode") | |
| print("โ : prev episode") | |
| print("โ : next timestep") | |
| print("โ : prev timestep") | |
| def next_mode(self) -> bool: | |
| self.switch_dataset() | |
| return True | |
| def next_axis_1(self) -> bool: | |
| self.load_episode(self.episode_id + 1) | |
| return True | |
| def prev_axis_1(self) -> bool: | |
| self.load_episode(self.episode_id - 1) | |
| return True | |
| def next_axis_2(self) -> bool: | |
| return False | |
| def prev_axis_2(self) -> bool: | |
| return False | |
| def load_episode(self, episode_id: int) -> None: | |
| self.episode_id = episode_id % self.dataset.num_episodes | |
| self.episode = self.dataset.load_episode(self.episode_id) | |
| self.set_timestep(0) | |
| metrics = self.episode.compute_metrics() | |
| self.ep_return = metrics["return"] | |
| self.ep_length = metrics["length"] | |
| self.pos_return = self.episode.rew[self.episode.rew > 0].sum().item() | |
| self.neg_return = self.episode.rew[self.episode.rew < 0].sum().abs().item() | |
| def set_timestep(self, timestep: int) -> None: | |
| self.t = timestep % len(self.episode) | |
| self.obs = self.episode.obs[self.t].unsqueeze(0) | |
| self.act = self.episode.act[self.t] | |
| self.rew = self.episode.rew[self.t] | |
| self.end = self.episode.end[self.t] | |
| self.trunc = self.episode.trunc[self.t] | |
| def switch_dataset(self) -> None: | |
| self.dataset_id = (self.dataset_id + 1) % len(self.datasets) | |
| self.dataset = self.datasets[self.dataset_id] | |
| self.load_episode(0) | |
| def reset(self) -> None: | |
| self.set_timestep(0) | |
| return self.obs, None | |
| def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]: | |
| # Replaced Python 3.10 `match` statement with if/elif chain for Python 3.8/3.9 compatibility | |
| if act == 1: | |
| self.set_timestep(self.t - 1) | |
| elif act == 2: | |
| self.set_timestep(self.t + 1) | |
| elif act == 3: | |
| self.set_timestep(self.t - 10) | |
| elif act == 4: | |
| self.set_timestep(self.t + 10) | |
| n_digits = len(str(self.ep_length)) | |
| header = [ | |
| [ | |
| f"Dataset: {self.dataset.name}", | |
| f"Episode: {self.episode_id}", | |
| "--------", | |
| f"Return (+): +{self.pos_return:4.1f}", | |
| f"Return (-): -{self.neg_return:4.1f}", | |
| f"Total : {self.ep_return:4.1f}", | |
| ], | |
| [ | |
| f"Action: {self.action_names[self.act]}", | |
| f"Trunc : {bool(self.trunc)}", | |
| f"Done : {bool(self.end)}", | |
| f"Reward: {self.rew.item():.2f}", | |
| "-------", | |
| f"To here: {self.episode.rew[:self.t + 1].sum().item():.2f}", | |
| f"To go : {self.episode.rew[self.t + 1:].sum().item():.2f}", | |
| ], | |
| [ | |
| f"Timestep: {self.t:{n_digits}d}", | |
| f"Length : {self.ep_length}", | |
| ], | |
| ] | |
| info = {"header": header} | |
| return self.obs, torch.tensor(0), False, False, info | |