Spaces:
Sleeping
Sleeping
| """ | |
| Web-based Diamond CSGO AI Player for Hugging Face Spaces | |
| Uses FastAPI + WebSocket for real-time keyboard input and game streaming | |
| """ | |
| # Fix environment variables FIRST, before any other imports | |
| import os | |
| import tempfile | |
| # Fix OMP_NUM_THREADS immediately (before PyTorch/NumPy imports) | |
| if "OMP_NUM_THREADS" not in os.environ or not os.environ.get("OMP_NUM_THREADS", "").isdigit(): | |
| os.environ["OMP_NUM_THREADS"] = "2" | |
| # Set up cache directories immediately | |
| temp_dir = tempfile.gettempdir() | |
| os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch")) | |
| os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface")) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers")) | |
| # Create cache directories | |
| for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]: | |
| cache_path = os.environ[cache_var] | |
| os.makedirs(cache_path, exist_ok=True) | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Set, Tuple | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from hydra import compose, initialize | |
| from hydra.utils import instantiate | |
| from omegaconf import DictConfig, OmegaConf | |
| from PIL import Image | |
| # Import your modules | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path for src package imports | |
| project_root = Path(__file__).parent | |
| if str(project_root) not in sys.path: | |
| sys.path.insert(0, str(project_root)) | |
| from src.agent import Agent | |
| from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names | |
| from src.envs import WorldModelEnv | |
| from src.game.web_play_env import WebPlayEnv | |
| from src.utils import extract_state_dict | |
| from config_web import web_config | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| app = FastAPI(title="Diamond CSGO AI Player") | |
| # Set safe defaults for headless CI/Spaces environments | |
| os.environ.setdefault("SDL_VIDEODRIVER", "dummy") | |
| os.environ.setdefault("SDL_AUDIODRIVER", "dummy") | |
| os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1") | |
| # Environment variables already set at top of file | |
| connected_clients: Set[WebSocket] = set() | |
| class WebKeyMap: | |
| """Map web key codes to pygame-like keys for CSGO actions""" | |
| WEB_TO_CSGO = { | |
| 'KeyW': 'w', | |
| 'KeyA': 'a', | |
| 'KeyS': 's', | |
| 'KeyD': 'd', | |
| 'Space': 'space', | |
| 'ControlLeft': 'left ctrl', | |
| 'ShiftLeft': 'left shift', | |
| 'Digit1': '1', | |
| 'Digit2': '2', | |
| 'Digit3': '3', | |
| 'KeyR': 'r', | |
| 'ArrowUp': 'camera_up', | |
| 'ArrowDown': 'camera_down', | |
| 'ArrowLeft': 'camera_left', | |
| 'ArrowRight': 'camera_right' | |
| } | |
| class WebGameEngine: | |
| """Web-compatible game engine that replaces pygame functionality""" | |
| def __init__(self): | |
| self.play_env: Optional[WebPlayEnv] = None | |
| self.obs = None | |
| self.running = False | |
| self.game_started = False | |
| # Allow runtime tuning via environment variables | |
| import os | |
| self.fps = int(os.getenv("DISPLAY_FPS", "30")) # Display FPS | |
| # Increase default AI inference FPS; can be overridden with AI_FPS env var | |
| self.ai_fps = int(os.getenv("AI_FPS", "15")) | |
| # Send every Nth frame to the browser (1 = send all frames) | |
| self.send_every = int(os.getenv("DISPLAY_SKIP", "1")) | |
| self.frame_count = 0 | |
| self.ai_frame_count = 0 | |
| self.last_ai_time = 0 | |
| self.start_time = 0 # Track when AI started for proper FPS calculation | |
| self.pressed_keys: Set[str] = set() | |
| self.mouse_x = 0 | |
| self.mouse_y = 0 | |
| self.l_click = False | |
| self.r_click = False | |
| self.should_reset = False | |
| self.cached_obs = None # Cache last observation for frame skipping | |
| self.first_inference_done = False # Track if first inference completed | |
| self.models_ready = False # Track if models are loaded | |
| self.download_progress = 0 # Track download progress (0-100) | |
| self.loading_status = "Initializing..." # Loading status message | |
| self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights | |
| import time | |
| self.time_module = time | |
| # Async inference queues to decouple GPU work from websocket I/O | |
| import asyncio | |
| self._in_queue: asyncio.Queue = asyncio.Queue(maxsize=1) | |
| self._out_queue: asyncio.Queue = asyncio.Queue(maxsize=1) | |
| # Flag to start worker once models are ready | |
| self._worker_started = False | |
| async def _load_model_from_url_async(self, agent, device): | |
| """Load model from URL using torch.hub (HF Spaces compatible)""" | |
| import asyncio | |
| import concurrent.futures | |
| def load_model_weights(): | |
| """Load model weights in thread pool to avoid blocking""" | |
| try: | |
| logger.info("Loading model using torch.hub.load_state_dict_from_url...") | |
| self.loading_status = "Downloading model..." | |
| self.download_progress = 10 | |
| model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt" | |
| # Use torch.hub to download and load state dict with custom cache dir | |
| logger.info(f"Loading state dict from {model_url}") | |
| # Set custom cache directory that we have write permissions for | |
| cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Use torch.hub with custom cache directory | |
| state_dict = torch.hub.load_state_dict_from_url( | |
| model_url, | |
| map_location=device, | |
| model_dir=cache_dir, | |
| check_hash=False # Skip hash check for faster loading | |
| ) | |
| self.download_progress = 60 | |
| self.loading_status = "Loading model weights into agent..." | |
| logger.info("State dict loaded, applying to agent...") | |
| # Check what components are in the state dict | |
| has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys()) | |
| has_denoiser = any(k.startswith('denoiser.') for k in state_dict.keys()) | |
| has_upsampler = any(k.startswith('upsampler.') for k in state_dict.keys()) | |
| logger.info(f"Model components found - actor_critic: {has_actor_critic}, denoiser: {has_denoiser}, upsampler: {has_upsampler}") | |
| # Load state dict into agent | |
| agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic) | |
| # Track if actor_critic was actually loaded with trained weights | |
| self.actor_critic_loaded = has_actor_critic | |
| # For HF Spaces demo, if no actor_critic, we can still show the world model | |
| if not has_actor_critic: | |
| logger.warning("No actor_critic weights found - world model will work but AI won't play") | |
| logger.info("Users can still interact and see the world model predictions") | |
| self.download_progress = 100 | |
| self.loading_status = "Model loaded successfully!" | |
| logger.info("All model weights loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| # Run in thread pool to avoid blocking with timeout | |
| loop = asyncio.get_event_loop() | |
| try: | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| # Add timeout for model loading (5 minutes max) | |
| future = loop.run_in_executor(executor, load_model_weights) | |
| success = await asyncio.wait_for(future, timeout=300.0) # 5 minute timeout | |
| return success | |
| except asyncio.TimeoutError: | |
| logger.error("Model loading timed out after 5 minutes") | |
| self.loading_status = "Model loading timed out - using dummy mode" | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error in model loading executor: {e}") | |
| self.loading_status = f"Model loading error: {str(e)[:50]}..." | |
| return False | |
| async def initialize_models(self): | |
| """Initialize the AI models and environment""" | |
| try: | |
| import torch | |
| logger.info("Initializing models...") | |
| # Setup environment and paths | |
| web_config.setup_environment_variables() | |
| web_config.create_default_configs() | |
| config_path = web_config.get_config_path() | |
| logger.info(f"Using config path: {config_path}") | |
| # Convert to relative path for Hydra | |
| import os | |
| relative_config_path = os.path.relpath(config_path) | |
| logger.info(f"Relative config path: {relative_config_path}") | |
| with initialize(version_base="1.3", config_path=relative_config_path): | |
| cfg = compose(config_name="trainer") | |
| # Override config for deployment | |
| cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml") | |
| cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml") | |
| # Use GPU if available, otherwise fall back to CPU | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Log GPU availability and CUDA info for debugging | |
| if torch.cuda.is_available(): | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| logger.info(f"GPU device count: {torch.cuda.device_count()}") | |
| logger.info(f"Current GPU: {torch.cuda.get_device_name(0)}") | |
| logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |
| logger.info("🚀 GPU acceleration enabled!") | |
| else: | |
| logger.info("CUDA not available, using CPU mode") | |
| # Initialize agent first | |
| num_actions = cfg.env.num_actions | |
| agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval() | |
| # Get spawn directory | |
| spawn_dir = web_config.get_spawn_dir() | |
| # Try to load checkpoint (remote first, then local, then dummy mode) | |
| try: | |
| # First try to load from Hugging Face Hub using torch.hub | |
| logger.info("Loading model from Hugging Face Hub with torch.hub...") | |
| success = await self._load_model_from_url_async(agent, device) | |
| if success: | |
| logger.info("Successfully loaded checkpoint from HF Hub") | |
| else: | |
| # Fallback to local checkpoint if available | |
| logger.error("Failed to load from HF Hub! Check the detailed error above.") | |
| checkpoint_path = web_config.get_checkpoint_path() | |
| if checkpoint_path.exists(): | |
| logger.info(f"Loading local checkpoint: {checkpoint_path}") | |
| self.loading_status = "Loading local checkpoint..." | |
| agent.load(checkpoint_path) | |
| logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}") | |
| # Assume local checkpoint has actor_critic weights (may need verification) | |
| self.actor_critic_loaded = True | |
| else: | |
| logger.error(f"No local checkpoint found at: {checkpoint_path}") | |
| raise FileNotFoundError("No model checkpoint available (local or remote)") | |
| except Exception as e: | |
| logger.error(f"Failed to load any checkpoint: {e}") | |
| self._init_dummy_mode() | |
| self.actor_critic_loaded = False # No actor_critic in dummy mode | |
| return True | |
| # Initialize world model environment | |
| try: | |
| sl = cfg.agent.denoiser.inner_model.num_steps_conditioning | |
| if agent.upsampler is not None: | |
| sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning) | |
| wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1) | |
| wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, | |
| spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True) | |
| # Create play environment | |
| self.play_env = WebPlayEnv(agent, wm_env, False, False, False) | |
| # Verify actor-critic is loaded and ready for inference | |
| if agent.actor_critic is not None and self.actor_critic_loaded: | |
| logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions") | |
| logger.info(f"Actor-critic device: {agent.actor_critic.device}") | |
| # Force AI control for web demo | |
| self.play_env.is_human_player = False | |
| logger.info("✅ WebPlayEnv set to AI control mode - ready for inference!") | |
| elif agent.actor_critic is not None and not self.actor_critic_loaded: | |
| logger.warning("⚠️ Actor-critic model exists but has no trained weights!") | |
| logger.info("🎮 Demo will work in world-model mode (human input + world simulation)") | |
| # Still allow human input to drive the world model | |
| self.play_env.is_human_player = True | |
| self.play_env.human_input_override = True # Always use human input | |
| logger.info("WebPlayEnv set to human control mode (no trained weights)") | |
| else: | |
| logger.warning("❌ No actor-critic model found - AI inference will not work!") | |
| self.play_env.is_human_player = True | |
| logger.info("WebPlayEnv set to human control mode (fallback)") | |
| # Set up cache directories for HF Spaces compatibility | |
| import os, pwd, tempfile | |
| try: | |
| pwd.getpwuid(os.getuid()) | |
| except KeyError: | |
| os.environ["USER"] = "huggingface" | |
| # Set writable cache directories for HF Spaces | |
| cache_dir = tempfile.gettempdir() | |
| os.environ.setdefault("TRITON_CACHE_DIR", os.path.join(cache_dir, "triton")) | |
| os.environ.setdefault("TORCH_COMPILE_DEBUG_DIR", os.path.join(cache_dir, "torch_compile")) | |
| # Create cache directories | |
| for cache_var in ["TRITON_CACHE_DIR", "TORCH_COMPILE_DEBUG_DIR"]: | |
| cache_path = os.environ[cache_var] | |
| os.makedirs(cache_path, exist_ok=True) | |
| # Enable torch.compile with proper error handling for HF Spaces | |
| # Check if we're on HF Spaces (common indicators) | |
| is_hf_spaces = any([ | |
| 'space_id' in os.environ, | |
| 'huggingface' in os.environ.get('USER', '').lower(), | |
| '/app' in os.getcwd() | |
| ]) | |
| # Enable compilation by default everywhere, including HF Spaces | |
| # Can disable with DISABLE_TORCH_COMPILE=1 if needed | |
| disable_compile = os.getenv("DISABLE_TORCH_COMPILE", "0") == "1" | |
| compile_enabled = (device.type == "cuda" and not disable_compile) | |
| if compile_enabled: | |
| logger.info("Compiling models for faster inference (like play.py --compile)...") | |
| try: | |
| wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead") | |
| if wm_env.upsample_next_obs is not None: | |
| wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead") | |
| logger.info("✅ Model compilation enabled successfully!") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Model compilation failed: {e}") | |
| logger.info("Continuing without model compilation...") | |
| else: | |
| if disable_compile: | |
| reason = "DISABLE_TORCH_COMPILE=1 set" | |
| else: | |
| reason = "no CUDA device available" | |
| logger.info(f"Model compilation disabled ({reason}). Models will run uncompiled.") | |
| # Reset environment | |
| self.obs, _ = self.play_env.reset() | |
| self.cached_obs = self.obs # Initialize cache | |
| logger.info("Models initialized successfully!") | |
| logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}") | |
| self.models_ready = True | |
| self.loading_status = "Ready!" | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize world model environment: {e}") | |
| self._init_dummy_mode() | |
| self.actor_critic_loaded = False # No actor_critic in dummy mode | |
| self.models_ready = True | |
| self.loading_status = "Using dummy mode" | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| self._init_dummy_mode() | |
| self.actor_critic_loaded = False # No actor_critic in dummy mode | |
| self.models_ready = True | |
| self.loading_status = "Error - using dummy mode" | |
| return True | |
| def _init_dummy_mode(self): | |
| """Initialize dummy mode for testing without models""" | |
| logger.info("Initializing dummy mode...") | |
| # Create a test observation | |
| height, width = 150, 600 | |
| img_array = np.zeros((height, width, 3), dtype=np.uint8) | |
| # Add test pattern | |
| for y in range(height): | |
| for x in range(width): | |
| img_array[y, x, 0] = (x % 256) # Red gradient | |
| img_array[y, x, 1] = (y % 256) # Green gradient | |
| img_array[y, x, 2] = ((x + y) % 256) # Blue pattern | |
| # Convert to torch tensor in expected format [-1, 1] | |
| tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format | |
| tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range | |
| tensor = tensor.unsqueeze(0) # Add batch dimension | |
| self.obs = tensor | |
| self.play_env = None # No real environment in dummy mode | |
| logger.info("Dummy mode initialized with test pattern") | |
| def step_environment(self): | |
| """Step the environment with current input state (with intelligent frame skipping)""" | |
| if self.play_env is None: | |
| # Dummy mode - just return current observation | |
| return self.obs, 0.0, False, False, {"mode": "dummy"} | |
| try: | |
| # Check if reset is requested | |
| if self.should_reset: | |
| self.reset_environment() | |
| self.should_reset = False | |
| self.last_ai_time = self.time_module.time() # Reset AI timer | |
| return self.obs, 0.0, False, False, {"reset": True} | |
| current_time = self.time_module.time() | |
| # Push task to inference queue if needed | |
| time_since_last_ai = current_time - self.last_ai_time | |
| should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps) | |
| if should_run_ai and self._in_queue.empty(): | |
| # Snapshot web input state | |
| web_state = dict( | |
| pressed_keys=set(self.pressed_keys), | |
| mouse_x=self.mouse_x, | |
| mouse_y=self.mouse_y, | |
| l_click=self.l_click, | |
| r_click=self.r_click, | |
| ) | |
| asyncio.create_task(self._in_queue.put((self.obs, web_state))) | |
| # Check for completed inference | |
| if not self._out_queue.empty(): | |
| (next_obs, reward, done, truncated, info, inference_time) = self._out_queue.get_nowait() | |
| # Log first inference completion | |
| if not self.first_inference_done: | |
| self.first_inference_done = True | |
| logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!") | |
| # Cache the new observation and update timing | |
| self.cached_obs = next_obs | |
| self.last_ai_time = current_time | |
| self.ai_frame_count += 1 | |
| # Add AI performance info | |
| info = info or {} | |
| info["ai_inference"] = True | |
| # Calculate proper AI FPS: frames / elapsed time since start | |
| elapsed_time = current_time - self.start_time | |
| if elapsed_time > 0 and self.ai_frame_count > 0: | |
| ai_fps = self.ai_frame_count / elapsed_time | |
| # Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference) | |
| info["ai_fps"] = min(ai_fps, 100.0) | |
| else: | |
| info["ai_fps"] = 0 | |
| info["inference_time"] = inference_time | |
| return next_obs, reward, done, truncated, info | |
| else: | |
| # Use cached observation for smoother display without AI overhead | |
| obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs | |
| # Calculate AI FPS for cached frames too | |
| elapsed_time = current_time - self.start_time | |
| if elapsed_time > 0 and self.ai_frame_count > 0: | |
| ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS | |
| else: | |
| ai_fps = 0 | |
| return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps} | |
| except Exception as e: | |
| logger.error(f"Error stepping environment: {e}") | |
| obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs | |
| return obs_to_return, 0.0, False, False, {"error": str(e)} | |
| def reset_environment(self): | |
| """Reset the environment""" | |
| try: | |
| if self.play_env is not None: | |
| self.obs, _ = self.play_env.reset() | |
| self.cached_obs = self.obs # Update cache | |
| logger.info("Environment reset successfully") | |
| else: | |
| # Dummy mode - recreate test pattern | |
| self._init_dummy_mode() | |
| self.cached_obs = self.obs # Update cache | |
| logger.info("Dummy environment reset") | |
| except Exception as e: | |
| logger.error(f"Error resetting environment: {e}") | |
| def request_reset(self): | |
| """Request environment reset on next step""" | |
| self.should_reset = True | |
| logger.info("Environment reset requested") | |
| def start_game(self): | |
| """Start the game""" | |
| self.game_started = True | |
| self.start_time = self.time_module.time() # Reset start time for FPS calculation | |
| self.ai_frame_count = 0 # Reset AI frame count | |
| logger.info("Game started") | |
| def pause_game(self): | |
| """Pause/stop the game""" | |
| self.game_started = False | |
| logger.info("Game paused") | |
| def obs_to_base64(self, obs: torch.Tensor) -> str: | |
| """Convert observation tensor to base64 image for web display""" | |
| if obs is None: | |
| return "" | |
| try: | |
| # Convert tensor to PIL Image | |
| if obs.ndim == 4 and obs.size(0) == 1: | |
| img_array = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy() | |
| else: | |
| img_array = obs.add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy() | |
| img = Image.fromarray(img_array) | |
| # Resize for web display to match canvas size (optimized) | |
| img = img.resize((600, 150), Image.NEAREST) | |
| # Choose codec via env var for flexibility (jpeg|png) | |
| codec = os.getenv("IMG_CODEC", "jpeg").lower() | |
| img_np = np.array(img)[:, :, ::-1] # RGB -> BGR | |
| if codec == "png": | |
| success, encoded_img = cv2.imencode('.png', img_np, [cv2.IMWRITE_PNG_COMPRESSION, 1]) | |
| mime = 'png' | |
| else: | |
| # JPEG with quality 70 for speed/size balance | |
| success, encoded_img = cv2.imencode('.jpg', img_np, [cv2.IMWRITE_JPEG_QUALITY, 70]) | |
| mime = 'jpeg' | |
| if not success: | |
| return "" | |
| img_str = base64.b64encode(encoded_img).decode() | |
| return f"data:image/{mime};base64,{img_str}" | |
| except Exception as e: | |
| logger.error(f"Error converting observation to base64: {e}") | |
| return "" | |
| # ------------------------------------------------------------------ | |
| # Faster binary encoder (JPEG/PNG) with OpenCV – no Pillow involved | |
| # ------------------------------------------------------------------ | |
| def obs_to_bytes(self, obs: torch.Tensor) -> Tuple[bytes, str]: | |
| """Return encoded image bytes and MIME (image/jpeg or image/png).""" | |
| if obs is None: | |
| return b"", "image/jpeg" | |
| try: | |
| # Keep operations on GPU as long as possible (like play.py) | |
| if obs.ndim == 4 and obs.size(0) == 1: | |
| img_tensor = obs[0] | |
| else: | |
| img_tensor = obs | |
| # Resize on GPU first (faster than CPU resize) | |
| img_tensor = torch.nn.functional.interpolate( | |
| img_tensor.unsqueeze(0), size=(75, 300), mode='nearest' | |
| ).squeeze(0) | |
| # Convert to uint8 on GPU, then transfer to CPU once | |
| img_np = (img_tensor.add(1).mul(127.5).clamp(0, 255).byte() | |
| .permute(1, 2, 0).contiguous().cpu().numpy()) # HWC uint8 | |
| # Encode with OpenCV | |
| import os | |
| codec = os.getenv("IMG_CODEC", "jpeg").lower() | |
| if codec == "png": | |
| ok, enc = cv2.imencode('.png', img_np, [cv2.IMWRITE_PNG_COMPRESSION, 1]) | |
| mime = "image/png" | |
| else: | |
| ok, enc = cv2.imencode('.jpg', img_np, [cv2.IMWRITE_JPEG_QUALITY, 75]) | |
| mime = "image/jpeg" | |
| if not ok: | |
| return b"", mime | |
| return enc.tobytes(), mime | |
| except Exception as e: | |
| logger.error(f"obs_to_bytes error: {e}") | |
| return b"", "image/jpeg" | |
| async def game_loop(self): | |
| """Main game loop that runs continuously""" | |
| self.running = True | |
| # Start inference worker once, when models are ready | |
| while self.running: | |
| loop_start_time = self.time_module.time() | |
| # Spawn worker lazily after models initialized | |
| if self.models_ready and not self._worker_started: | |
| asyncio.create_task(self._inference_worker()) | |
| self._worker_started = True | |
| try: | |
| # Check if models are ready | |
| if not self.models_ready: | |
| # Send loading status to clients | |
| if connected_clients: | |
| loading_data = { | |
| 'type': 'loading', | |
| 'status': self.loading_status, | |
| 'progress': self.download_progress, | |
| 'ready': False | |
| } | |
| disconnected = set() | |
| for client in connected_clients.copy(): | |
| try: | |
| await client.send_text(json.dumps(loading_data)) | |
| except: | |
| disconnected.add(client) | |
| connected_clients.difference_update(disconnected) | |
| await asyncio.sleep(0.5) # Check every 500ms during loading | |
| continue | |
| # Always send frames, but only step environment if game is started | |
| should_send_frame = True | |
| if not self.game_started: | |
| # Game not started - just send current observation without stepping | |
| if self.obs is not None and connected_clients: | |
| should_send_frame = True | |
| else: | |
| should_send_frame = False | |
| await asyncio.sleep(0.1) | |
| else: | |
| # Game is started - step environment | |
| if self.play_env is None: | |
| await asyncio.sleep(0.1) | |
| continue | |
| # Step environment with current input state | |
| next_obs, reward, done, truncated, info = self.step_environment() | |
| if done or truncated: | |
| # Auto-reset when episode ends | |
| self.reset_environment() | |
| else: | |
| self.obs = next_obs | |
| # Send frame to all connected clients (regardless of game state) | |
| if should_send_frame and connected_clients and self.obs is not None and (self.frame_count % self.send_every == 0): | |
| # Set default values for when game isn't running | |
| if not self.game_started: | |
| reward = 0.0 | |
| info = {"waiting": True} | |
| # If game is started, reward and info should be set above | |
| # Prefer binary frames if client agrees (feature flag) | |
| use_binary = os.getenv("BINARY_WS", "0") == "1" | |
| if use_binary: | |
| img_bytes, mime = self.obs_to_bytes(self.obs) | |
| meta = { | |
| 'type': 'frame_meta', | |
| 'mime': mime, | |
| 'frame_count': self.frame_count, | |
| 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0, | |
| 'info': str(info) if info else "", | |
| 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0, | |
| 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False | |
| } | |
| disconnected = set() | |
| for client in connected_clients.copy(): | |
| try: | |
| await client.send_text(json.dumps(meta)) | |
| await client.send_bytes(img_bytes) | |
| except: | |
| disconnected.add(client) | |
| connected_clients.difference_update(disconnected) | |
| else: | |
| # Fallback to base64 JSON | |
| image_data = self.obs_to_base64(self.obs) | |
| if self.frame_count < 5: | |
| logger.info( | |
| f"Frame {self.frame_count}: base64_len={len(image_data)} ai={info.get('ai_fps',0):.1f}") | |
| frame_data = { | |
| 'type': 'frame', | |
| 'image': image_data, | |
| 'frame_count': self.frame_count, | |
| 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0, | |
| 'info': str(info) if info else "", | |
| 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0, | |
| 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False | |
| } | |
| disconnected = set() | |
| for client in connected_clients.copy(): | |
| try: | |
| await client.send_text(json.dumps(frame_data)) | |
| except: | |
| disconnected.add(client) | |
| connected_clients.difference_update(disconnected) | |
| self.frame_count += 1 | |
| # Adaptive sleep so we don't waste idle time when GPU faster than display FPS | |
| loop_elapsed = self.time_module.time() - loop_start_time | |
| sleep_for = max((1.0 / self.fps) - loop_elapsed, 0) | |
| if sleep_for: | |
| await asyncio.sleep(sleep_for) | |
| except Exception as e: | |
| logger.error(f"Error in game loop: {e}") | |
| await asyncio.sleep(0.1) | |
| async def _inference_worker(self): | |
| """Runs AI inference in background to avoid blocking I/O.""" | |
| logger.info("Inference worker started") | |
| next_inference_time = self.time_module.time() | |
| while True: | |
| obs, web_state = await self._in_queue.get() | |
| # Timing control: maintain steady AI_FPS like play.py's clock.tick() | |
| now = self.time_module.time() | |
| if now < next_inference_time: | |
| await asyncio.sleep(next_inference_time - now) | |
| next_inference_time += 1.0 / self.ai_fps | |
| # Run inference directly in asyncio (not thread pool) with autocast for speed | |
| try: | |
| start = self.time_module.time() | |
| # Use FP16 autocast for faster inference (like play.py can do with modern GPUs) | |
| # Use newer autocast API to avoid deprecation warning | |
| import torch | |
| with torch.amp.autocast('cuda', dtype=torch.float16, enabled=torch.cuda.is_available()): | |
| res = self.play_env.step_from_web_input(**web_state) | |
| infer_t = self.time_module.time() - start | |
| await self._out_queue.put((*res, infer_t)) | |
| except Exception as e: | |
| logger.error(f"Inference worker error: {e}") | |
| import traceback | |
| logger.error(f"Full traceback: {traceback.format_exc()}") | |
| # Create a proper dummy result with correct tensor properties | |
| try: | |
| if self.obs is not None and hasattr(self.obs, 'shape') and hasattr(self.obs, 'device'): | |
| dummy_obs = self.obs.clone() | |
| else: | |
| # Fallback to a standard tensor on the right device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| dummy_obs = torch.zeros(1, 3, 150, 600, device=device) | |
| await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0)) | |
| except Exception as e2: | |
| logger.error(f"Error creating dummy result: {e2}") | |
| # Last resort - create CPU tensor | |
| dummy_obs = torch.zeros(1, 3, 150, 600) | |
| await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0)) | |
| # Global game engine instance | |
| game_engine = WebGameEngine() | |
| async def startup_event(): | |
| """Initialize models when the app starts""" | |
| # Start the game loop immediately (it will handle loading state) | |
| asyncio.create_task(game_engine.game_loop()) | |
| # Initialize models in background (non-blocking) | |
| asyncio.create_task(game_engine.initialize_models()) | |
| async def get_homepage(): | |
| """Serve the main game interface""" | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Physics-informed BEV World Model</title> | |
| <style> | |
| body { | |
| margin: 0; | |
| padding: 20px; | |
| background: #1a1a1a; | |
| color: white; | |
| font-family: 'Courier New', monospace; | |
| text-align: center; | |
| } | |
| #gameCanvas { | |
| border: 2px solid #00ff00; | |
| background: #000; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| #controls { | |
| margin: 20px; | |
| display: grid; | |
| grid-template-columns: 1fr 1fr; | |
| gap: 20px; | |
| max-width: 800px; | |
| margin: 20px auto; | |
| } | |
| .control-section { | |
| background: #2a2a2a; | |
| padding: 15px; | |
| border-radius: 8px; | |
| border: 1px solid #444; | |
| } | |
| .key-display { | |
| background: #333; | |
| border: 1px solid #555; | |
| padding: 5px 10px; | |
| margin: 2px; | |
| border-radius: 4px; | |
| display: inline-block; | |
| min-width: 30px; | |
| } | |
| .key-pressed { | |
| background: #00ff00; | |
| color: #000; | |
| } | |
| #status { | |
| margin: 10px; | |
| padding: 10px; | |
| background: #2a2a2a; | |
| border-radius: 4px; | |
| } | |
| .info { | |
| color: #00ff00; | |
| margin: 5px 0; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🎮 Physics-informed BEV World Model</h1> | |
| <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p> | |
| <p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p> | |
| <!-- Model Download Progress --> | |
| <div id="downloadSection" style="display: none; margin: 20px;"> | |
| <p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p> | |
| <div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;"> | |
| <div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div> | |
| </div> | |
| <p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p> | |
| </div> | |
| <canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas> | |
| <div id="status"> | |
| <div class="info">Status: <span id="connectionStatus">Connecting...</span></div> | |
| <div class="info">Game: <span id="gameStatus">Click to Start</span></div> | |
| <div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div> | |
| <div class="info">Reward: <span id="reward">0</span></div> | |
| </div> | |
| <div id="controls"> | |
| <div class="control-section"> | |
| <h3>Movement</h3> | |
| <div> | |
| <span class="key-display" id="key-w">W</span> Forward<br> | |
| <span class="key-display" id="key-a">A</span> Left | |
| <span class="key-display" id="key-s">S</span> Back | |
| <span class="key-display" id="key-d">D</span> Right<br> | |
| <span class="key-display" id="key-space">Space</span> Jump | |
| <span class="key-display" id="key-ctrl">Ctrl</span> Crouch | |
| <span class="key-display" id="key-shift">Shift</span> Walk | |
| </div> | |
| </div> | |
| <div class="control-section"> | |
| <h3>Actions</h3> | |
| <div> | |
| <span class="key-display" id="key-1">1</span> Weapon 1<br> | |
| <span class="key-display" id="key-2">2</span> Weapon 2 | |
| <span class="key-display" id="key-3">3</span> Weapon 3<br> | |
| <span class="key-display" id="key-r">R</span> Reload<br> | |
| <span class="key-display" id="key-arrows">↑↓←→</span> Camera<br> | |
| <span class="key-display" id="key-enter">Enter</span> Reset Game<br> | |
| <span class="key-display" id="key-esc">Esc</span> Pause/Quit | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const canvas = document.getElementById('gameCanvas'); | |
| const ctx = canvas.getContext('2d'); | |
| const statusEl = document.getElementById('connectionStatus'); | |
| const gameStatusEl = document.getElementById('gameStatus'); | |
| const frameEl = document.getElementById('frameCount'); | |
| const aiFpsEl = document.getElementById('aiFps'); | |
| const rewardEl = document.getElementById('reward'); | |
| const loadingEl = document.getElementById('loadingIndicator'); | |
| const downloadSectionEl = document.getElementById('downloadSection'); | |
| const downloadStatusEl = document.getElementById('downloadStatus'); | |
| const progressBarEl = document.getElementById('progressBar'); | |
| const progressTextEl = document.getElementById('progressText'); | |
| let ws = null; | |
| let pressedKeys = new Set(); | |
| let gameStarted = false; | |
| // Key mapping | |
| const keyDisplayMap = { | |
| 'KeyW': 'key-w', | |
| 'KeyA': 'key-a', | |
| 'KeyS': 'key-s', | |
| 'KeyD': 'key-d', | |
| 'Space': 'key-space', | |
| 'ControlLeft': 'key-ctrl', | |
| 'ShiftLeft': 'key-shift', | |
| 'Digit1': 'key-1', | |
| 'Digit2': 'key-2', | |
| 'Digit3': 'key-3', | |
| 'KeyR': 'key-r', | |
| 'ArrowUp': 'key-arrows', | |
| 'ArrowDown': 'key-arrows', | |
| 'ArrowLeft': 'key-arrows', | |
| 'ArrowRight': 'key-arrows', | |
| 'Enter': 'key-enter', | |
| 'Escape': 'key-esc' | |
| }; | |
| function connectWebSocket() { | |
| const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
| const wsUrl = `${protocol}//${window.location.host}/ws`; | |
| ws = new WebSocket(wsUrl); | |
| ws.onopen = function(event) { | |
| statusEl.textContent = 'Connected'; | |
| statusEl.style.color = '#00ff00'; | |
| // If user already clicked to start before WS was ready, send start now | |
| if (gameStarted) { | |
| ws.send(JSON.stringify({ type: 'start' })); | |
| } | |
| }; | |
| ws.onmessage = function(event) { | |
| const data = JSON.parse(event.data); | |
| if (data.type === 'loading') { | |
| // Handle loading status | |
| downloadSectionEl.style.display = 'block'; | |
| downloadStatusEl.textContent = data.status; | |
| if (data.progress !== undefined) { | |
| progressBarEl.style.width = data.progress + '%'; | |
| progressTextEl.textContent = data.progress + '% - ' + data.status; | |
| } else { | |
| progressTextEl.textContent = data.status; | |
| } | |
| gameStatusEl.textContent = 'Loading Models...'; | |
| gameStatusEl.style.color = '#ffaa00'; | |
| } else if (data.type === 'frame') { | |
| // Hide loading indicators once we get frames | |
| downloadSectionEl.style.display = 'none'; | |
| // Update frame display | |
| if (data.image) { | |
| const img = new Image(); | |
| img.onload = function() { | |
| ctx.clearRect(0, 0, canvas.width, canvas.height); | |
| ctx.drawImage(img, 0, 0, canvas.width, canvas.height); | |
| }; | |
| img.src = data.image; | |
| } | |
| frameEl.textContent = data.frame_count; | |
| rewardEl.textContent = data.reward.toFixed(2); | |
| // Update AI FPS display and hide loading indicator once AI starts | |
| if (data.ai_fps !== undefined && data.ai_fps !== null) { | |
| // Ensure FPS value is reasonable | |
| const aiFps = Math.min(Math.max(data.ai_fps, 0), 100); | |
| aiFpsEl.textContent = aiFps.toFixed(1); | |
| // Color code AI FPS for performance indication | |
| if (aiFps >= 8) { | |
| aiFpsEl.style.color = '#00ff00'; // Green for good performance | |
| } else if (aiFps >= 5) { | |
| aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance | |
| } else if (aiFps > 0) { | |
| aiFpsEl.style.color = '#ff0000'; // Red for poor performance | |
| } else { | |
| aiFpsEl.style.color = '#888888'; // Gray for inactive | |
| } | |
| // Hide loading indicator once AI inference starts working | |
| if (aiFps > 0 && gameStarted) { | |
| loadingEl.style.display = 'none'; | |
| gameStatusEl.textContent = 'Playing'; | |
| gameStatusEl.style.color = '#00ff00'; | |
| } | |
| } | |
| } | |
| }; | |
| ws.onclose = function(event) { | |
| statusEl.textContent = 'Disconnected'; | |
| statusEl.style.color = '#ff0000'; | |
| setTimeout(connectWebSocket, 1000); // Reconnect after 1 second | |
| }; | |
| ws.onerror = function(event) { | |
| statusEl.textContent = 'Error'; | |
| statusEl.style.color = '#ff0000'; | |
| }; | |
| } | |
| function sendKeyState() { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'keys', | |
| keys: Array.from(pressedKeys) | |
| })); | |
| } | |
| } | |
| function startGame() { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'start' | |
| })); | |
| gameStarted = true; | |
| gameStatusEl.textContent = 'Starting AI...'; | |
| gameStatusEl.style.color = '#ffff00'; | |
| loadingEl.style.display = 'block'; | |
| console.log('Game started'); | |
| } | |
| } | |
| function pauseGame() { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'pause' | |
| })); | |
| gameStarted = false; | |
| gameStatusEl.textContent = 'Paused - Click to Resume'; | |
| gameStatusEl.style.color = '#ffff00'; | |
| console.log('Game paused'); | |
| } | |
| } | |
| function updateKeyDisplay() { | |
| // Reset all key displays | |
| Object.values(keyDisplayMap).forEach(id => { | |
| const el = document.getElementById(id); | |
| if (el) el.classList.remove('key-pressed'); | |
| }); | |
| // Highlight pressed keys | |
| pressedKeys.forEach(key => { | |
| const displayId = keyDisplayMap[key]; | |
| if (displayId) { | |
| const el = document.getElementById(displayId); | |
| if (el) el.classList.add('key-pressed'); | |
| } | |
| }); | |
| } | |
| // Focus canvas and handle keyboard events | |
| canvas.addEventListener('click', () => { | |
| canvas.focus(); | |
| if (!gameStarted) { | |
| // Queue start locally and send immediately if WS is open | |
| gameStarted = true; | |
| gameStatusEl.textContent = 'Starting AI...'; | |
| gameStatusEl.style.color = '#ffff00'; | |
| loadingEl.style.display = 'block'; | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ type: 'start' })); | |
| } | |
| } | |
| }); | |
| canvas.addEventListener('keydown', (event) => { | |
| event.preventDefault(); | |
| // Handle special keys | |
| if (event.code === 'Enter') { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'reset' | |
| })); | |
| console.log('Environment reset requested'); | |
| } | |
| // Add to pressedKeys for visual feedback | |
| pressedKeys.add(event.code); | |
| updateKeyDisplay(); | |
| // Remove Enter from pressedKeys after a short delay for visual feedback | |
| setTimeout(() => { | |
| pressedKeys.delete(event.code); | |
| updateKeyDisplay(); | |
| }, 200); | |
| } else if (event.code === 'Escape') { | |
| pauseGame(); | |
| // Add to pressedKeys for visual feedback | |
| pressedKeys.add(event.code); | |
| updateKeyDisplay(); | |
| // Remove ESC from pressedKeys after a short delay for visual feedback | |
| setTimeout(() => { | |
| pressedKeys.delete(event.code); | |
| updateKeyDisplay(); | |
| }, 200); | |
| } else { | |
| // Only send game keys if game is started | |
| if (gameStarted) { | |
| pressedKeys.add(event.code); | |
| updateKeyDisplay(); | |
| sendKeyState(); | |
| } | |
| } | |
| }); | |
| canvas.addEventListener('keyup', (event) => { | |
| event.preventDefault(); | |
| // Don't handle special keys release (handled in keydown with timeout) | |
| if (event.code !== 'Enter' && event.code !== 'Escape') { | |
| if (gameStarted) { | |
| pressedKeys.delete(event.code); | |
| updateKeyDisplay(); | |
| sendKeyState(); | |
| } | |
| } | |
| }); | |
| // Handle mouse events for clicks | |
| canvas.addEventListener('mousedown', (event) => { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'mouse', | |
| button: event.button, | |
| action: 'down', | |
| x: event.offsetX, | |
| y: event.offsetY | |
| })); | |
| } | |
| }); | |
| canvas.addEventListener('mouseup', (event) => { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'mouse', | |
| button: event.button, | |
| action: 'up', | |
| x: event.offsetX, | |
| y: event.offsetY | |
| })); | |
| } | |
| }); | |
| // Initialize | |
| connectWebSocket(); | |
| canvas.focus(); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html_content | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """Handle WebSocket connections for real-time game communication""" | |
| await websocket.accept() | |
| connected_clients.add(websocket) | |
| try: | |
| while True: | |
| # Receive messages from client | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| if message['type'] == 'keys': | |
| # Update pressed keys | |
| game_engine.pressed_keys = set(message['keys']) | |
| elif message['type'] == 'reset': | |
| # Handle environment reset request | |
| game_engine.request_reset() | |
| elif message['type'] == 'start': | |
| # Handle game start request | |
| game_engine.start_game() | |
| elif message['type'] == 'pause': | |
| # Handle game pause request | |
| game_engine.pause_game() | |
| elif message['type'] == 'mouse': | |
| # Handle mouse events | |
| if message['action'] == 'down': | |
| if message['button'] == 0: # Left click | |
| game_engine.l_click = True | |
| elif message['button'] == 2: # Right click | |
| game_engine.r_click = True | |
| elif message['action'] == 'up': | |
| if message['button'] == 0: # Left click | |
| game_engine.l_click = False | |
| elif message['button'] == 2: # Right click | |
| game_engine.r_click = False | |
| # Update mouse position (relative to canvas) | |
| game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px | |
| game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px | |
| except WebSocketDisconnect: | |
| connected_clients.discard(websocket) | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| connected_clients.discard(websocket) | |
| if __name__ == "__main__": | |
| # For local development | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |