""" Web-compatible PlayEnv that handles web input and AI inference """ from typing import Any, Dict, List, Set, Tuple import torch from torch import Tensor from torch.distributions.categorical import Categorical import torch.nn as nn import torch.nn.functional as F from ..agent import Agent from ..envs import WorldModelEnv from ..csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names, encode_web_csgo_action from .play_env import PlayEnv class WebPlayEnv(PlayEnv): """Web-compatible version of PlayEnv that handles web input and AI inference""" def __init__( self, agent: Agent, wm_env: WorldModelEnv, recording_mode: bool, store_denoising_trajectory: bool, store_original_obs: bool, ) -> None: super().__init__(agent, wm_env, recording_mode, store_denoising_trajectory, store_original_obs) # For web demo, we want AI control by default self.is_human_player = False # AI controls the actions self.human_input_override = False # Can be set to True to allow human input # Initialize LSTM hidden states for actor-critic (only if actor_critic exists) if agent.actor_critic is not None: self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device) self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device) else: self.hx = None self.cx = None def switch_controller(self) -> None: """Switch between AI and human control""" self.is_human_player = not self.is_human_player print(f"Switched to {'human' if self.is_human_player else 'AI'} control") def str_control(self) -> str: """Return control mode string""" if self.human_input_override: return "Human Override" return "Human" if self.is_human_player else "AI" @torch.no_grad() def step_from_web_input( self, pressed_keys: Set[str], mouse_x: float, mouse_y: float, l_click: bool, r_click: bool, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Dict[str, Any]]: """ Step the environment with web input. If AI mode is enabled, use AI inference. If human mode or override, use human input. """ # Convert web keys to action names action_names = web_keys_to_csgo_action_names(pressed_keys) # Create web CSGO action from input web_action = WebCSGOAction( key_names=action_names, mouse_x=mouse_x, mouse_y=mouse_y, l_click=l_click, r_click=r_click ) # Ensure we have a valid observation; if not, reset the environment if self.obs is None: try: self.obs, _ = self.reset() except Exception: # If reset fails, fall back to human input below pass # If we have human input override or in human mode, use human input if self.human_input_override or self.is_human_player: # Encode the web action to tensor format action = encode_web_csgo_action(web_action, device=self.agent.device) else: # AI mode - use the agent's actor-critic to predict the action try: # Get current observation (ensure it has batch dimension) obs = self.obs if obs.ndim == 3: # CHW -> BCHW obs = obs.unsqueeze(0) # Ensure obs is on the same device as the models if obs.device != self.agent.device: obs = obs.to(self.agent.device, non_blocking=True) # Detach hidden states to prevent gradient tracking (only if they exist) if self.hx is not None: self.hx = self.hx.detach() if self.cx is not None: self.cx = self.cx.detach() # Resize observation to match actor-critic expected encoder/LSTM input # Count how many MaxPool2d layers are in the encoder to infer downsampling factor if hasattr(self.agent, "actor_critic") and self.agent.actor_critic is not None: try: n_pools = sum( 1 for m in self.agent.actor_critic.encoder.encoder if isinstance(m, nn.MaxPool2d) ) # We want the spatial size after the encoder to be 1x1 so that # flattening matches the LSTM input size configured at init time. # With n_pools halvings, input size must be 2**n_pools. target_hw = 2 ** n_pools if n_pools > 0 else min(int(obs.size(-2)), int(obs.size(-1))) if obs.size(-2) != target_hw or obs.size(-1) != target_hw: obs = F.interpolate( obs, size=(target_hw, target_hw), mode="bilinear", align_corners=False ) except Exception: # If anything goes wrong in the shape logic, fall back without resizing pass # Get action logits and value from actor-critic logits_act, value, (self.hx, self.cx) = self.agent.actor_critic.predict_act_value(obs, (self.hx, self.cx)) # Sample action from logits action_dist = Categorical(logits=logits_act) action = action_dist.sample() # Convert to proper shape (remove batch dimension if needed) if action.ndim > 0 and action.size(0) == 1: action = action.squeeze(0) except Exception as e: print(f"AI inference failed: {e}") import traceback traceback.print_exc() # Fallback to human input if AI fails action = encode_web_csgo_action(web_action, device=self.agent.device) # Step the environment with the chosen action next_obs, rew, end, trunc, env_info = self.env.step(action) # Update internal state self.obs = next_obs self.t += 1 # Reset hidden states on episode end (only if they exist) if end.any() or trunc.any(): if self.hx is not None: self.hx.zero_() if self.cx is not None: self.cx.zero_() # Return the step results return next_obs, rew, end, trunc, env_info