Spaces:
Sleeping
Sleeping
| import argparse | |
| from pathlib import Path | |
| from hydra import compose, initialize | |
| from hydra.utils import instantiate | |
| from omegaconf import DictConfig, OmegaConf | |
| import torch | |
| from agent import Agent | |
| from envs import WorldModelEnv | |
| from game import Game, PlayEnv | |
| from utils import get_path_agent_ckpt | |
| OmegaConf.register_new_resolver("eval", eval) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint-dir", type=str, help="Path to training output directory") | |
| parser.add_argument("--epoch", type=int, default=-1, help="Epoch to load, -1 for latest checkpoint") | |
| parser.add_argument("-r", "--record", action="store_true", help="Record episodes in PlayEnv.") | |
| parser.add_argument("--store-denoising-trajectory", action="store_true", help="Save denoising steps in info.") | |
| parser.add_argument("--store-original-obs", action="store_true", help="Save original obs (pre resizing) in info.") | |
| parser.add_argument("--mouse-multiplier", type=int, default=10, help="Multiplication factor for the mouse movement.") | |
| parser.add_argument("--size-multiplier", type=int, default=2, help="Multiplication factor for the screen size.") | |
| parser.add_argument("--compile", action="store_true", help="Turn on model compilation.") | |
| parser.add_argument("--fps", type=int, default=15, help="Frame rate.") | |
| parser.add_argument("--no-header", action="store_true") | |
| return parser.parse_args() | |
| def check_args(args: argparse.Namespace) -> None: | |
| if not args.record and (args.store_denoising_trajectory or args.store_original_obs): | |
| print("Warning: not in recording mode, ignoring --store* options") | |
| return True | |
| def prepare_play_mode(cfg: DictConfig, args: argparse.Namespace) -> PlayEnv: | |
| #checkpoint_dir = Path(args.checkpoint_dir) | |
| # Load training config | |
| config_path = Path("/home/alienware3/Documents/diamond/config/trainer.yaml") | |
| if not config_path.exists(): | |
| raise FileNotFoundError(f"Training config not found: {config_path}") | |
| training_cfg = OmegaConf.load(config_path) | |
| # Override config | |
| cfg.agent = training_cfg.defaults[2].agent | |
| cfg.env = training_cfg.defaults[1].env | |
| cfg.world_model_env = training_cfg.defaults[3].world_model_env | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda:0") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| print("----------------------------------------------------------------------") | |
| print(f"Using {device} for rendering.") | |
| if not torch.cuda.is_available(): | |
| print("If you have a CUDA GPU available and it is not being used, please follow the instructions at https://pytorch.org/get-started/locally/ to reinstall torch with CUDA support and try again.") | |
| print("----------------------------------------------------------------------") | |
| # Get model checkpoint path | |
| ckpt_dir = "checkpoints" | |
| path_ckpt = Path("/home/alienware3/Documents/diamond/agent_epoch_00206.pt") # get_path_agent_ckpt(ckpt_dir, args.epoch) | |
| if not path_ckpt.exists(): | |
| agent_versions_dir = ckpt_dir / "agent_versions" | |
| if agent_versions_dir.exists(): | |
| available_ckpts = sorted(list(agent_versions_dir.glob("*.pt"))) | |
| if available_ckpts: | |
| path_ckpt = available_ckpts[-1] | |
| else: | |
| raise FileNotFoundError("No agent checkpoint files found") | |
| else: | |
| raise FileNotFoundError(f"Agent checkpoints directory not found: {agent_versions_dir}") | |
| spawn_dir = Path("/home/alienware3/Documents/diamond/csgo/spawn") | |
| assert cfg.env.train.id == "csgo" | |
| num_actions = cfg.env.num_actions | |
| # Models | |
| agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval() | |
| agent.load(path_ckpt) | |
| # World model environment | |
| sl = cfg.agent.denoiser.inner_model.num_steps_conditioning | |
| if agent.upsampler is not None: | |
| sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning) | |
| wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1) | |
| wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True) | |
| if device.type == "cuda" and args.compile: | |
| print("Compiling models...") | |
| wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead") | |
| if agent.upsampler is not None: | |
| wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead") | |
| play_env = PlayEnv( | |
| agent, | |
| wm_env, | |
| args.record, | |
| args.store_denoising_trajectory, | |
| args.store_original_obs, | |
| ) | |
| return play_env | |
| def main(): | |
| args = parse_args() | |
| ok = check_args(args) | |
| if not ok: | |
| return | |
| with initialize(version_base="1.3", config_path="../config"): | |
| cfg = compose(config_name="trainer") | |
| # window size | |
| h, w = (cfg.env.train.size,) * 2 if isinstance(cfg.env.train.size, int) else cfg.env.train.size | |
| size_h, size_w = h * args.size_multiplier, w * args.size_multiplier | |
| env = prepare_play_mode(cfg, args) | |
| game = Game(env, (size_h, size_w), args.mouse_multiplier, fps=args.fps, verbose=not args.no_header) | |
| game.run() | |
| if __name__ == "__main__": | |
| main() |