Spaces:
Sleeping
Sleeping
| """ | |
| Web-based Diamond CSGO AI Player for Hugging Face Spaces | |
| Uses FastAPI + WebSocket for real-time keyboard input and game streaming | |
| """ | |
| import asyncio | |
| import base64 | |
| import io | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Set | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from hydra import compose, initialize | |
| from hydra.utils import instantiate | |
| from omegaconf import DictConfig, OmegaConf | |
| from PIL import Image | |
| # Import your modules | |
| from src.agent import Agent | |
| from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names | |
| from src.envs import WorldModelEnv | |
| from src.game.web_play_env import WebPlayEnv | |
| from config_web import web_config | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| app = FastAPI(title="Diamond CSGO AI Player") | |
| connected_clients: Set[WebSocket] = set() | |
| class WebKeyMap: | |
| """Map web key codes to pygame-like keys for CSGO actions""" | |
| WEB_TO_CSGO = { | |
| 'KeyW': 'w', | |
| 'KeyA': 'a', | |
| 'KeyS': 's', | |
| 'KeyD': 'd', | |
| 'Space': 'space', | |
| 'ControlLeft': 'left ctrl', | |
| 'ShiftLeft': 'left shift', | |
| 'Digit1': '1', | |
| 'Digit2': '2', | |
| 'Digit3': '3', | |
| 'KeyR': 'r', | |
| 'ArrowUp': 'camera_up', | |
| 'ArrowDown': 'camera_down', | |
| 'ArrowLeft': 'camera_left', | |
| 'ArrowRight': 'camera_right' | |
| } | |
| class WebGameEngine: | |
| """Web-compatible game engine that replaces pygame functionality""" | |
| def __init__(self): | |
| self.play_env: Optional[WebPlayEnv] = None | |
| self.obs = None | |
| self.running = False | |
| self.game_started = False | |
| self.fps = 30 # Display FPS | |
| self.ai_fps = 10 # AI inference FPS (slower than display for efficiency) | |
| self.frame_count = 0 | |
| self.ai_frame_count = 0 | |
| self.last_ai_time = 0 | |
| self.start_time = 0 # Track when AI started for proper FPS calculation | |
| self.pressed_keys: Set[str] = set() | |
| self.mouse_x = 0 | |
| self.mouse_y = 0 | |
| self.l_click = False | |
| self.r_click = False | |
| self.should_reset = False | |
| self.cached_obs = None # Cache last observation for frame skipping | |
| self.first_inference_done = False # Track if first inference completed | |
| self.models_ready = False # Track if models are loaded | |
| self.download_progress = 0 # Track download progress (0-100) | |
| self.loading_status = "Initializing..." # Loading status message | |
| import time | |
| self.time_module = time | |
| async def _download_model_async(self, url, filepath): | |
| """Download model asynchronously with progress tracking""" | |
| import asyncio | |
| import concurrent.futures | |
| import urllib.request | |
| import os | |
| def download_with_progress(): | |
| """Download function that runs in thread pool""" | |
| def progress_hook(block_num, block_size, total_size): | |
| if total_size > 0: | |
| progress = min(100, (block_num * block_size * 100) / total_size) | |
| self.download_progress = int(progress) | |
| if progress % 10 == 0: # Log every 10% | |
| logger.info(f"Download progress: {self.download_progress}%") | |
| urllib.request.urlretrieve(url, filepath, reporthook=progress_hook) | |
| self.download_progress = 100 | |
| # Run download in thread pool to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| await loop.run_in_executor(executor, download_with_progress) | |
| logger.info("Model download completed!") | |
| async def initialize_models(self): | |
| """Initialize the AI models and environment""" | |
| try: | |
| import torch | |
| logger.info("Initializing models...") | |
| # Setup environment and paths | |
| web_config.setup_environment_variables() | |
| web_config.create_default_configs() | |
| config_path = web_config.get_config_path() | |
| logger.info(f"Using config path: {config_path}") | |
| # Convert to relative path for Hydra | |
| import os | |
| relative_config_path = os.path.relpath(config_path) | |
| logger.info(f"Relative config path: {relative_config_path}") | |
| with initialize(version_base="1.3", config_path=relative_config_path): | |
| cfg = compose(config_name="trainer") | |
| # Override config for deployment | |
| cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml") | |
| cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml") | |
| # Use CPU if no GPU available (for free HF spaces) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Load model checkpoint | |
| checkpoint_path = web_config.get_checkpoint_path() | |
| if not checkpoint_path.exists(): | |
| logger.warning(f"No checkpoint found at {checkpoint_path} - using dummy mode") | |
| self._init_dummy_mode() | |
| return True | |
| # Get spawn directory | |
| spawn_dir = web_config.get_spawn_dir() | |
| # Initialize agent | |
| num_actions = cfg.env.num_actions | |
| agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval() | |
| # Try to load checkpoint (remote or local) | |
| try: | |
| # First try to download from Hugging Face Hub using direct URL | |
| try: | |
| import torch.hub | |
| import os | |
| logger.info("Downloading model from Hugging Face Hub...") | |
| # Direct download URL (change 'blob' to 'resolve' for direct download) | |
| model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt" | |
| # Download to cache directory | |
| cache_dir = "./cache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| model_cache_path = os.path.join(cache_dir, "agent_epoch_00003.pt") | |
| # Download if not cached | |
| if not os.path.exists(model_cache_path): | |
| logger.info(f"Downloading 1.53GB model to {model_cache_path}...") | |
| self.loading_status = "Downloading AI model from Hugging Face Hub..." | |
| # Download with progress tracking in a separate thread | |
| await self._download_model_async(model_url, model_cache_path) | |
| else: | |
| logger.info(f"Using cached model from {model_cache_path}") | |
| self.loading_status = "Loading cached model..." | |
| # Use the agent's load method which expects a file path | |
| self.loading_status = "Loading model weights..." | |
| agent.load(model_cache_path) | |
| logger.info(f"Successfully loaded checkpoint from HF Hub") | |
| except Exception as hub_error: | |
| logger.warning(f"Failed to download from HF Hub: {hub_error}") | |
| # Fallback to local checkpoint if available | |
| if checkpoint_path.exists(): | |
| logger.info(f"Falling back to local checkpoint: {checkpoint_path}") | |
| agent.load(checkpoint_path) | |
| logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}") | |
| else: | |
| raise FileNotFoundError("No model checkpoint available (local or remote)") | |
| except Exception as e: | |
| logger.error(f"Failed to load any checkpoint: {e}") | |
| self._init_dummy_mode() | |
| return True | |
| # Initialize world model environment | |
| try: | |
| sl = cfg.agent.denoiser.inner_model.num_steps_conditioning | |
| if agent.upsampler is not None: | |
| sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning) | |
| wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1) | |
| wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model, | |
| spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True) | |
| # Create play environment | |
| self.play_env = WebPlayEnv(agent, wm_env, False, False, False) | |
| # Model compilation causes 10-30s delay on first inference, so make it optional | |
| # You can enable it by setting ENABLE_TORCH_COMPILE=1 environment variable | |
| import os | |
| if device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "0") == "1": | |
| logger.info("Compiling models for faster inference (will cause delay on first inference)...") | |
| try: | |
| wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead") | |
| if wm_env.upsample_next_obs is not None: | |
| wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead") | |
| logger.info("Model compilation enabled successfully!") | |
| except Exception as e: | |
| logger.warning(f"Model compilation failed: {e}") | |
| else: | |
| logger.info("Model compilation disabled (faster startup). Set ENABLE_TORCH_COMPILE=1 to enable.") | |
| # Reset environment | |
| self.obs, _ = self.play_env.reset() | |
| self.cached_obs = self.obs # Initialize cache | |
| logger.info("Models initialized successfully!") | |
| logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}") | |
| self.models_ready = True | |
| self.loading_status = "Ready!" | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize world model environment: {e}") | |
| self._init_dummy_mode() | |
| self.models_ready = True | |
| self.loading_status = "Using dummy mode" | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| self._init_dummy_mode() | |
| self.models_ready = True | |
| self.loading_status = "Error - using dummy mode" | |
| return True | |
| def _init_dummy_mode(self): | |
| """Initialize dummy mode for testing without models""" | |
| logger.info("Initializing dummy mode...") | |
| # Create a test observation | |
| height, width = 150, 600 | |
| img_array = np.zeros((height, width, 3), dtype=np.uint8) | |
| # Add test pattern | |
| for y in range(height): | |
| for x in range(width): | |
| img_array[y, x, 0] = (x % 256) # Red gradient | |
| img_array[y, x, 1] = (y % 256) # Green gradient | |
| img_array[y, x, 2] = ((x + y) % 256) # Blue pattern | |
| # Convert to torch tensor in expected format [-1, 1] | |
| tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format | |
| tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range | |
| tensor = tensor.unsqueeze(0) # Add batch dimension | |
| self.obs = tensor | |
| self.play_env = None # No real environment in dummy mode | |
| logger.info("Dummy mode initialized with test pattern") | |
| def step_environment(self): | |
| """Step the environment with current input state (with intelligent frame skipping)""" | |
| if self.play_env is None: | |
| # Dummy mode - just return current observation | |
| return self.obs, 0.0, False, False, {"mode": "dummy"} | |
| try: | |
| # Check if reset is requested | |
| if self.should_reset: | |
| self.reset_environment() | |
| self.should_reset = False | |
| self.last_ai_time = self.time_module.time() # Reset AI timer | |
| return self.obs, 0.0, False, False, {"reset": True} | |
| # Intelligent frame skipping: only run AI inference at target FPS | |
| current_time = self.time_module.time() | |
| time_since_last_ai = current_time - self.last_ai_time | |
| should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps) | |
| if should_run_ai: | |
| # Show loading indicator for first inference (can be slow) | |
| if not self.first_inference_done: | |
| logger.info("Running first AI inference (may take 5-15 seconds)...") | |
| # Run AI inference | |
| inference_start = self.time_module.time() | |
| next_obs, reward, done, truncated, info = self.play_env.step_from_web_input( | |
| pressed_keys=self.pressed_keys, | |
| mouse_x=self.mouse_x, | |
| mouse_y=self.mouse_y, | |
| l_click=self.l_click, | |
| r_click=self.r_click | |
| ) | |
| inference_time = self.time_module.time() - inference_start | |
| # Log first inference completion | |
| if not self.first_inference_done: | |
| self.first_inference_done = True | |
| logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!") | |
| # Cache the new observation and update timing | |
| self.cached_obs = next_obs | |
| self.last_ai_time = current_time | |
| self.ai_frame_count += 1 | |
| # Add AI performance info | |
| info = info or {} | |
| info["ai_inference"] = True | |
| # Calculate proper AI FPS: frames / elapsed time since start | |
| elapsed_time = current_time - self.start_time | |
| if elapsed_time > 0 and self.ai_frame_count > 0: | |
| ai_fps = self.ai_frame_count / elapsed_time | |
| # Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference) | |
| info["ai_fps"] = min(ai_fps, 100.0) | |
| else: | |
| info["ai_fps"] = 0 | |
| info["inference_time"] = inference_time | |
| return next_obs, reward, done, truncated, info | |
| else: | |
| # Use cached observation for smoother display without AI overhead | |
| obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs | |
| # Calculate AI FPS for cached frames too | |
| elapsed_time = current_time - self.start_time | |
| if elapsed_time > 0 and self.ai_frame_count > 0: | |
| ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS | |
| else: | |
| ai_fps = 0 | |
| return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps} | |
| except Exception as e: | |
| logger.error(f"Error stepping environment: {e}") | |
| obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs | |
| return obs_to_return, 0.0, False, False, {"error": str(e)} | |
| def reset_environment(self): | |
| """Reset the environment""" | |
| try: | |
| if self.play_env is not None: | |
| self.obs, _ = self.play_env.reset() | |
| self.cached_obs = self.obs # Update cache | |
| logger.info("Environment reset successfully") | |
| else: | |
| # Dummy mode - recreate test pattern | |
| self._init_dummy_mode() | |
| self.cached_obs = self.obs # Update cache | |
| logger.info("Dummy environment reset") | |
| except Exception as e: | |
| logger.error(f"Error resetting environment: {e}") | |
| def request_reset(self): | |
| """Request environment reset on next step""" | |
| self.should_reset = True | |
| logger.info("Environment reset requested") | |
| def start_game(self): | |
| """Start the game""" | |
| self.game_started = True | |
| self.start_time = self.time_module.time() # Reset start time for FPS calculation | |
| self.ai_frame_count = 0 # Reset AI frame count | |
| logger.info("Game started") | |
| def pause_game(self): | |
| """Pause/stop the game""" | |
| self.game_started = False | |
| logger.info("Game paused") | |
| def obs_to_base64(self, obs: torch.Tensor) -> str: | |
| """Convert observation tensor to base64 image for web display""" | |
| if obs is None: | |
| return "" | |
| try: | |
| # Convert tensor to PIL Image | |
| if obs.ndim == 4 and obs.size(0) == 1: | |
| img_array = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy() | |
| else: | |
| img_array = obs.add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy() | |
| img = Image.fromarray(img_array) | |
| # Resize for web display to match canvas size (optimized) | |
| img = img.resize((600, 150), Image.NEAREST) # NEAREST is faster than BICUBIC | |
| # Optimized base64 conversion with JPEG for better compression/speed | |
| buffer = io.BytesIO() | |
| img.save(buffer, format='JPEG', quality=85, optimize=True) # JPEG is faster than PNG | |
| img_str = base64.b64encode(buffer.getvalue()).decode() | |
| return f"data:image/jpeg;base64,{img_str}" | |
| except Exception as e: | |
| logger.error(f"Error converting observation to base64: {e}") | |
| return "" | |
| async def game_loop(self): | |
| """Main game loop that runs continuously""" | |
| self.running = True | |
| while self.running: | |
| try: | |
| # Check if models are ready | |
| if not self.models_ready: | |
| # Send loading status to clients | |
| if connected_clients: | |
| loading_data = { | |
| 'type': 'loading', | |
| 'status': self.loading_status, | |
| 'progress': self.download_progress, | |
| 'ready': False | |
| } | |
| disconnected = set() | |
| for client in connected_clients.copy(): | |
| try: | |
| await client.send_text(json.dumps(loading_data)) | |
| except: | |
| disconnected.add(client) | |
| connected_clients.difference_update(disconnected) | |
| await asyncio.sleep(0.5) # Check every 500ms during loading | |
| continue | |
| # Always send frames, but only step environment if game is started | |
| should_send_frame = True | |
| if not self.game_started: | |
| # Game not started - just send current observation without stepping | |
| if self.obs is not None and connected_clients: | |
| should_send_frame = True | |
| else: | |
| should_send_frame = False | |
| await asyncio.sleep(0.1) | |
| else: | |
| # Game is started - step environment | |
| if self.play_env is None: | |
| await asyncio.sleep(0.1) | |
| continue | |
| # Step environment with current input state | |
| next_obs, reward, done, truncated, info = self.step_environment() | |
| if done or truncated: | |
| # Auto-reset when episode ends | |
| self.reset_environment() | |
| else: | |
| self.obs = next_obs | |
| # Send frame to all connected clients (regardless of game state) | |
| if should_send_frame and connected_clients and self.obs is not None: | |
| # Set default values for when game isn't running | |
| if not self.game_started: | |
| reward = 0.0 | |
| info = {"waiting": True} | |
| # If game is started, reward and info should be set above | |
| # Convert observation to base64 | |
| image_data = self.obs_to_base64(self.obs) | |
| # Debug logging for first few frames | |
| if self.frame_count < 5: | |
| logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, " | |
| f"image_data_length={len(image_data) if image_data else 0}, " | |
| f"game_started={self.game_started}") | |
| frame_data = { | |
| 'type': 'frame', | |
| 'image': image_data, | |
| 'frame_count': self.frame_count, | |
| 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0, | |
| 'info': str(info) if info else "", | |
| 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0, | |
| 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False | |
| } | |
| # Send to all connected clients | |
| disconnected = set() | |
| for client in connected_clients.copy(): | |
| try: | |
| await client.send_text(json.dumps(frame_data)) | |
| except: | |
| disconnected.add(client) | |
| # Remove disconnected clients | |
| connected_clients.difference_update(disconnected) | |
| self.frame_count += 1 | |
| await asyncio.sleep(1.0 / self.fps) # Control FPS | |
| except Exception as e: | |
| logger.error(f"Error in game loop: {e}") | |
| await asyncio.sleep(0.1) | |
| # Global game engine instance | |
| game_engine = WebGameEngine() | |
| async def startup_event(): | |
| """Initialize models when the app starts""" | |
| # Start the game loop immediately (it will handle loading state) | |
| asyncio.create_task(game_engine.game_loop()) | |
| # Initialize models in background (non-blocking) | |
| asyncio.create_task(game_engine.initialize_models()) | |
| async def get_homepage(): | |
| """Serve the main game interface""" | |
| html_content = """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Diamond CSGO AI Player</title> | |
| <style> | |
| body { | |
| margin: 0; | |
| padding: 20px; | |
| background: #1a1a1a; | |
| color: white; | |
| font-family: 'Courier New', monospace; | |
| text-align: center; | |
| } | |
| #gameCanvas { | |
| border: 2px solid #00ff00; | |
| background: #000; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| #controls { | |
| margin: 20px; | |
| display: grid; | |
| grid-template-columns: 1fr 1fr; | |
| gap: 20px; | |
| max-width: 800px; | |
| margin: 20px auto; | |
| } | |
| .control-section { | |
| background: #2a2a2a; | |
| padding: 15px; | |
| border-radius: 8px; | |
| border: 1px solid #444; | |
| } | |
| .key-display { | |
| background: #333; | |
| border: 1px solid #555; | |
| padding: 5px 10px; | |
| margin: 2px; | |
| border-radius: 4px; | |
| display: inline-block; | |
| min-width: 30px; | |
| } | |
| .key-pressed { | |
| background: #00ff00; | |
| color: #000; | |
| } | |
| #status { | |
| margin: 10px; | |
| padding: 10px; | |
| background: #2a2a2a; | |
| border-radius: 4px; | |
| } | |
| .info { | |
| color: #00ff00; | |
| margin: 5px 0; | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>🎮 Diamond CSGO AI Player</h1> | |
| <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p> | |
| <p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p> | |
| <!-- Model Download Progress --> | |
| <div id="downloadSection" style="display: none; margin: 20px;"> | |
| <p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p> | |
| <div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;"> | |
| <div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div> | |
| </div> | |
| <p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p> | |
| </div> | |
| <canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas> | |
| <div id="status"> | |
| <div class="info">Status: <span id="connectionStatus">Connecting...</span></div> | |
| <div class="info">Game: <span id="gameStatus">Click to Start</span></div> | |
| <div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div> | |
| <div class="info">Reward: <span id="reward">0</span></div> | |
| </div> | |
| <div id="controls"> | |
| <div class="control-section"> | |
| <h3>Movement</h3> | |
| <div> | |
| <span class="key-display" id="key-w">W</span> Forward<br> | |
| <span class="key-display" id="key-a">A</span> Left | |
| <span class="key-display" id="key-s">S</span> Back | |
| <span class="key-display" id="key-d">D</span> Right<br> | |
| <span class="key-display" id="key-space">Space</span> Jump | |
| <span class="key-display" id="key-ctrl">Ctrl</span> Crouch | |
| <span class="key-display" id="key-shift">Shift</span> Walk | |
| </div> | |
| </div> | |
| <div class="control-section"> | |
| <h3>Actions</h3> | |
| <div> | |
| <span class="key-display" id="key-1">1</span> Weapon 1<br> | |
| <span class="key-display" id="key-2">2</span> Weapon 2 | |
| <span class="key-display" id="key-3">3</span> Weapon 3<br> | |
| <span class="key-display" id="key-r">R</span> Reload<br> | |
| <span class="key-display" id="key-arrows">↑↓←→</span> Camera<br> | |
| <span class="key-display" id="key-enter">Enter</span> Reset Game<br> | |
| <span class="key-display" id="key-esc">Esc</span> Pause/Quit | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| const canvas = document.getElementById('gameCanvas'); | |
| const ctx = canvas.getContext('2d'); | |
| const statusEl = document.getElementById('connectionStatus'); | |
| const gameStatusEl = document.getElementById('gameStatus'); | |
| const frameEl = document.getElementById('frameCount'); | |
| const aiFpsEl = document.getElementById('aiFps'); | |
| const rewardEl = document.getElementById('reward'); | |
| const loadingEl = document.getElementById('loadingIndicator'); | |
| const downloadSectionEl = document.getElementById('downloadSection'); | |
| const downloadStatusEl = document.getElementById('downloadStatus'); | |
| const progressBarEl = document.getElementById('progressBar'); | |
| const progressTextEl = document.getElementById('progressText'); | |
| let ws = null; | |
| let pressedKeys = new Set(); | |
| let gameStarted = false; | |
| // Key mapping | |
| const keyDisplayMap = { | |
| 'KeyW': 'key-w', | |
| 'KeyA': 'key-a', | |
| 'KeyS': 'key-s', | |
| 'KeyD': 'key-d', | |
| 'Space': 'key-space', | |
| 'ControlLeft': 'key-ctrl', | |
| 'ShiftLeft': 'key-shift', | |
| 'Digit1': 'key-1', | |
| 'Digit2': 'key-2', | |
| 'Digit3': 'key-3', | |
| 'KeyR': 'key-r', | |
| 'ArrowUp': 'key-arrows', | |
| 'ArrowDown': 'key-arrows', | |
| 'ArrowLeft': 'key-arrows', | |
| 'ArrowRight': 'key-arrows', | |
| 'Enter': 'key-enter', | |
| 'Escape': 'key-esc' | |
| }; | |
| function connectWebSocket() { | |
| const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; | |
| const wsUrl = `${protocol}//${window.location.host}/ws`; | |
| ws = new WebSocket(wsUrl); | |
| ws.onopen = function(event) { | |
| statusEl.textContent = 'Connected'; | |
| statusEl.style.color = '#00ff00'; | |
| }; | |
| ws.onmessage = function(event) { | |
| const data = JSON.parse(event.data); | |
| if (data.type === 'loading') { | |
| // Handle loading status | |
| downloadSectionEl.style.display = 'block'; | |
| downloadStatusEl.textContent = data.status; | |
| if (data.progress !== undefined) { | |
| progressBarEl.style.width = data.progress + '%'; | |
| progressTextEl.textContent = data.progress + '% - ' + data.status; | |
| } else { | |
| progressTextEl.textContent = data.status; | |
| } | |
| gameStatusEl.textContent = 'Loading Models...'; | |
| gameStatusEl.style.color = '#ffaa00'; | |
| } else if (data.type === 'frame') { | |
| // Hide loading indicators once we get frames | |
| downloadSectionEl.style.display = 'none'; | |
| // Update frame display | |
| if (data.image) { | |
| const img = new Image(); | |
| img.onload = function() { | |
| ctx.clearRect(0, 0, canvas.width, canvas.height); | |
| ctx.drawImage(img, 0, 0, canvas.width, canvas.height); | |
| }; | |
| img.src = data.image; | |
| } | |
| frameEl.textContent = data.frame_count; | |
| rewardEl.textContent = data.reward.toFixed(2); | |
| // Update AI FPS display and hide loading indicator once AI starts | |
| if (data.ai_fps !== undefined && data.ai_fps !== null) { | |
| // Ensure FPS value is reasonable | |
| const aiFps = Math.min(Math.max(data.ai_fps, 0), 100); | |
| aiFpsEl.textContent = aiFps.toFixed(1); | |
| // Color code AI FPS for performance indication | |
| if (aiFps >= 8) { | |
| aiFpsEl.style.color = '#00ff00'; // Green for good performance | |
| } else if (aiFps >= 5) { | |
| aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance | |
| } else if (aiFps > 0) { | |
| aiFpsEl.style.color = '#ff0000'; // Red for poor performance | |
| } else { | |
| aiFpsEl.style.color = '#888888'; // Gray for inactive | |
| } | |
| // Hide loading indicator once AI inference starts working | |
| if (aiFps > 0 && gameStarted) { | |
| loadingEl.style.display = 'none'; | |
| gameStatusEl.textContent = 'Playing'; | |
| gameStatusEl.style.color = '#00ff00'; | |
| } | |
| } | |
| } | |
| }; | |
| ws.onclose = function(event) { | |
| statusEl.textContent = 'Disconnected'; | |
| statusEl.style.color = '#ff0000'; | |
| setTimeout(connectWebSocket, 1000); // Reconnect after 1 second | |
| }; | |
| ws.onerror = function(event) { | |
| statusEl.textContent = 'Error'; | |
| statusEl.style.color = '#ff0000'; | |
| }; | |
| } | |
| function sendKeyState() { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'keys', | |
| keys: Array.from(pressedKeys) | |
| })); | |
| } | |
| } | |
| function startGame() { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'start' | |
| })); | |
| gameStarted = true; | |
| gameStatusEl.textContent = 'Starting AI...'; | |
| gameStatusEl.style.color = '#ffff00'; | |
| loadingEl.style.display = 'block'; | |
| console.log('Game started'); | |
| } | |
| } | |
| function pauseGame() { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'pause' | |
| })); | |
| gameStarted = false; | |
| gameStatusEl.textContent = 'Paused - Click to Resume'; | |
| gameStatusEl.style.color = '#ffff00'; | |
| console.log('Game paused'); | |
| } | |
| } | |
| function updateKeyDisplay() { | |
| // Reset all key displays | |
| Object.values(keyDisplayMap).forEach(id => { | |
| const el = document.getElementById(id); | |
| if (el) el.classList.remove('key-pressed'); | |
| }); | |
| // Highlight pressed keys | |
| pressedKeys.forEach(key => { | |
| const displayId = keyDisplayMap[key]; | |
| if (displayId) { | |
| const el = document.getElementById(displayId); | |
| if (el) el.classList.add('key-pressed'); | |
| } | |
| }); | |
| } | |
| // Focus canvas and handle keyboard events | |
| canvas.addEventListener('click', () => { | |
| canvas.focus(); | |
| if (!gameStarted) { | |
| startGame(); | |
| } | |
| }); | |
| canvas.addEventListener('keydown', (event) => { | |
| event.preventDefault(); | |
| // Handle special keys | |
| if (event.code === 'Enter') { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'reset' | |
| })); | |
| console.log('Environment reset requested'); | |
| } | |
| // Add to pressedKeys for visual feedback | |
| pressedKeys.add(event.code); | |
| updateKeyDisplay(); | |
| // Remove Enter from pressedKeys after a short delay for visual feedback | |
| setTimeout(() => { | |
| pressedKeys.delete(event.code); | |
| updateKeyDisplay(); | |
| }, 200); | |
| } else if (event.code === 'Escape') { | |
| pauseGame(); | |
| // Add to pressedKeys for visual feedback | |
| pressedKeys.add(event.code); | |
| updateKeyDisplay(); | |
| // Remove ESC from pressedKeys after a short delay for visual feedback | |
| setTimeout(() => { | |
| pressedKeys.delete(event.code); | |
| updateKeyDisplay(); | |
| }, 200); | |
| } else { | |
| // Only send game keys if game is started | |
| if (gameStarted) { | |
| pressedKeys.add(event.code); | |
| updateKeyDisplay(); | |
| sendKeyState(); | |
| } | |
| } | |
| }); | |
| canvas.addEventListener('keyup', (event) => { | |
| event.preventDefault(); | |
| // Don't handle special keys release (handled in keydown with timeout) | |
| if (event.code !== 'Enter' && event.code !== 'Escape') { | |
| if (gameStarted) { | |
| pressedKeys.delete(event.code); | |
| updateKeyDisplay(); | |
| sendKeyState(); | |
| } | |
| } | |
| }); | |
| // Handle mouse events for clicks | |
| canvas.addEventListener('mousedown', (event) => { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'mouse', | |
| button: event.button, | |
| action: 'down', | |
| x: event.offsetX, | |
| y: event.offsetY | |
| })); | |
| } | |
| }); | |
| canvas.addEventListener('mouseup', (event) => { | |
| if (ws && ws.readyState === WebSocket.OPEN) { | |
| ws.send(JSON.stringify({ | |
| type: 'mouse', | |
| button: event.button, | |
| action: 'up', | |
| x: event.offsetX, | |
| y: event.offsetY | |
| })); | |
| } | |
| }); | |
| // Initialize | |
| connectWebSocket(); | |
| canvas.focus(); | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return html_content | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """Handle WebSocket connections for real-time game communication""" | |
| await websocket.accept() | |
| connected_clients.add(websocket) | |
| try: | |
| while True: | |
| # Receive messages from client | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| if message['type'] == 'keys': | |
| # Update pressed keys | |
| game_engine.pressed_keys = set(message['keys']) | |
| elif message['type'] == 'reset': | |
| # Handle environment reset request | |
| game_engine.request_reset() | |
| elif message['type'] == 'start': | |
| # Handle game start request | |
| game_engine.start_game() | |
| elif message['type'] == 'pause': | |
| # Handle game pause request | |
| game_engine.pause_game() | |
| elif message['type'] == 'mouse': | |
| # Handle mouse events | |
| if message['action'] == 'down': | |
| if message['button'] == 0: # Left click | |
| game_engine.l_click = True | |
| elif message['button'] == 2: # Right click | |
| game_engine.r_click = True | |
| elif message['action'] == 'up': | |
| if message['button'] == 0: # Left click | |
| game_engine.l_click = False | |
| elif message['button'] == 2: # Right click | |
| game_engine.r_click = False | |
| # Update mouse position (relative to canvas) | |
| game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px | |
| game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px | |
| except WebSocketDisconnect: | |
| connected_clients.discard(websocket) | |
| except Exception as e: | |
| logger.error(f"WebSocket error: {e}") | |
| connected_clients.discard(websocket) | |
| if __name__ == "__main__": | |
| # For local development | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) | |