PIWM / src /game /dataset_env.py
musictimer's picture
Fix bug 10
b8159f9
raw
history blame
3.98 kB
from typing import Any, Dict, List, Tuple
import torch
from torch import Tensor
from data import Dataset
class DatasetEnv:
def __init__(self, datasets: List[Dataset], action_names: List[str]) -> None:
self.datasets = [d for d in datasets if len(d) > 0]
assert len(self.datasets) > 0
self.action_names = action_names
self.dataset_id = 0
self.dataset = self.datasets[0]
self.episode_id = None
self.episode = None
self.t = None
self.ep_return = None
self.ep_length = None
self.pos_return = None
self.neg_return = None
self.load_episode(0)
def print_controls(self) -> None:
print("\nControls (dataset mode):\n")
print(f"m : datasets ({'/'.join([d.name for d in self.datasets])})")
print("โ†‘ : next episode")
print("โ†“ : prev episode")
print("โ†’ : next timestep")
print("โ† : prev timestep")
def next_mode(self) -> bool:
self.switch_dataset()
return True
def next_axis_1(self) -> bool:
self.load_episode(self.episode_id + 1)
return True
def prev_axis_1(self) -> bool:
self.load_episode(self.episode_id - 1)
return True
def next_axis_2(self) -> bool:
return False
def prev_axis_2(self) -> bool:
return False
def load_episode(self, episode_id: int) -> None:
self.episode_id = episode_id % self.dataset.num_episodes
self.episode = self.dataset.load_episode(self.episode_id)
self.set_timestep(0)
metrics = self.episode.compute_metrics()
self.ep_return = metrics["return"]
self.ep_length = metrics["length"]
self.pos_return = self.episode.rew[self.episode.rew > 0].sum().item()
self.neg_return = self.episode.rew[self.episode.rew < 0].sum().abs().item()
def set_timestep(self, timestep: int) -> None:
self.t = timestep % len(self.episode)
self.obs = self.episode.obs[self.t].unsqueeze(0)
self.act = self.episode.act[self.t]
self.rew = self.episode.rew[self.t]
self.end = self.episode.end[self.t]
self.trunc = self.episode.trunc[self.t]
def switch_dataset(self) -> None:
self.dataset_id = (self.dataset_id + 1) % len(self.datasets)
self.dataset = self.datasets[self.dataset_id]
self.load_episode(0)
def reset(self) -> None:
self.set_timestep(0)
return self.obs, None
@torch.no_grad()
def step(self, act: int) -> Tuple[Tensor, Tensor, bool, bool, Dict[str, Any]]:
# Replaced Python 3.10 `match` statement with if/elif chain for Python 3.8/3.9 compatibility
if act == 1:
self.set_timestep(self.t - 1)
elif act == 2:
self.set_timestep(self.t + 1)
elif act == 3:
self.set_timestep(self.t - 10)
elif act == 4:
self.set_timestep(self.t + 10)
n_digits = len(str(self.ep_length))
header = [
[
f"Dataset: {self.dataset.name}",
f"Episode: {self.episode_id}",
"--------",
f"Return (+): +{self.pos_return:4.1f}",
f"Return (-): -{self.neg_return:4.1f}",
f"Total : {self.ep_return:4.1f}",
],
[
f"Action: {self.action_names[self.act]}",
f"Trunc : {bool(self.trunc)}",
f"Done : {bool(self.end)}",
f"Reward: {self.rew.item():.2f}",
"-------",
f"To here: {self.episode.rew[:self.t + 1].sum().item():.2f}",
f"To go : {self.episode.rew[self.t + 1:].sum().item():.2f}",
],
[
f"Timestep: {self.t:{n_digits}d}",
f"Length : {self.ep_length}",
],
]
info = {"header": header}
return self.obs, torch.tensor(0), False, False, info