PIWM / src /play_trained.py
musictimer's picture
Initial Diamond CSGO AI deployment
c64c726
raw
history blame
5.44 kB
import argparse
from pathlib import Path
from hydra import compose, initialize
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import torch
from agent import Agent
from envs import WorldModelEnv
from game import Game, PlayEnv
from utils import get_path_agent_ckpt
OmegaConf.register_new_resolver("eval", eval)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-dir", type=str, help="Path to training output directory")
parser.add_argument("--epoch", type=int, default=-1, help="Epoch to load, -1 for latest checkpoint")
parser.add_argument("-r", "--record", action="store_true", help="Record episodes in PlayEnv.")
parser.add_argument("--store-denoising-trajectory", action="store_true", help="Save denoising steps in info.")
parser.add_argument("--store-original-obs", action="store_true", help="Save original obs (pre resizing) in info.")
parser.add_argument("--mouse-multiplier", type=int, default=10, help="Multiplication factor for the mouse movement.")
parser.add_argument("--size-multiplier", type=int, default=2, help="Multiplication factor for the screen size.")
parser.add_argument("--compile", action="store_true", help="Turn on model compilation.")
parser.add_argument("--fps", type=int, default=15, help="Frame rate.")
parser.add_argument("--no-header", action="store_true")
return parser.parse_args()
def check_args(args: argparse.Namespace) -> None:
if not args.record and (args.store_denoising_trajectory or args.store_original_obs):
print("Warning: not in recording mode, ignoring --store* options")
return True
def prepare_play_mode(cfg: DictConfig, args: argparse.Namespace) -> PlayEnv:
#checkpoint_dir = Path(args.checkpoint_dir)
# Load training config
config_path = Path("/home/alienware3/Documents/diamond/config/trainer.yaml")
if not config_path.exists():
raise FileNotFoundError(f"Training config not found: {config_path}")
training_cfg = OmegaConf.load(config_path)
# Override config
cfg.agent = training_cfg.defaults[2].agent
cfg.env = training_cfg.defaults[1].env
cfg.world_model_env = training_cfg.defaults[3].world_model_env
if torch.cuda.is_available():
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print("----------------------------------------------------------------------")
print(f"Using {device} for rendering.")
if not torch.cuda.is_available():
print("If you have a CUDA GPU available and it is not being used, please follow the instructions at https://pytorch.org/get-started/locally/ to reinstall torch with CUDA support and try again.")
print("----------------------------------------------------------------------")
# Get model checkpoint path
ckpt_dir = "checkpoints"
path_ckpt = Path("/home/alienware3/Documents/diamond/agent_epoch_00206.pt") # get_path_agent_ckpt(ckpt_dir, args.epoch)
if not path_ckpt.exists():
agent_versions_dir = ckpt_dir / "agent_versions"
if agent_versions_dir.exists():
available_ckpts = sorted(list(agent_versions_dir.glob("*.pt")))
if available_ckpts:
path_ckpt = available_ckpts[-1]
else:
raise FileNotFoundError("No agent checkpoint files found")
else:
raise FileNotFoundError(f"Agent checkpoints directory not found: {agent_versions_dir}")
spawn_dir = Path("/home/alienware3/Documents/diamond/csgo/spawn")
assert cfg.env.train.id == "csgo"
num_actions = cfg.env.num_actions
# Models
agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
agent.load(path_ckpt)
# World model environment
sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
if agent.upsampler is not None:
sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
if device.type == "cuda" and args.compile:
print("Compiling models...")
wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
if agent.upsampler is not None:
wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
play_env = PlayEnv(
agent,
wm_env,
args.record,
args.store_denoising_trajectory,
args.store_original_obs,
)
return play_env
@torch.no_grad()
def main():
args = parse_args()
ok = check_args(args)
if not ok:
return
with initialize(version_base="1.3", config_path="../config"):
cfg = compose(config_name="trainer")
# window size
h, w = (cfg.env.train.size,) * 2 if isinstance(cfg.env.train.size, int) else cfg.env.train.size
size_h, size_w = h * args.size_multiplier, w * args.size_multiplier
env = prepare_play_mode(cfg, args)
game = Game(env, (size_h, size_w), args.mouse_multiplier, fps=args.fps, verbose=not args.no_header)
game.run()
if __name__ == "__main__":
main()