musictimer commited on
Commit
c64c726
·
1 Parent(s): 0f24197

Initial Diamond CSGO AI deployment

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +38 -0
  2. README.md +76 -11
  3. app.py +969 -0
  4. config/agent/csgo.yaml +34 -0
  5. config/env/csgo.yaml +7 -0
  6. config/trainer.yaml +9 -0
  7. config/world_model_env/fast.yaml +17 -0
  8. config_web.py +208 -0
  9. csgo/spawn/0/act.npy +3 -0
  10. csgo/spawn/0/full_res.npy +3 -0
  11. csgo/spawn/0/info.json +1 -0
  12. csgo/spawn/0/low_res.npy +3 -0
  13. csgo/spawn/0/next_act.npy +3 -0
  14. packages.txt +3 -0
  15. requirements.txt +32 -0
  16. src/__init__.py +0 -0
  17. src/__pycache__/__init__.cpython-310.pyc +0 -0
  18. src/__pycache__/agent.cpython-310.pyc +0 -0
  19. src/__pycache__/trainer.cpython-310.pyc +0 -0
  20. src/__pycache__/utils.cpython-310.pyc +0 -0
  21. src/agent.py +74 -0
  22. src/coroutines/__init__.py +11 -0
  23. src/coroutines/__pycache__/__init__.cpython-310.pyc +0 -0
  24. src/coroutines/__pycache__/collector.cpython-310.pyc +0 -0
  25. src/coroutines/__pycache__/env_loop.cpython-310.pyc +0 -0
  26. src/coroutines/collector.py +126 -0
  27. src/coroutines/env_loop.py +74 -0
  28. src/csgo/__init__.py +0 -0
  29. src/csgo/__pycache__/__init__.cpython-310.pyc +0 -0
  30. src/csgo/__pycache__/action_processing.cpython-310.pyc +0 -0
  31. src/csgo/__pycache__/keymap.cpython-310.pyc +0 -0
  32. src/csgo/__pycache__/web_action_processing.cpython-310.pyc +0 -0
  33. src/csgo/action_processing.py +230 -0
  34. src/csgo/keymap.py +33 -0
  35. src/csgo/web_action_processing.py +167 -0
  36. src/data/__init__.py +6 -0
  37. src/data/__pycache__/__init__.cpython-310.pyc +0 -0
  38. src/data/__pycache__/batch.cpython-310.pyc +0 -0
  39. src/data/__pycache__/batch_sampler.cpython-310.pyc +0 -0
  40. src/data/__pycache__/dataset.cpython-310.pyc +0 -0
  41. src/data/__pycache__/episode.cpython-310.pyc +0 -0
  42. src/data/__pycache__/segment.cpython-310.pyc +0 -0
  43. src/data/__pycache__/utils.cpython-310.pyc +0 -0
  44. src/data/batch.py +25 -0
  45. src/data/batch_sampler.py +72 -0
  46. src/data/dataset.py +202 -0
  47. src/data/episode.py +64 -0
  48. src/data/segment.py +30 -0
  49. src/data/utils.py +89 -0
  50. src/envs/__init__.py +2 -0
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.9 slim image
2
+ FROM python:3.9-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ build-essential \
10
+ curl \
11
+ software-properties-common \
12
+ git \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements first for better caching
16
+ COPY requirements.txt .
17
+
18
+ # Install Python dependencies
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy source code
22
+ COPY . .
23
+
24
+ # Create necessary directories
25
+ RUN mkdir -p csgo/spawn config checkpoints cache
26
+
27
+ # Set environment variables
28
+ ENV PYTHONPATH=/app/src:/app
29
+ ENV CUDA_VISIBLE_DEVICES=""
30
+
31
+ # Expose port
32
+ EXPOSE 7860
33
+
34
+ # Health check
35
+ HEALTHCHECK CMD curl --fail http://localhost:7860/ || exit 1
36
+
37
+ # Run the application
38
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,76 @@
1
- ---
2
- title: Diamond Ai Player
3
- emoji: 🏆
4
- colorFrom: yellow
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diamond CSGO AI Player 🎮
2
+
3
+ A web-based demo of the Diamond AI agent playing Counter-Strike: Global Offensive using diffusion models and reinforcement learning.
4
+
5
+ ## Features
6
+
7
+ - **Real-time Keyboard Input**: Use standard WASD controls and other keys to interact
8
+ - **AI Agent**: Pre-trained agent using diffusion-based world models
9
+ - **Web Interface**: No installation required, play directly in your browser
10
+ - **Live Visualization**: See the AI's perspective and actions in real-time
11
+
12
+ ## Controls
13
+
14
+ ### Movement
15
+ - **W** - Move Forward
16
+ - **A** - Move Left
17
+ - **S** - Move Back
18
+ - **D** - Move Right
19
+ - **Space** - Jump
20
+ - **Ctrl** - Crouch
21
+ - **Shift** - Walk
22
+
23
+ ### Actions
24
+ - **1, 2, 3** - Switch Weapons
25
+ - **R** - Reload
26
+ - **Arrow Keys** - Camera Movement
27
+ - **Left/Right Click** - Primary/Secondary Fire
28
+
29
+ ### Game Controls
30
+ - **M** - Switch between Human/AI control
31
+ - **Enter** - Reset Environment
32
+
33
+ ## How to Play
34
+
35
+ 1. Click on the game canvas to focus it
36
+ 2. Use keyboard controls to play
37
+ 3. The AI agent will respond to your inputs in real-time
38
+ 4. Switch to AI mode to watch the agent play autonomously
39
+
40
+ ## Technical Details
41
+
42
+ This demo uses:
43
+ - **FastAPI + WebSocket** for real-time communication
44
+ - **PyTorch** for AI model inference
45
+ - **Diffusion Models** for next-frame prediction
46
+ - **World Model Environment** for simulation
47
+
48
+ The agent was trained using the Diamond framework, which combines:
49
+ - Diffusion-based world models
50
+ - Actor-critic reinforcement learning
51
+ - Multi-step planning and imagination
52
+
53
+ ## Model Information
54
+
55
+ The AI agent uses several neural networks:
56
+ - **Denoiser**: Diffusion model for generating next observations
57
+ - **Upsampler**: High-resolution image generation
58
+ - **Reward/End Model**: Predicting game outcomes
59
+ - **Actor-Critic**: Action selection and value estimation
60
+
61
+ ## Citation
62
+
63
+ This work is based on the Diamond framework. If you use this code, please cite:
64
+
65
+ ```bibtex
66
+ @article{diamond2024,
67
+ title={Diamond: Diffusion for World Modeling},
68
+ author={[Authors]},
69
+ journal={[Journal]},
70
+ year={2024}
71
+ }
72
+ ```
73
+
74
+ ## License
75
+
76
+ See LICENSE file for details.
app.py ADDED
@@ -0,0 +1,969 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web-based Diamond CSGO AI Player for Hugging Face Spaces
3
+ Uses FastAPI + WebSocket for real-time keyboard input and game streaming
4
+ """
5
+
6
+ import asyncio
7
+ import base64
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Dict, List, Optional, Set
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import torch
18
+ import uvicorn
19
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
20
+ from fastapi.responses import HTMLResponse
21
+ from fastapi.staticfiles import StaticFiles
22
+ from hydra import compose, initialize
23
+ from hydra.utils import instantiate
24
+ from omegaconf import DictConfig, OmegaConf
25
+ from PIL import Image
26
+
27
+ # Import your modules
28
+ from src.agent import Agent
29
+ from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
30
+ from src.envs import WorldModelEnv
31
+ from src.game.web_play_env import WebPlayEnv
32
+ from config_web import web_config
33
+
34
+ # Configure logging
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Global variables
39
+ app = FastAPI(title="Diamond CSGO AI Player")
40
+ connected_clients: Set[WebSocket] = set()
41
+
42
+ class WebKeyMap:
43
+ """Map web key codes to pygame-like keys for CSGO actions"""
44
+ WEB_TO_CSGO = {
45
+ 'KeyW': 'w',
46
+ 'KeyA': 'a',
47
+ 'KeyS': 's',
48
+ 'KeyD': 'd',
49
+ 'Space': 'space',
50
+ 'ControlLeft': 'left ctrl',
51
+ 'ShiftLeft': 'left shift',
52
+ 'Digit1': '1',
53
+ 'Digit2': '2',
54
+ 'Digit3': '3',
55
+ 'KeyR': 'r',
56
+ 'ArrowUp': 'camera_up',
57
+ 'ArrowDown': 'camera_down',
58
+ 'ArrowLeft': 'camera_left',
59
+ 'ArrowRight': 'camera_right'
60
+ }
61
+
62
+ class WebGameEngine:
63
+ """Web-compatible game engine that replaces pygame functionality"""
64
+
65
+ def __init__(self):
66
+ self.play_env: Optional[WebPlayEnv] = None
67
+ self.obs = None
68
+ self.running = False
69
+ self.game_started = False
70
+ self.fps = 30 # Display FPS
71
+ self.ai_fps = 10 # AI inference FPS (slower than display for efficiency)
72
+ self.frame_count = 0
73
+ self.ai_frame_count = 0
74
+ self.last_ai_time = 0
75
+ self.start_time = 0 # Track when AI started for proper FPS calculation
76
+ self.pressed_keys: Set[str] = set()
77
+ self.mouse_x = 0
78
+ self.mouse_y = 0
79
+ self.l_click = False
80
+ self.r_click = False
81
+ self.should_reset = False
82
+ self.cached_obs = None # Cache last observation for frame skipping
83
+ self.first_inference_done = False # Track if first inference completed
84
+ self.models_ready = False # Track if models are loaded
85
+ self.download_progress = 0 # Track download progress (0-100)
86
+ self.loading_status = "Initializing..." # Loading status message
87
+ import time
88
+ self.time_module = time
89
+
90
+ async def _download_model_async(self, url, filepath):
91
+ """Download model asynchronously with progress tracking"""
92
+ import asyncio
93
+ import concurrent.futures
94
+ import urllib.request
95
+ import os
96
+
97
+ def download_with_progress():
98
+ """Download function that runs in thread pool"""
99
+ def progress_hook(block_num, block_size, total_size):
100
+ if total_size > 0:
101
+ progress = min(100, (block_num * block_size * 100) / total_size)
102
+ self.download_progress = int(progress)
103
+ if progress % 10 == 0: # Log every 10%
104
+ logger.info(f"Download progress: {self.download_progress}%")
105
+
106
+ urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
107
+ self.download_progress = 100
108
+
109
+ # Run download in thread pool to avoid blocking
110
+ loop = asyncio.get_event_loop()
111
+ with concurrent.futures.ThreadPoolExecutor() as executor:
112
+ await loop.run_in_executor(executor, download_with_progress)
113
+
114
+ logger.info("Model download completed!")
115
+
116
+ async def initialize_models(self):
117
+ """Initialize the AI models and environment"""
118
+ try:
119
+ import torch
120
+ logger.info("Initializing models...")
121
+
122
+ # Setup environment and paths
123
+ web_config.setup_environment_variables()
124
+ web_config.create_default_configs()
125
+
126
+ config_path = web_config.get_config_path()
127
+ logger.info(f"Using config path: {config_path}")
128
+
129
+ # Convert to relative path for Hydra
130
+ import os
131
+ relative_config_path = os.path.relpath(config_path)
132
+ logger.info(f"Relative config path: {relative_config_path}")
133
+
134
+ with initialize(version_base="1.3", config_path=relative_config_path):
135
+ cfg = compose(config_name="trainer")
136
+
137
+ # Override config for deployment
138
+ cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml")
139
+ cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml")
140
+
141
+ # Use CPU if no GPU available (for free HF spaces)
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
+ logger.info(f"Using device: {device}")
144
+
145
+ # Load model checkpoint
146
+ checkpoint_path = web_config.get_checkpoint_path()
147
+ if not checkpoint_path.exists():
148
+ logger.warning(f"No checkpoint found at {checkpoint_path} - using dummy mode")
149
+ self._init_dummy_mode()
150
+ return True
151
+
152
+ # Get spawn directory
153
+ spawn_dir = web_config.get_spawn_dir()
154
+
155
+ # Initialize agent
156
+ num_actions = cfg.env.num_actions
157
+ agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
158
+
159
+ # Try to load checkpoint (remote or local)
160
+ try:
161
+ # First try to download from Hugging Face Hub using direct URL
162
+ try:
163
+ import torch.hub
164
+ import os
165
+ logger.info("Downloading model from Hugging Face Hub...")
166
+
167
+ # Direct download URL (change 'blob' to 'resolve' for direct download)
168
+ model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
169
+
170
+ # Download to cache directory
171
+ cache_dir = "./cache"
172
+ os.makedirs(cache_dir, exist_ok=True)
173
+ model_cache_path = os.path.join(cache_dir, "agent_epoch_00003.pt")
174
+
175
+ # Download if not cached
176
+ if not os.path.exists(model_cache_path):
177
+ logger.info(f"Downloading 1.53GB model to {model_cache_path}...")
178
+ self.loading_status = "Downloading AI model from Hugging Face Hub..."
179
+
180
+ # Download with progress tracking in a separate thread
181
+ await self._download_model_async(model_url, model_cache_path)
182
+ else:
183
+ logger.info(f"Using cached model from {model_cache_path}")
184
+ self.loading_status = "Loading cached model..."
185
+
186
+ # Use the agent's load method which expects a file path
187
+ self.loading_status = "Loading model weights..."
188
+ agent.load(model_cache_path)
189
+ logger.info(f"Successfully loaded checkpoint from HF Hub")
190
+
191
+ except Exception as hub_error:
192
+ logger.warning(f"Failed to download from HF Hub: {hub_error}")
193
+
194
+ # Fallback to local checkpoint if available
195
+ if checkpoint_path.exists():
196
+ logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
197
+ agent.load(checkpoint_path)
198
+ logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
199
+ else:
200
+ raise FileNotFoundError("No model checkpoint available (local or remote)")
201
+
202
+ except Exception as e:
203
+ logger.error(f"Failed to load any checkpoint: {e}")
204
+ self._init_dummy_mode()
205
+ return True
206
+
207
+ # Initialize world model environment
208
+ try:
209
+ sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
210
+ if agent.upsampler is not None:
211
+ sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
212
+ wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
213
+ wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model,
214
+ spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
215
+
216
+ # Create play environment
217
+ self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
218
+
219
+ # Model compilation causes 10-30s delay on first inference, so make it optional
220
+ # You can enable it by setting ENABLE_TORCH_COMPILE=1 environment variable
221
+ import os
222
+ if device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "0") == "1":
223
+ logger.info("Compiling models for faster inference (will cause delay on first inference)...")
224
+ try:
225
+ wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
226
+ if wm_env.upsample_next_obs is not None:
227
+ wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
228
+ logger.info("Model compilation enabled successfully!")
229
+ except Exception as e:
230
+ logger.warning(f"Model compilation failed: {e}")
231
+ else:
232
+ logger.info("Model compilation disabled (faster startup). Set ENABLE_TORCH_COMPILE=1 to enable.")
233
+
234
+ # Reset environment
235
+ self.obs, _ = self.play_env.reset()
236
+ self.cached_obs = self.obs # Initialize cache
237
+
238
+ logger.info("Models initialized successfully!")
239
+ logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}")
240
+ self.models_ready = True
241
+ self.loading_status = "Ready!"
242
+ return True
243
+
244
+ except Exception as e:
245
+ logger.error(f"Failed to initialize world model environment: {e}")
246
+ self._init_dummy_mode()
247
+ self.models_ready = True
248
+ self.loading_status = "Using dummy mode"
249
+ return True
250
+
251
+ except Exception as e:
252
+ logger.error(f"Failed to initialize models: {e}")
253
+ import traceback
254
+ traceback.print_exc()
255
+ self._init_dummy_mode()
256
+ self.models_ready = True
257
+ self.loading_status = "Error - using dummy mode"
258
+ return True
259
+
260
+ def _init_dummy_mode(self):
261
+ """Initialize dummy mode for testing without models"""
262
+ logger.info("Initializing dummy mode...")
263
+
264
+ # Create a test observation
265
+ height, width = 150, 600
266
+ img_array = np.zeros((height, width, 3), dtype=np.uint8)
267
+
268
+ # Add test pattern
269
+ for y in range(height):
270
+ for x in range(width):
271
+ img_array[y, x, 0] = (x % 256) # Red gradient
272
+ img_array[y, x, 1] = (y % 256) # Green gradient
273
+ img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
274
+
275
+ # Convert to torch tensor in expected format [-1, 1]
276
+ tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
277
+ tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
278
+ tensor = tensor.unsqueeze(0) # Add batch dimension
279
+
280
+ self.obs = tensor
281
+ self.play_env = None # No real environment in dummy mode
282
+ logger.info("Dummy mode initialized with test pattern")
283
+
284
+
285
+ def step_environment(self):
286
+ """Step the environment with current input state (with intelligent frame skipping)"""
287
+ if self.play_env is None:
288
+ # Dummy mode - just return current observation
289
+ return self.obs, 0.0, False, False, {"mode": "dummy"}
290
+
291
+ try:
292
+ # Check if reset is requested
293
+ if self.should_reset:
294
+ self.reset_environment()
295
+ self.should_reset = False
296
+ self.last_ai_time = self.time_module.time() # Reset AI timer
297
+ return self.obs, 0.0, False, False, {"reset": True}
298
+
299
+ # Intelligent frame skipping: only run AI inference at target FPS
300
+ current_time = self.time_module.time()
301
+ time_since_last_ai = current_time - self.last_ai_time
302
+ should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
303
+
304
+ if should_run_ai:
305
+ # Show loading indicator for first inference (can be slow)
306
+ if not self.first_inference_done:
307
+ logger.info("Running first AI inference (may take 5-15 seconds)...")
308
+
309
+ # Run AI inference
310
+ inference_start = self.time_module.time()
311
+ next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
312
+ pressed_keys=self.pressed_keys,
313
+ mouse_x=self.mouse_x,
314
+ mouse_y=self.mouse_y,
315
+ l_click=self.l_click,
316
+ r_click=self.r_click
317
+ )
318
+ inference_time = self.time_module.time() - inference_start
319
+
320
+ # Log first inference completion
321
+ if not self.first_inference_done:
322
+ self.first_inference_done = True
323
+ logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!")
324
+
325
+ # Cache the new observation and update timing
326
+ self.cached_obs = next_obs
327
+ self.last_ai_time = current_time
328
+ self.ai_frame_count += 1
329
+
330
+ # Add AI performance info
331
+ info = info or {}
332
+ info["ai_inference"] = True
333
+
334
+ # Calculate proper AI FPS: frames / elapsed time since start
335
+ elapsed_time = current_time - self.start_time
336
+ if elapsed_time > 0 and self.ai_frame_count > 0:
337
+ ai_fps = self.ai_frame_count / elapsed_time
338
+ # Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference)
339
+ info["ai_fps"] = min(ai_fps, 100.0)
340
+ else:
341
+ info["ai_fps"] = 0
342
+
343
+ info["inference_time"] = inference_time
344
+
345
+ return next_obs, reward, done, truncated, info
346
+ else:
347
+ # Use cached observation for smoother display without AI overhead
348
+ obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
349
+
350
+ # Calculate AI FPS for cached frames too
351
+ elapsed_time = current_time - self.start_time
352
+ if elapsed_time > 0 and self.ai_frame_count > 0:
353
+ ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS
354
+ else:
355
+ ai_fps = 0
356
+
357
+ return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps}
358
+
359
+ except Exception as e:
360
+ logger.error(f"Error stepping environment: {e}")
361
+ obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
362
+ return obs_to_return, 0.0, False, False, {"error": str(e)}
363
+
364
+ def reset_environment(self):
365
+ """Reset the environment"""
366
+ try:
367
+ if self.play_env is not None:
368
+ self.obs, _ = self.play_env.reset()
369
+ self.cached_obs = self.obs # Update cache
370
+ logger.info("Environment reset successfully")
371
+ else:
372
+ # Dummy mode - recreate test pattern
373
+ self._init_dummy_mode()
374
+ self.cached_obs = self.obs # Update cache
375
+ logger.info("Dummy environment reset")
376
+ except Exception as e:
377
+ logger.error(f"Error resetting environment: {e}")
378
+
379
+ def request_reset(self):
380
+ """Request environment reset on next step"""
381
+ self.should_reset = True
382
+ logger.info("Environment reset requested")
383
+
384
+ def start_game(self):
385
+ """Start the game"""
386
+ self.game_started = True
387
+ self.start_time = self.time_module.time() # Reset start time for FPS calculation
388
+ self.ai_frame_count = 0 # Reset AI frame count
389
+ logger.info("Game started")
390
+
391
+ def pause_game(self):
392
+ """Pause/stop the game"""
393
+ self.game_started = False
394
+ logger.info("Game paused")
395
+
396
+ def obs_to_base64(self, obs: torch.Tensor) -> str:
397
+ """Convert observation tensor to base64 image for web display"""
398
+ if obs is None:
399
+ return ""
400
+
401
+ try:
402
+ # Convert tensor to PIL Image
403
+ if obs.ndim == 4 and obs.size(0) == 1:
404
+ img_array = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
405
+ else:
406
+ img_array = obs.add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
407
+
408
+ img = Image.fromarray(img_array)
409
+
410
+ # Resize for web display to match canvas size (optimized)
411
+ img = img.resize((600, 150), Image.NEAREST) # NEAREST is faster than BICUBIC
412
+
413
+ # Optimized base64 conversion with JPEG for better compression/speed
414
+ buffer = io.BytesIO()
415
+ img.save(buffer, format='JPEG', quality=85, optimize=True) # JPEG is faster than PNG
416
+ img_str = base64.b64encode(buffer.getvalue()).decode()
417
+ return f"data:image/jpeg;base64,{img_str}"
418
+
419
+ except Exception as e:
420
+ logger.error(f"Error converting observation to base64: {e}")
421
+ return ""
422
+
423
+ async def game_loop(self):
424
+ """Main game loop that runs continuously"""
425
+ self.running = True
426
+
427
+ while self.running:
428
+ try:
429
+ # Check if models are ready
430
+ if not self.models_ready:
431
+ # Send loading status to clients
432
+ if connected_clients:
433
+ loading_data = {
434
+ 'type': 'loading',
435
+ 'status': self.loading_status,
436
+ 'progress': self.download_progress,
437
+ 'ready': False
438
+ }
439
+ disconnected = set()
440
+ for client in connected_clients.copy():
441
+ try:
442
+ await client.send_text(json.dumps(loading_data))
443
+ except:
444
+ disconnected.add(client)
445
+ connected_clients.difference_update(disconnected)
446
+
447
+ await asyncio.sleep(0.5) # Check every 500ms during loading
448
+ continue
449
+
450
+ # Always send frames, but only step environment if game is started
451
+ should_send_frame = True
452
+
453
+ if not self.game_started:
454
+ # Game not started - just send current observation without stepping
455
+ if self.obs is not None and connected_clients:
456
+ should_send_frame = True
457
+ else:
458
+ should_send_frame = False
459
+ await asyncio.sleep(0.1)
460
+ else:
461
+ # Game is started - step environment
462
+ if self.play_env is None:
463
+ await asyncio.sleep(0.1)
464
+ continue
465
+
466
+ # Step environment with current input state
467
+ next_obs, reward, done, truncated, info = self.step_environment()
468
+
469
+ if done or truncated:
470
+ # Auto-reset when episode ends
471
+ self.reset_environment()
472
+ else:
473
+ self.obs = next_obs
474
+
475
+ # Send frame to all connected clients (regardless of game state)
476
+ if should_send_frame and connected_clients and self.obs is not None:
477
+ # Set default values for when game isn't running
478
+ if not self.game_started:
479
+ reward = 0.0
480
+ info = {"waiting": True}
481
+ # If game is started, reward and info should be set above
482
+
483
+ # Convert observation to base64
484
+ image_data = self.obs_to_base64(self.obs)
485
+
486
+ # Debug logging for first few frames
487
+ if self.frame_count < 5:
488
+ logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
489
+ f"image_data_length={len(image_data) if image_data else 0}, "
490
+ f"game_started={self.game_started}")
491
+
492
+ frame_data = {
493
+ 'type': 'frame',
494
+ 'image': image_data,
495
+ 'frame_count': self.frame_count,
496
+ 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
497
+ 'info': str(info) if info else "",
498
+ 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
499
+ 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False
500
+ }
501
+
502
+ # Send to all connected clients
503
+ disconnected = set()
504
+ for client in connected_clients.copy():
505
+ try:
506
+ await client.send_text(json.dumps(frame_data))
507
+ except:
508
+ disconnected.add(client)
509
+
510
+ # Remove disconnected clients
511
+ connected_clients.difference_update(disconnected)
512
+
513
+ self.frame_count += 1
514
+ await asyncio.sleep(1.0 / self.fps) # Control FPS
515
+
516
+ except Exception as e:
517
+ logger.error(f"Error in game loop: {e}")
518
+ await asyncio.sleep(0.1)
519
+
520
+ # Global game engine instance
521
+ game_engine = WebGameEngine()
522
+
523
+ @app.on_event("startup")
524
+ async def startup_event():
525
+ """Initialize models when the app starts"""
526
+ # Start the game loop immediately (it will handle loading state)
527
+ asyncio.create_task(game_engine.game_loop())
528
+
529
+ # Initialize models in background (non-blocking)
530
+ asyncio.create_task(game_engine.initialize_models())
531
+
532
+ @app.get("/", response_class=HTMLResponse)
533
+ async def get_homepage():
534
+ """Serve the main game interface"""
535
+ html_content = """
536
+ <!DOCTYPE html>
537
+ <html>
538
+ <head>
539
+ <title>Diamond CSGO AI Player</title>
540
+ <style>
541
+ body {
542
+ margin: 0;
543
+ padding: 20px;
544
+ background: #1a1a1a;
545
+ color: white;
546
+ font-family: 'Courier New', monospace;
547
+ text-align: center;
548
+ }
549
+ #gameCanvas {
550
+ border: 2px solid #00ff00;
551
+ background: #000;
552
+ margin: 20px auto;
553
+ display: block;
554
+ }
555
+ #controls {
556
+ margin: 20px;
557
+ display: grid;
558
+ grid-template-columns: 1fr 1fr;
559
+ gap: 20px;
560
+ max-width: 800px;
561
+ margin: 20px auto;
562
+ }
563
+ .control-section {
564
+ background: #2a2a2a;
565
+ padding: 15px;
566
+ border-radius: 8px;
567
+ border: 1px solid #444;
568
+ }
569
+ .key-display {
570
+ background: #333;
571
+ border: 1px solid #555;
572
+ padding: 5px 10px;
573
+ margin: 2px;
574
+ border-radius: 4px;
575
+ display: inline-block;
576
+ min-width: 30px;
577
+ }
578
+ .key-pressed {
579
+ background: #00ff00;
580
+ color: #000;
581
+ }
582
+ #status {
583
+ margin: 10px;
584
+ padding: 10px;
585
+ background: #2a2a2a;
586
+ border-radius: 4px;
587
+ }
588
+ .info {
589
+ color: #00ff00;
590
+ margin: 5px 0;
591
+ }
592
+ </style>
593
+ </head>
594
+ <body>
595
+ <h1>🎮 Diamond CSGO AI Player</h1>
596
+ <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
597
+ <p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p>
598
+
599
+ <!-- Model Download Progress -->
600
+ <div id="downloadSection" style="display: none; margin: 20px;">
601
+ <p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p>
602
+ <div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;">
603
+ <div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div>
604
+ </div>
605
+ <p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p>
606
+ </div>
607
+
608
+ <canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas>
609
+
610
+ <div id="status">
611
+ <div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
612
+ <div class="info">Game: <span id="gameStatus">Click to Start</span></div>
613
+ <div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div>
614
+ <div class="info">Reward: <span id="reward">0</span></div>
615
+ </div>
616
+
617
+ <div id="controls">
618
+ <div class="control-section">
619
+ <h3>Movement</h3>
620
+ <div>
621
+ <span class="key-display" id="key-w">W</span> Forward<br>
622
+ <span class="key-display" id="key-a">A</span> Left
623
+ <span class="key-display" id="key-s">S</span> Back
624
+ <span class="key-display" id="key-d">D</span> Right<br>
625
+ <span class="key-display" id="key-space">Space</span> Jump
626
+ <span class="key-display" id="key-ctrl">Ctrl</span> Crouch
627
+ <span class="key-display" id="key-shift">Shift</span> Walk
628
+ </div>
629
+ </div>
630
+
631
+ <div class="control-section">
632
+ <h3>Actions</h3>
633
+ <div>
634
+ <span class="key-display" id="key-1">1</span> Weapon 1<br>
635
+ <span class="key-display" id="key-2">2</span> Weapon 2
636
+ <span class="key-display" id="key-3">3</span> Weapon 3<br>
637
+ <span class="key-display" id="key-r">R</span> Reload<br>
638
+ <span class="key-display" id="key-arrows">↑↓←→</span> Camera<br>
639
+ <span class="key-display" id="key-enter">Enter</span> Reset Game<br>
640
+ <span class="key-display" id="key-esc">Esc</span> Pause/Quit
641
+ </div>
642
+ </div>
643
+ </div>
644
+
645
+ <script>
646
+ const canvas = document.getElementById('gameCanvas');
647
+ const ctx = canvas.getContext('2d');
648
+ const statusEl = document.getElementById('connectionStatus');
649
+ const gameStatusEl = document.getElementById('gameStatus');
650
+ const frameEl = document.getElementById('frameCount');
651
+ const aiFpsEl = document.getElementById('aiFps');
652
+ const rewardEl = document.getElementById('reward');
653
+ const loadingEl = document.getElementById('loadingIndicator');
654
+ const downloadSectionEl = document.getElementById('downloadSection');
655
+ const downloadStatusEl = document.getElementById('downloadStatus');
656
+ const progressBarEl = document.getElementById('progressBar');
657
+ const progressTextEl = document.getElementById('progressText');
658
+
659
+ let ws = null;
660
+ let pressedKeys = new Set();
661
+ let gameStarted = false;
662
+
663
+ // Key mapping
664
+ const keyDisplayMap = {
665
+ 'KeyW': 'key-w',
666
+ 'KeyA': 'key-a',
667
+ 'KeyS': 'key-s',
668
+ 'KeyD': 'key-d',
669
+ 'Space': 'key-space',
670
+ 'ControlLeft': 'key-ctrl',
671
+ 'ShiftLeft': 'key-shift',
672
+ 'Digit1': 'key-1',
673
+ 'Digit2': 'key-2',
674
+ 'Digit3': 'key-3',
675
+ 'KeyR': 'key-r',
676
+ 'ArrowUp': 'key-arrows',
677
+ 'ArrowDown': 'key-arrows',
678
+ 'ArrowLeft': 'key-arrows',
679
+ 'ArrowRight': 'key-arrows',
680
+ 'Enter': 'key-enter',
681
+ 'Escape': 'key-esc'
682
+ };
683
+
684
+ function connectWebSocket() {
685
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
686
+ const wsUrl = `${protocol}//${window.location.host}/ws`;
687
+
688
+ ws = new WebSocket(wsUrl);
689
+
690
+ ws.onopen = function(event) {
691
+ statusEl.textContent = 'Connected';
692
+ statusEl.style.color = '#00ff00';
693
+ };
694
+
695
+ ws.onmessage = function(event) {
696
+ const data = JSON.parse(event.data);
697
+
698
+ if (data.type === 'loading') {
699
+ // Handle loading status
700
+ downloadSectionEl.style.display = 'block';
701
+ downloadStatusEl.textContent = data.status;
702
+
703
+ if (data.progress !== undefined) {
704
+ progressBarEl.style.width = data.progress + '%';
705
+ progressTextEl.textContent = data.progress + '% - ' + data.status;
706
+ } else {
707
+ progressTextEl.textContent = data.status;
708
+ }
709
+
710
+ gameStatusEl.textContent = 'Loading Models...';
711
+ gameStatusEl.style.color = '#ffaa00';
712
+
713
+ } else if (data.type === 'frame') {
714
+ // Hide loading indicators once we get frames
715
+ downloadSectionEl.style.display = 'none';
716
+ // Update frame display
717
+ if (data.image) {
718
+ const img = new Image();
719
+ img.onload = function() {
720
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
721
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
722
+ };
723
+ img.src = data.image;
724
+ }
725
+
726
+ frameEl.textContent = data.frame_count;
727
+ rewardEl.textContent = data.reward.toFixed(2);
728
+
729
+ // Update AI FPS display and hide loading indicator once AI starts
730
+ if (data.ai_fps !== undefined && data.ai_fps !== null) {
731
+ // Ensure FPS value is reasonable
732
+ const aiFps = Math.min(Math.max(data.ai_fps, 0), 100);
733
+ aiFpsEl.textContent = aiFps.toFixed(1);
734
+
735
+ // Color code AI FPS for performance indication
736
+ if (aiFps >= 8) {
737
+ aiFpsEl.style.color = '#00ff00'; // Green for good performance
738
+ } else if (aiFps >= 5) {
739
+ aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance
740
+ } else if (aiFps > 0) {
741
+ aiFpsEl.style.color = '#ff0000'; // Red for poor performance
742
+ } else {
743
+ aiFpsEl.style.color = '#888888'; // Gray for inactive
744
+ }
745
+
746
+ // Hide loading indicator once AI inference starts working
747
+ if (aiFps > 0 && gameStarted) {
748
+ loadingEl.style.display = 'none';
749
+ gameStatusEl.textContent = 'Playing';
750
+ gameStatusEl.style.color = '#00ff00';
751
+ }
752
+ }
753
+ }
754
+ };
755
+
756
+ ws.onclose = function(event) {
757
+ statusEl.textContent = 'Disconnected';
758
+ statusEl.style.color = '#ff0000';
759
+ setTimeout(connectWebSocket, 1000); // Reconnect after 1 second
760
+ };
761
+
762
+ ws.onerror = function(event) {
763
+ statusEl.textContent = 'Error';
764
+ statusEl.style.color = '#ff0000';
765
+ };
766
+ }
767
+
768
+ function sendKeyState() {
769
+ if (ws && ws.readyState === WebSocket.OPEN) {
770
+ ws.send(JSON.stringify({
771
+ type: 'keys',
772
+ keys: Array.from(pressedKeys)
773
+ }));
774
+ }
775
+ }
776
+
777
+ function startGame() {
778
+ if (ws && ws.readyState === WebSocket.OPEN) {
779
+ ws.send(JSON.stringify({
780
+ type: 'start'
781
+ }));
782
+ gameStarted = true;
783
+ gameStatusEl.textContent = 'Starting AI...';
784
+ gameStatusEl.style.color = '#ffff00';
785
+ loadingEl.style.display = 'block';
786
+ console.log('Game started');
787
+ }
788
+ }
789
+
790
+ function pauseGame() {
791
+ if (ws && ws.readyState === WebSocket.OPEN) {
792
+ ws.send(JSON.stringify({
793
+ type: 'pause'
794
+ }));
795
+ gameStarted = false;
796
+ gameStatusEl.textContent = 'Paused - Click to Resume';
797
+ gameStatusEl.style.color = '#ffff00';
798
+ console.log('Game paused');
799
+ }
800
+ }
801
+
802
+ function updateKeyDisplay() {
803
+ // Reset all key displays
804
+ Object.values(keyDisplayMap).forEach(id => {
805
+ const el = document.getElementById(id);
806
+ if (el) el.classList.remove('key-pressed');
807
+ });
808
+
809
+ // Highlight pressed keys
810
+ pressedKeys.forEach(key => {
811
+ const displayId = keyDisplayMap[key];
812
+ if (displayId) {
813
+ const el = document.getElementById(displayId);
814
+ if (el) el.classList.add('key-pressed');
815
+ }
816
+ });
817
+ }
818
+
819
+ // Focus canvas and handle keyboard events
820
+ canvas.addEventListener('click', () => {
821
+ canvas.focus();
822
+ if (!gameStarted) {
823
+ startGame();
824
+ }
825
+ });
826
+
827
+ canvas.addEventListener('keydown', (event) => {
828
+ event.preventDefault();
829
+
830
+ // Handle special keys
831
+ if (event.code === 'Enter') {
832
+ if (ws && ws.readyState === WebSocket.OPEN) {
833
+ ws.send(JSON.stringify({
834
+ type: 'reset'
835
+ }));
836
+ console.log('Environment reset requested');
837
+ }
838
+ // Add to pressedKeys for visual feedback
839
+ pressedKeys.add(event.code);
840
+ updateKeyDisplay();
841
+
842
+ // Remove Enter from pressedKeys after a short delay for visual feedback
843
+ setTimeout(() => {
844
+ pressedKeys.delete(event.code);
845
+ updateKeyDisplay();
846
+ }, 200);
847
+ } else if (event.code === 'Escape') {
848
+ pauseGame();
849
+ // Add to pressedKeys for visual feedback
850
+ pressedKeys.add(event.code);
851
+ updateKeyDisplay();
852
+
853
+ // Remove ESC from pressedKeys after a short delay for visual feedback
854
+ setTimeout(() => {
855
+ pressedKeys.delete(event.code);
856
+ updateKeyDisplay();
857
+ }, 200);
858
+ } else {
859
+ // Only send game keys if game is started
860
+ if (gameStarted) {
861
+ pressedKeys.add(event.code);
862
+ updateKeyDisplay();
863
+ sendKeyState();
864
+ }
865
+ }
866
+ });
867
+
868
+ canvas.addEventListener('keyup', (event) => {
869
+ event.preventDefault();
870
+
871
+ // Don't handle special keys release (handled in keydown with timeout)
872
+ if (event.code !== 'Enter' && event.code !== 'Escape') {
873
+ if (gameStarted) {
874
+ pressedKeys.delete(event.code);
875
+ updateKeyDisplay();
876
+ sendKeyState();
877
+ }
878
+ }
879
+ });
880
+
881
+ // Handle mouse events for clicks
882
+ canvas.addEventListener('mousedown', (event) => {
883
+ if (ws && ws.readyState === WebSocket.OPEN) {
884
+ ws.send(JSON.stringify({
885
+ type: 'mouse',
886
+ button: event.button,
887
+ action: 'down',
888
+ x: event.offsetX,
889
+ y: event.offsetY
890
+ }));
891
+ }
892
+ });
893
+
894
+ canvas.addEventListener('mouseup', (event) => {
895
+ if (ws && ws.readyState === WebSocket.OPEN) {
896
+ ws.send(JSON.stringify({
897
+ type: 'mouse',
898
+ button: event.button,
899
+ action: 'up',
900
+ x: event.offsetX,
901
+ y: event.offsetY
902
+ }));
903
+ }
904
+ });
905
+
906
+ // Initialize
907
+ connectWebSocket();
908
+ canvas.focus();
909
+ </script>
910
+ </body>
911
+ </html>
912
+ """
913
+ return html_content
914
+
915
+ @app.websocket("/ws")
916
+ async def websocket_endpoint(websocket: WebSocket):
917
+ """Handle WebSocket connections for real-time game communication"""
918
+ await websocket.accept()
919
+ connected_clients.add(websocket)
920
+
921
+ try:
922
+ while True:
923
+ # Receive messages from client
924
+ data = await websocket.receive_text()
925
+ message = json.loads(data)
926
+
927
+ if message['type'] == 'keys':
928
+ # Update pressed keys
929
+ game_engine.pressed_keys = set(message['keys'])
930
+
931
+ elif message['type'] == 'reset':
932
+ # Handle environment reset request
933
+ game_engine.request_reset()
934
+
935
+ elif message['type'] == 'start':
936
+ # Handle game start request
937
+ game_engine.start_game()
938
+
939
+ elif message['type'] == 'pause':
940
+ # Handle game pause request
941
+ game_engine.pause_game()
942
+
943
+ elif message['type'] == 'mouse':
944
+ # Handle mouse events
945
+ if message['action'] == 'down':
946
+ if message['button'] == 0: # Left click
947
+ game_engine.l_click = True
948
+ elif message['button'] == 2: # Right click
949
+ game_engine.r_click = True
950
+ elif message['action'] == 'up':
951
+ if message['button'] == 0: # Left click
952
+ game_engine.l_click = False
953
+ elif message['button'] == 2: # Right click
954
+ game_engine.r_click = False
955
+
956
+ # Update mouse position (relative to canvas)
957
+ game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px
958
+ game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px
959
+
960
+ except WebSocketDisconnect:
961
+ connected_clients.discard(websocket)
962
+ except Exception as e:
963
+ logger.error(f"WebSocket error: {e}")
964
+ connected_clients.discard(websocket)
965
+
966
+ if __name__ == "__main__":
967
+ # For local development
968
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
969
+
config/agent/csgo.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: agent.AgentConfig
2
+
3
+ denoiser:
4
+ _target_: models.diffusion.DenoiserConfig
5
+ sigma_data: 0.5
6
+ sigma_offset_noise: 0.1
7
+ noise_previous_obs: true
8
+ upsampling_factor: null
9
+ inner_model:
10
+ _target_: models.diffusion.InnerModelConfig
11
+ img_channels: 3
12
+ num_steps_conditioning: 4
13
+ cond_channels: 2048
14
+ depths: [2, 2, 2, 2]
15
+ channels: [128, 256, 512, 1024]
16
+ attn_depths: [0, 0, 1, 1]
17
+
18
+ upsampler:
19
+ _target_: models.diffusion.DenoiserConfig
20
+ sigma_data: 0.5
21
+ sigma_offset_noise: 0.1
22
+ noise_previous_obs: false
23
+ upsampling_factor: 5
24
+ inner_model:
25
+ _target_: models.diffusion.InnerModelConfig
26
+ img_channels: 3
27
+ num_steps_conditioning: 1
28
+ cond_channels: 2048
29
+ depths: [2, 2, 2, 2]
30
+ channels: [64, 64, 128, 256]
31
+ attn_depths: [0, 0, 0, 1]
32
+
33
+ rew_end_model: null
34
+ actor_critic: null
config/env/csgo.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ train:
2
+ id: csgo
3
+ size: [150, 600]
4
+ num_actions: 51
5
+ path_data_low_res: /tmp/dummy_data_low_res
6
+ path_data_full_res: /tmp/dummy_data_full_res
7
+ keymap: csgo
config/trainer.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - env: csgo
4
+ - agent: csgo
5
+ - world_model_env: fast
6
+
7
+ static_dataset:
8
+ path: /tmp/dummy_data_low_res
9
+ ignore_sample_weights: True
config/world_model_env/fast.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: envs.WorldModelEnvConfig
2
+ horizon: 1000
3
+ num_batches_to_preload: 1
4
+ diffusion_sampler_next_obs:
5
+ _target_: models.diffusion.DiffusionSamplerConfig
6
+ num_steps_denoising: 6 # Balanced: better quality than 3, faster than 10
7
+ sigma_min: 0.002
8
+ sigma_max: 5.0
9
+ rho: 7
10
+ order: 1
11
+ diffusion_sampler_upsampling:
12
+ _target_: models.diffusion.DiffusionSamplerConfig
13
+ num_steps_denoising: 4 # Balanced: better quality than 2, faster than 5
14
+ sigma_min: 0.002
15
+ sigma_max: 5.0
16
+ rho: 7
17
+ order: 1
config_web.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration helper for web deployment
3
+ Handles path resolution and model loading for deployment
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Optional
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class WebConfig:
14
+ """Configuration manager for web deployment"""
15
+
16
+ def __init__(self, base_path: Optional[Path] = None):
17
+ if base_path is None:
18
+ base_path = Path.cwd()
19
+ self.base_path = Path(base_path)
20
+
21
+ def get_config_path(self) -> Path:
22
+ """Get configuration directory path"""
23
+ # Try multiple possible locations
24
+ possible_paths = [
25
+ self.base_path / "config",
26
+ self.base_path / "src" / ".." / "config",
27
+ Path(__file__).parent / "config"
28
+ ]
29
+
30
+ for path in possible_paths:
31
+ if path.exists():
32
+ return path.resolve()
33
+
34
+ # Create default config directory
35
+ config_path = self.base_path / "config"
36
+ config_path.mkdir(exist_ok=True)
37
+ return config_path
38
+
39
+ def get_checkpoint_path(self) -> Path:
40
+ """Find and return the best available checkpoint"""
41
+ # Try different possible locations and names
42
+ possible_checkpoints = [
43
+ self.base_path / "agent_epoch_00003.pt",
44
+ self.base_path / "agent_epoch_00003.pt",
45
+ self.base_path / "checkpoints" / "agent_epoch_00003.pt",
46
+ self.base_path / "checkpoints" / "agent_epoch_00003.pt",
47
+ self.base_path / "checkpoints" / "latest.pt",
48
+ ]
49
+
50
+ for ckpt_path in possible_checkpoints:
51
+ if ckpt_path.exists():
52
+ logger.info(f"Found checkpoint: {ckpt_path}")
53
+ return ckpt_path
54
+
55
+ # If no checkpoint found, create a dummy message
56
+ logger.warning("No checkpoint found - you may need to download models")
57
+ return self.base_path / "checkpoints" / "model_not_found.pt"
58
+
59
+ def get_spawn_dir(self) -> Path:
60
+ """Get spawn data directory"""
61
+ spawn_dir = self.base_path / "csgo" / "spawn"
62
+ spawn_dir.mkdir(parents=True, exist_ok=True)
63
+
64
+ # Create dummy spawn data if it doesn't exist
65
+ spawn_subdir = spawn_dir / "0"
66
+ spawn_subdir.mkdir(exist_ok=True)
67
+
68
+ # Create dummy files if they don't exist
69
+ dummy_files = ["act.npy", "full_res.npy", "info.json", "low_res.npy", "next_act.npy"]
70
+ for filename in dummy_files:
71
+ file_path = spawn_subdir / filename
72
+ if not file_path.exists():
73
+ if filename.endswith('.npy'):
74
+ import numpy as np
75
+ np.save(file_path, np.zeros((1, 10))) # Dummy array
76
+ elif filename.endswith('.json'):
77
+ import json
78
+ with open(file_path, 'w') as f:
79
+ json.dump({"dummy": True}, f)
80
+
81
+ return spawn_dir
82
+
83
+ def setup_environment_variables(self):
84
+ """Set up environment variables for deployment"""
85
+ # Disable CUDA if not available (for CPU-only deployment)
86
+ if not self.has_cuda():
87
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
88
+
89
+ # Set Python path
90
+ python_path = str(self.base_path / "src")
91
+ current_path = os.environ.get("PYTHONPATH", "")
92
+ if python_path not in current_path:
93
+ os.environ["PYTHONPATH"] = f"{python_path}:{current_path}" if current_path else python_path
94
+
95
+ def has_cuda(self) -> bool:
96
+ """Check if CUDA is available"""
97
+ try:
98
+ import torch
99
+ return torch.cuda.is_available()
100
+ except ImportError:
101
+ return False
102
+
103
+ def create_default_configs(self):
104
+ """Create default configuration files if they don't exist"""
105
+ config_dir = self.get_config_path()
106
+
107
+ # Create agent config
108
+ agent_dir = config_dir / "agent"
109
+ agent_dir.mkdir(exist_ok=True)
110
+
111
+ agent_config_path = agent_dir / "csgo.yaml"
112
+ if not agent_config_path.exists():
113
+ with open(agent_config_path, 'w') as f:
114
+ f.write("""_target_: agent.AgentConfig
115
+
116
+ denoiser:
117
+ _target_: models.diffusion.DenoiserConfig
118
+ sigma_data: 0.5
119
+ sigma_offset_noise: 0.1
120
+ noise_previous_obs: true
121
+ upsampling_factor: null
122
+ inner_model:
123
+ _target_: models.diffusion.InnerModelConfig
124
+ img_channels: 3
125
+ num_steps_conditioning: 4
126
+ cond_channels: 2048
127
+ depths: [2, 2, 2, 2]
128
+ channels: [128, 256, 512, 1024]
129
+ attn_depths: [0, 0, 1, 1]
130
+
131
+ upsampler:
132
+ _target_: models.diffusion.DenoiserConfig
133
+ sigma_data: 0.5
134
+ sigma_offset_noise: 0.1
135
+ noise_previous_obs: false
136
+ upsampling_factor: 5
137
+ inner_model:
138
+ _target_: models.diffusion.InnerModelConfig
139
+ img_channels: 3
140
+ num_steps_conditioning: 1
141
+ cond_channels: 2048
142
+ depths: [2, 2, 2, 2]
143
+ channels: [64, 64, 128, 256]
144
+ attn_depths: [0, 0, 0, 1]
145
+
146
+ rew_end_model: null
147
+ actor_critic: null
148
+ """)
149
+
150
+ # Create env config
151
+ env_dir = config_dir / "env"
152
+ env_dir.mkdir(exist_ok=True)
153
+
154
+ env_config_path = env_dir / "csgo.yaml"
155
+ if not env_config_path.exists():
156
+ with open(env_config_path, 'w') as f:
157
+ f.write("""train:
158
+ id: csgo
159
+ size: [150, 600]
160
+ num_actions: 51
161
+ path_data_low_res: /tmp/dummy_data_low_res
162
+ path_data_full_res: /tmp/dummy_data_full_res
163
+ keymap: csgo
164
+ """)
165
+
166
+ # Create world model env config
167
+ wm_env_dir = config_dir / "world_model_env"
168
+ wm_env_dir.mkdir(exist_ok=True)
169
+
170
+ wm_config_path = wm_env_dir / "fast.yaml"
171
+ if not wm_config_path.exists():
172
+ with open(wm_config_path, 'w') as f:
173
+ f.write("""_target_: envs.WorldModelEnvConfig
174
+ horizon: 1000
175
+ num_batches_to_preload: 1
176
+ diffusion_sampler_next_obs:
177
+ _target_: models.diffusion.DiffusionSamplerConfig
178
+ num_steps_denoising: 10
179
+ sigma_min: 0.002
180
+ sigma_max: 5.0
181
+ rho: 7
182
+ order: 1
183
+ diffusion_sampler_upsampling:
184
+ _target_: models.diffusion.DiffusionSamplerConfig
185
+ num_steps_denoising: 5
186
+ sigma_min: 0.002
187
+ sigma_max: 5.0
188
+ rho: 7
189
+ order: 1
190
+ """)
191
+
192
+ # Create trainer config
193
+ trainer_config_path = config_dir / "trainer.yaml"
194
+ if not trainer_config_path.exists():
195
+ with open(trainer_config_path, 'w') as f:
196
+ f.write("""defaults:
197
+ - _self_
198
+ - env: csgo
199
+ - agent: csgo
200
+ - world_model_env: fast
201
+
202
+ static_dataset:
203
+ path: /tmp/dummy_data_low_res
204
+ ignore_sample_weights: True
205
+ """)
206
+
207
+ # Global config instance
208
+ web_config = WebConfig()
csgo/spawn/0/act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11830620c54f47d0ee6a9f904e68516980f8cd5af488572bd6e9e4815e8be52d
3
+ size 332
csgo/spawn/0/full_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cff6e9d7871c6f3c622f964fabbc181befeca9ef5b5a8a3f4e6cce1af79e6a8f
3
+ size 1080128
csgo/spawn/0/info.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"original_file_id": "4001-4200/hdf5_dm_july2021_4143.hdf5", "timestep_start": 540}
csgo/spawn/0/low_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d775f579f104caf9e195fa12bdf302e54c3c0f8938483ace0fa2cb75b694be1
3
+ size 43328
csgo/spawn/0/next_act.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:762d9db84444e12912a8e535d10b29783db59f6a7f97579c933496354ffb4bb6
3
+ size 10328
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build-essential
2
+ curl
3
+ git
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML dependencies
2
+ torch>=1.13.0
3
+ torchvision>=0.14.0
4
+ torchaudio>=0.13.0
5
+ numpy>=1.21.0
6
+
7
+ # Configuration management
8
+ hydra-core>=1.2.0
9
+ omegaconf>=2.2.0
10
+
11
+ # Web framework for deployment
12
+ fastapi>=0.68.0
13
+ uvicorn>=0.15.0
14
+ websockets>=10.0
15
+
16
+ # Image processing
17
+ opencv-python-headless>=4.5.0
18
+ Pillow>=8.0.0
19
+
20
+ # Hugging Face integration
21
+ huggingface-hub>=0.10.0
22
+
23
+ # Data handling
24
+ h5py>=3.7.0
25
+
26
+ # Optional: for better performance
27
+ # torch-audio # if needed for audio processing
28
+
29
+ # Development dependencies (uncomment for local development)
30
+ # pytest>=6.0.0
31
+ # black>=21.0.0
32
+ # isort>=5.0.0
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
src/__pycache__/agent.cpython-310.pyc ADDED
Binary file (2.94 kB). View file
 
src/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
src/agent.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from envs import TorchEnv, WorldModelEnv
9
+ from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
10
+ from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
11
+ from models.rew_end_model import RewEndModel, RewEndModelConfig
12
+ from utils import extract_state_dict
13
+
14
+
15
+ @dataclass
16
+ class AgentConfig:
17
+ denoiser: DenoiserConfig
18
+ upsampler: Optional[DenoiserConfig]
19
+ rew_end_model: Optional[RewEndModelConfig]
20
+ actor_critic: Optional[ActorCriticConfig]
21
+ num_actions: int
22
+
23
+ def __post_init__(self) -> None:
24
+ self.denoiser.inner_model.num_actions = self.num_actions
25
+ if self.upsampler is not None:
26
+ self.upsampler.inner_model.num_actions = self.num_actions
27
+ if self.rew_end_model is not None:
28
+ self.rew_end_model.num_actions = self.num_actions
29
+ if self.actor_critic is not None:
30
+ self.actor_critic.num_actions = self.num_actions
31
+
32
+
33
+ class Agent(nn.Module):
34
+ def __init__(self, cfg: AgentConfig) -> None:
35
+ super().__init__()
36
+ self.denoiser = Denoiser(cfg.denoiser)
37
+ self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None
38
+ self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None
39
+ self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None
40
+
41
+ @property
42
+ def device(self):
43
+ return self.denoiser.device
44
+
45
+ def setup_training(
46
+ self,
47
+ sigma_distribution_cfg: SigmaDistributionConfig,
48
+ sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig],
49
+ actor_critic_loss_cfg: Optional[ActorCriticLossConfig],
50
+ rl_env: Optional[Union[TorchEnv, WorldModelEnv]],
51
+ ) -> None:
52
+ self.denoiser.setup_training(sigma_distribution_cfg)
53
+ if self.upsampler is not None:
54
+ self.upsampler.setup_training(sigma_distribution_cfg_upsampler)
55
+ if self.actor_critic is not None:
56
+ self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)
57
+
58
+ def load(
59
+ self,
60
+ path_to_ckpt: Path,
61
+ load_denoiser: bool = True,
62
+ load_upsampler: bool = True,
63
+ load_rew_end_model: bool = True,
64
+ load_actor_critic: bool = True,
65
+ ) -> None:
66
+ sd = torch.load(Path(path_to_ckpt), map_location=self.device)
67
+ if load_denoiser:
68
+ self.denoiser.load_state_dict(extract_state_dict(sd, "denoiser"))
69
+ if load_upsampler:
70
+ self.upsampler.load_state_dict(extract_state_dict(sd, "upsampler"))
71
+ if load_rew_end_model and self.rew_end_model is not None:
72
+ self.rew_end_model.load_state_dict(extract_state_dict(sd, "rew_end_model"))
73
+ if load_actor_critic and self.actor_critic is not None:
74
+ self.actor_critic.load_state_dict(extract_state_dict(sd, "actor_critic"))
src/coroutines/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+
3
+
4
+ def coroutine(func):
5
+ @wraps(func)
6
+ def primer(*args, **kwargs):
7
+ gen = func(*args, **kwargs)
8
+ next(gen)
9
+ return gen
10
+
11
+ return primer
src/coroutines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (484 Bytes). View file
 
src/coroutines/__pycache__/collector.cpython-310.pyc ADDED
Binary file (4.22 kB). View file
 
src/coroutines/__pycache__/env_loop.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
src/coroutines/collector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Generator, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ from . import coroutine
10
+ from data import Episode, Dataset
11
+ from envs import TorchEnv
12
+ from .env_loop import make_env_loop
13
+ from utils import Logs
14
+
15
+
16
+ @coroutine
17
+ def make_collector(
18
+ env: TorchEnv,
19
+ model: nn.Module,
20
+ dataset: Dataset,
21
+ epsilon: float = 0.0,
22
+ reset_every_collect: bool = False,
23
+ verbose: bool = True,
24
+ ) -> Generator[Logs, int, None]:
25
+ num_envs = env.num_envs
26
+
27
+ env_loop, buffer, episode_ids, dead = (None,) * 4
28
+ num_steps, num_episodes, to_log, pbar = (None,) * 4
29
+
30
+ def setup_new_collect():
31
+ nonlocal num_steps, num_episodes, buffer, to_log, pbar
32
+ num_steps = 0
33
+ num_episodes = 0
34
+ buffer = defaultdict(list)
35
+ to_log = []
36
+ pbar = tqdm(
37
+ total=num_to_collect.total,
38
+ unit=num_to_collect.unit,
39
+ desc=f"Collect {dataset.name}",
40
+ disable=not verbose,
41
+ )
42
+
43
+ def reset():
44
+ nonlocal env_loop, episode_ids, dead
45
+ env_loop = make_env_loop(env, model, epsilon)
46
+ episode_ids = defaultdict(lambda: None)
47
+ dead = [None] * num_envs
48
+
49
+ num_to_collect = yield
50
+ setup_new_collect()
51
+ reset()
52
+
53
+ while True:
54
+ with torch.no_grad():
55
+ all_obs, act, rew, end, trunc, *_, [infos] = env_loop.send(1)
56
+
57
+ num_steps += num_envs
58
+ pbar.update(num_envs if num_to_collect.steps is not None else 0)
59
+
60
+ for i, (o, a, r, e, t) in enumerate(zip(all_obs, act, rew, end, trunc)):
61
+ buffer[i].append((o, a, r, e, t))
62
+ dead[i] = (e + t).clip(max=1).item()
63
+
64
+ num_episodes += sum(dead)
65
+
66
+ can_stop = num_to_collect.can_stop(num_steps, num_episodes)
67
+
68
+ count_dead = 0
69
+ for i in range(num_envs):
70
+ # Store incomplete episodes only when reset_every_collect is set to False (train)
71
+ add_to_dataset = dead[i] or (can_stop and not reset_every_collect)
72
+ if add_to_dataset:
73
+ info = {"final_observation": infos["final_observation"][count_dead]} if dead[i] else {}
74
+ ep = Episode(*(torch.cat(x, dim=0) for x in zip(*buffer[i])), info).to("cpu")
75
+ if episode_ids[i] is not None:
76
+ ep = dataset.load_episode(episode_ids[i]) + ep
77
+ episode_ids[i] = dataset.add_episode(ep, episode_id=episode_ids[i])
78
+
79
+ if dead[i]:
80
+ to_log.append(
81
+ {
82
+ f"{dataset.name}/episode_id": episode_ids[i],
83
+ **ep.compute_metrics(),
84
+ }
85
+ )
86
+ buffer[i] = []
87
+ episode_ids[i] = None
88
+ pbar.update(1 if num_to_collect.episodes is not None else 0)
89
+
90
+ count_dead += dead[i]
91
+
92
+ if can_stop:
93
+ pbar.close()
94
+ metrics = {
95
+ "num_steps": dataset.num_steps,
96
+ "counts/rew_-1": dataset.counts_rew[0],
97
+ "counts/rew__0": dataset.counts_rew[1],
98
+ "counts/rew_+1": dataset.counts_rew[2],
99
+ "counts/end_0": dataset.counts_end[0],
100
+ "counts/end_1": dataset.counts_end[1],
101
+ }
102
+ to_log.append({f"{dataset.name}/{k}": v for k, v in metrics.items()})
103
+ num_to_collect = yield to_log
104
+ setup_new_collect()
105
+ if reset_every_collect:
106
+ reset()
107
+
108
+
109
+ @dataclass
110
+ class NumToCollect:
111
+ steps: Optional[int] = None
112
+ episodes: Optional[int] = None
113
+
114
+ def __post_init__(self) -> None:
115
+ assert (self.steps is None) != (self.episodes is None)
116
+
117
+ def can_stop(self, num_steps: int, num_episodes: int) -> bool:
118
+ return num_steps >= self.steps if self.steps is not None else num_episodes >= self.episodes
119
+
120
+ @property
121
+ def unit(self) -> str:
122
+ return "steps" if self.steps is not None else "eps"
123
+
124
+ @property
125
+ def total(self) -> int:
126
+ return self.steps if self.steps is not None else self.episodes
src/coroutines/env_loop.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Generator, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.distributions.categorical import Categorical
7
+
8
+ from . import coroutine
9
+ from envs import TorchEnv, WorldModelEnv
10
+
11
+
12
+ @coroutine
13
+ def make_env_loop(
14
+ env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0
15
+ ) -> Generator[Tuple[torch.Tensor, ...], int, None]:
16
+ num_steps = yield
17
+
18
+ hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
19
+ cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
20
+
21
+ seed = random.randint(0, 2**31 - 1)
22
+ obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)])
23
+
24
+ while True:
25
+ hx, cx = hx.detach(), cx.detach()
26
+ all_ = []
27
+ infos = []
28
+ n = 0
29
+
30
+ while n < num_steps:
31
+ logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx))
32
+ act = Categorical(logits=logits_act).sample()
33
+
34
+ if random.random() < epsilon:
35
+ act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device)
36
+
37
+ next_obs, rew, end, trunc, info = env.step(act)
38
+
39
+ if n > 0:
40
+ val_bootstrap = val.detach().clone()
41
+ if dead.any():
42
+ val_bootstrap[dead] = val_final_obs
43
+ all_[-1][-1] = val_bootstrap
44
+
45
+ dead = torch.logical_or(end, trunc)
46
+
47
+ if dead.any():
48
+ with torch.no_grad():
49
+ _, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead]))
50
+ reset_gate = 1 - dead.float().unsqueeze(1)
51
+ hx = hx * reset_gate
52
+ cx = cx * reset_gate
53
+ if "burnin_obs" in info:
54
+ burnin_obs = info["burnin_obs"]
55
+ for i in range(burnin_obs.size(1)):
56
+ _, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead]))
57
+
58
+ all_.append([obs, act, rew, end, trunc, logits_act, val, None])
59
+ infos.append(info)
60
+
61
+ obs = next_obs
62
+ n += 1
63
+
64
+ with torch.no_grad():
65
+ _, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx
66
+
67
+ if dead.any():
68
+ val_bootstrap[dead] = val_final_obs
69
+
70
+ all_[-1][-1] = val_bootstrap
71
+
72
+ all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_))
73
+
74
+ num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos
src/csgo/__init__.py ADDED
File without changes
src/csgo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
src/csgo/__pycache__/action_processing.cpython-310.pyc ADDED
Binary file (5.76 kB). View file
 
src/csgo/__pycache__/keymap.cpython-310.pyc ADDED
Binary file (622 Bytes). View file
 
src/csgo/__pycache__/web_action_processing.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
src/csgo/action_processing.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credits: some parts are taken and modified from the file `config.py` from https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict, List, Set, Tuple
7
+
8
+ import numpy as np
9
+ import pygame
10
+ import torch
11
+
12
+ from .keymap import CSGO_FORBIDDEN_COMBINATIONS, CSGO_KEYMAP
13
+
14
+
15
+ @dataclass
16
+ class CSGOAction:
17
+ keys: List[int]
18
+ mouse_x: float
19
+ mouse_y: float
20
+ l_click: bool
21
+ r_click: bool
22
+
23
+ def __post_init__(self) -> None:
24
+ self.keys = filter_keys_pressed_forbidden(self.keys)
25
+ self.process_mouse()
26
+
27
+ @property
28
+ def key_names(self) -> List[str]:
29
+ return [pygame.key.name(key) for key in self.keys]
30
+
31
+ def process_mouse(self) -> None:
32
+ # Clip and match mouse to closest in list of possibles
33
+ x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1])
34
+ y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1])
35
+ self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x))
36
+ self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y))
37
+
38
+ # Use arrows to override mouse movements
39
+ for key in self.key_names:
40
+ if key == "left":
41
+ self.mouse_x = -60
42
+ elif key == "right":
43
+ self.mouse_x = +60
44
+ elif key == "up":
45
+ self.mouse_y = -50
46
+ elif key == "down":
47
+ self.mouse_y = +50
48
+
49
+
50
+ def print_csgo_action(action: CSGOAction) -> Tuple[str]:
51
+ action_names = [CSGO_KEYMAP[k] for k in action.keys] if len(action.keys) > 0 else []
52
+ action_names = [x for x in action_names if not x.startswith("camera_")]
53
+ keys = " + ".join(action_names)
54
+ mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0)
55
+ clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click
56
+ return keys, mouse, clicks
57
+
58
+
59
+ MOUSE_X_POSSIBLES = [
60
+ -1000,
61
+ -500,
62
+ -300,
63
+ -200,
64
+ -100,
65
+ -60,
66
+ -30,
67
+ -20,
68
+ -10,
69
+ -4,
70
+ -2,
71
+ 0,
72
+ 2,
73
+ 4,
74
+ 10,
75
+ 20,
76
+ 30,
77
+ 60,
78
+ 100,
79
+ 200,
80
+ 300,
81
+ 500,
82
+ 1000,
83
+ ]
84
+
85
+ MOUSE_Y_POSSIBLES = [
86
+ -200,
87
+ -100,
88
+ -50,
89
+ -20,
90
+ -10,
91
+ -4,
92
+ -2,
93
+ 0,
94
+ 2,
95
+ 4,
96
+ 10,
97
+ 20,
98
+ 50,
99
+ 100,
100
+ 200,
101
+ ]
102
+
103
+ MOUSE_X_LIM = (MOUSE_X_POSSIBLES[0], MOUSE_X_POSSIBLES[-1])
104
+ MOUSE_Y_LIM = (MOUSE_Y_POSSIBLES[0], MOUSE_Y_POSSIBLES[-1])
105
+ N_KEYS = 11 # number of keyboard outputs, w,s,a,d,space,ctrl,shift,1,2,3,r
106
+ N_CLICKS = 2 # number of mouse buttons, left, right
107
+ N_MOUSE_X = len(MOUSE_X_POSSIBLES) # number of outputs on mouse x axis
108
+ N_MOUSE_Y = len(MOUSE_Y_POSSIBLES) # number of outputs on mouse y axis
109
+
110
+
111
+ def encode_csgo_action(csgo_action: CSGOAction, device: torch.device) -> torch.Tensor:
112
+
113
+ # mouse_x = csgo_action.mouse_x
114
+ # mouse_y = csgo_action.mouse_y
115
+
116
+ keys_pressed_onehot = np.zeros(N_KEYS)
117
+ mouse_x_onehot = np.zeros(N_MOUSE_X)
118
+ mouse_y_onehot = np.zeros(N_MOUSE_Y)
119
+ l_click_onehot = np.zeros(1)
120
+ r_click_onehot = np.zeros(1)
121
+
122
+ for key in csgo_action.key_names:
123
+ if key == "w":
124
+ keys_pressed_onehot[0] = 1
125
+ elif key == "a":
126
+ keys_pressed_onehot[1] = 1
127
+ elif key == "s":
128
+ keys_pressed_onehot[2] = 1
129
+ elif key == "d":
130
+ keys_pressed_onehot[3] = 1
131
+ elif key == "space":
132
+ keys_pressed_onehot[4] = 1
133
+ elif key == "left ctrl":
134
+ keys_pressed_onehot[5] = 1
135
+ elif key == "left shift":
136
+ keys_pressed_onehot[6] = 1
137
+ elif key == "1":
138
+ keys_pressed_onehot[7] = 1
139
+ elif key == "2":
140
+ keys_pressed_onehot[8] = 1
141
+ elif key == "3":
142
+ keys_pressed_onehot[9] = 1
143
+ elif key == "r":
144
+ keys_pressed_onehot[10] = 1
145
+
146
+ l_click_onehot[0] = int(csgo_action.l_click)
147
+ r_click_onehot[0] = int(csgo_action.r_click)
148
+
149
+ mouse_x_onehot[MOUSE_X_POSSIBLES.index(csgo_action.mouse_x)] = 1
150
+ mouse_y_onehot[MOUSE_Y_POSSIBLES.index(csgo_action.mouse_y)] = 1
151
+
152
+ assert mouse_x_onehot.sum() == 1
153
+ assert mouse_y_onehot.sum() == 1
154
+
155
+ return torch.tensor(
156
+ np.concatenate((
157
+ keys_pressed_onehot,
158
+ l_click_onehot,
159
+ r_click_onehot,
160
+ mouse_x_onehot,
161
+ mouse_y_onehot,
162
+ )),
163
+ device=device,
164
+ dtype=torch.float32,
165
+ )
166
+
167
+
168
+ def decode_csgo_action(y_preds: torch.Tensor) -> CSGOAction:
169
+ y_preds = y_preds.squeeze()
170
+ keys_pred = y_preds[0:N_KEYS]
171
+ l_click_pred = y_preds[N_KEYS : N_KEYS + 1]
172
+ r_click_pred = y_preds[N_KEYS + 1 : N_KEYS + N_CLICKS]
173
+ mouse_x_pred = y_preds[N_KEYS + N_CLICKS : N_KEYS + N_CLICKS + N_MOUSE_X]
174
+ mouse_y_pred = y_preds[
175
+ N_KEYS + N_CLICKS + N_MOUSE_X : N_KEYS + N_CLICKS + N_MOUSE_X + N_MOUSE_Y
176
+ ]
177
+
178
+ keys_pressed = []
179
+ keys_pressed_onehot = np.round(keys_pred)
180
+ if keys_pressed_onehot[0] == 1:
181
+ keys_pressed.append("w")
182
+ if keys_pressed_onehot[1] == 1:
183
+ keys_pressed.append("a")
184
+ if keys_pressed_onehot[2] == 1:
185
+ keys_pressed.append("s")
186
+ if keys_pressed_onehot[3] == 1:
187
+ keys_pressed.append("d")
188
+ if keys_pressed_onehot[4] == 1:
189
+ keys_pressed.append("space")
190
+ if keys_pressed_onehot[5] == 1:
191
+ keys_pressed.append("left ctrl")
192
+ if keys_pressed_onehot[6] == 1:
193
+ keys_pressed.append("left shift")
194
+ if keys_pressed_onehot[7] == 1:
195
+ keys_pressed.append("1")
196
+ if keys_pressed_onehot[8] == 1:
197
+ keys_pressed.append("2")
198
+ if keys_pressed_onehot[9] == 1:
199
+ keys_pressed.append("3")
200
+ if keys_pressed_onehot[10] == 1:
201
+ keys_pressed.append("r")
202
+
203
+ l_click = int(np.round(l_click_pred))
204
+ r_click = int(np.round(r_click_pred))
205
+
206
+ id = np.argmax(mouse_x_pred)
207
+ mouse_x = MOUSE_X_POSSIBLES[id]
208
+ id = np.argmax(mouse_y_pred)
209
+ mouse_y = MOUSE_Y_POSSIBLES[id]
210
+
211
+ keys_pressed = [pygame.key.key_code(x) for x in keys_pressed]
212
+
213
+ return CSGOAction(keys_pressed, mouse_x, mouse_y, bool(l_click), bool(r_click))
214
+
215
+
216
+ def filter_keys_pressed_forbidden(keys_pressed: List[int], keymap: Dict[int, str] = CSGO_KEYMAP, forbidden_combinations: List[Set[str]] = CSGO_FORBIDDEN_COMBINATIONS) -> List[int]:
217
+ keys = set()
218
+ names = set()
219
+ for key in keys_pressed:
220
+ if key not in keymap:
221
+ continue
222
+ name = keymap[key]
223
+ keys.add(key)
224
+ names.add(name)
225
+ for forbidden in forbidden_combinations:
226
+ if forbidden.issubset(names):
227
+ keys.remove(key)
228
+ names.remove(name)
229
+ break
230
+ return list(filter(lambda key: key in keys, keys_pressed))
src/csgo/keymap.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygame
2
+
3
+
4
+ CSGO_KEYMAP = {
5
+ pygame.K_w: "up",
6
+ pygame.K_d: "right",
7
+ pygame.K_a: "left",
8
+ pygame.K_s: "down",
9
+ pygame.K_SPACE: "jump",
10
+ pygame.K_LCTRL: "crouch",
11
+ pygame.K_LSHIFT: "walk",
12
+ pygame.K_1: "weapon1",
13
+ pygame.K_2: "weapon2",
14
+ pygame.K_3: "weapon3",
15
+ pygame.K_r: "reload",
16
+
17
+ # Override mouse movement with arrows
18
+ pygame.K_UP: "camera_up",
19
+ pygame.K_RIGHT: "camera_right",
20
+ pygame.K_LEFT: "camera_left",
21
+ pygame.K_DOWN: "camera_down",
22
+ }
23
+
24
+
25
+ CSGO_FORBIDDEN_COMBINATIONS = [
26
+ {"up", "down"},
27
+ {"left", "right"},
28
+ {"weapon1", "weapon2"},
29
+ {"weapon1", "weapon3"},
30
+ {"weapon2", "weapon3"},
31
+ {"camera_up", "camera_down"},
32
+ {"camera_left", "camera_right"},
33
+ ]
src/csgo/web_action_processing.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web-compatible action processing for CSGO actions
3
+ Converts web keyboard inputs to CSGO actions without pygame dependency
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List, Set, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ # Web key code to CSGO action mapping
13
+ WEB_KEYMAP = {
14
+ 'KeyW': "up",
15
+ 'KeyD': "right",
16
+ 'KeyA': "left",
17
+ 'KeyS': "down",
18
+ 'Space': "jump",
19
+ 'ControlLeft': "crouch",
20
+ 'ShiftLeft': "walk",
21
+ 'Digit1': "weapon1",
22
+ 'Digit2': "weapon2",
23
+ 'Digit3': "weapon3",
24
+ 'KeyR': "reload",
25
+ 'ArrowUp': "camera_up",
26
+ 'ArrowRight': "camera_right",
27
+ 'ArrowLeft': "camera_left",
28
+ 'ArrowDown': "camera_down",
29
+ }
30
+
31
+ # Forbidden key combinations (same logic as original)
32
+ WEB_FORBIDDEN_COMBINATIONS = [
33
+ {"up", "down"},
34
+ {"left", "right"},
35
+ {"weapon1", "weapon2"},
36
+ {"weapon1", "weapon3"},
37
+ {"weapon2", "weapon3"},
38
+ {"camera_up", "camera_down"},
39
+ {"camera_left", "camera_right"},
40
+ ]
41
+
42
+ @dataclass
43
+ class WebCSGOAction:
44
+ """Web-compatible CSGO action without pygame dependencies"""
45
+ key_names: List[str] # Use string names instead of pygame key codes
46
+ mouse_x: float
47
+ mouse_y: float
48
+ l_click: bool
49
+ r_click: bool
50
+
51
+ def __post_init__(self) -> None:
52
+ self.key_names = filter_web_keys_forbidden(self.key_names)
53
+ self.process_mouse()
54
+
55
+ def process_mouse(self) -> None:
56
+ """Process mouse movement with discretization"""
57
+ # Import mouse constants
58
+ from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, MOUSE_X_LIM, MOUSE_Y_LIM
59
+
60
+ # Clip and match mouse to closest in list of possibles
61
+ x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1])
62
+ y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1])
63
+ self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x))
64
+ self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y))
65
+
66
+ # Use arrow keys to override mouse movements
67
+ for key_name in self.key_names:
68
+ if key_name == "camera_left":
69
+ self.mouse_x = -60
70
+ elif key_name == "camera_right":
71
+ self.mouse_x = +60
72
+ elif key_name == "camera_up":
73
+ self.mouse_y = -50
74
+ elif key_name == "camera_down":
75
+ self.mouse_y = +50
76
+
77
+ def filter_web_keys_forbidden(key_names: List[str]) -> List[str]:
78
+ """Filter out forbidden key combinations"""
79
+ names = set(key_names)
80
+ filtered_names = []
81
+
82
+ for key_name in key_names:
83
+ # Check if adding this key would create a forbidden combination
84
+ test_names = set(filtered_names + [key_name])
85
+ is_forbidden = False
86
+
87
+ for forbidden in WEB_FORBIDDEN_COMBINATIONS:
88
+ if forbidden.issubset(test_names):
89
+ is_forbidden = True
90
+ break
91
+
92
+ if not is_forbidden:
93
+ filtered_names.append(key_name)
94
+
95
+ return filtered_names
96
+
97
+ def web_keys_to_csgo_action_names(pressed_web_keys: Set[str]) -> List[str]:
98
+ """Convert set of pressed web keys to CSGO action names"""
99
+ action_names = []
100
+ for web_key in pressed_web_keys:
101
+ if web_key in WEB_KEYMAP:
102
+ action_names.append(WEB_KEYMAP[web_key])
103
+ return action_names
104
+
105
+ def encode_web_csgo_action(web_action: WebCSGOAction, device: torch.device) -> torch.Tensor:
106
+ """Encode web CSGO action to tensor format (compatible with original encoding)"""
107
+ from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, N_KEYS, N_CLICKS, N_MOUSE_X, N_MOUSE_Y
108
+
109
+ keys_pressed_onehot = np.zeros(N_KEYS)
110
+ mouse_x_onehot = np.zeros(N_MOUSE_X)
111
+ mouse_y_onehot = np.zeros(N_MOUSE_Y)
112
+ l_click_onehot = np.zeros(1)
113
+ r_click_onehot = np.zeros(1)
114
+
115
+ # Map action names to one-hot encoding
116
+ for action_name in web_action.key_names:
117
+ if action_name == "up": # w key
118
+ keys_pressed_onehot[0] = 1
119
+ elif action_name == "left": # a key
120
+ keys_pressed_onehot[1] = 1
121
+ elif action_name == "down": # s key
122
+ keys_pressed_onehot[2] = 1
123
+ elif action_name == "right": # d key
124
+ keys_pressed_onehot[3] = 1
125
+ elif action_name == "jump": # space
126
+ keys_pressed_onehot[4] = 1
127
+ elif action_name == "crouch": # ctrl
128
+ keys_pressed_onehot[5] = 1
129
+ elif action_name == "walk": # shift
130
+ keys_pressed_onehot[6] = 1
131
+ elif action_name == "weapon1": # 1
132
+ keys_pressed_onehot[7] = 1
133
+ elif action_name == "weapon2": # 2
134
+ keys_pressed_onehot[8] = 1
135
+ elif action_name == "weapon3": # 3
136
+ keys_pressed_onehot[9] = 1
137
+ elif action_name == "reload": # r
138
+ keys_pressed_onehot[10] = 1
139
+
140
+ l_click_onehot[0] = int(web_action.l_click)
141
+ r_click_onehot[0] = int(web_action.r_click)
142
+
143
+ mouse_x_onehot[MOUSE_X_POSSIBLES.index(web_action.mouse_x)] = 1
144
+ mouse_y_onehot[MOUSE_Y_POSSIBLES.index(web_action.mouse_y)] = 1
145
+
146
+ assert mouse_x_onehot.sum() == 1
147
+ assert mouse_y_onehot.sum() == 1
148
+
149
+ return torch.tensor(
150
+ np.concatenate((
151
+ keys_pressed_onehot,
152
+ l_click_onehot,
153
+ r_click_onehot,
154
+ mouse_x_onehot,
155
+ mouse_y_onehot,
156
+ )),
157
+ device=device,
158
+ dtype=torch.float32,
159
+ )
160
+
161
+ def print_web_csgo_action(action: WebCSGOAction) -> Tuple[str, str, str]:
162
+ """Print web CSGO action in readable format"""
163
+ action_names = [name for name in action.key_names if not name.startswith("camera_")]
164
+ keys = " + ".join(action_names)
165
+ mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0)
166
+ clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click
167
+ return keys, mouse, clicks
src/data/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .batch import Batch
2
+ from .batch_sampler import BatchSampler
3
+ from .dataset import Dataset, CSGOHdf5Dataset
4
+ from .episode import Episode
5
+ from .segment import Segment, SegmentId
6
+ from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
src/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (501 Bytes). View file
 
src/data/__pycache__/batch.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
src/data/__pycache__/batch_sampler.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
src/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (9.52 kB). View file
 
src/data/__pycache__/episode.cpython-310.pyc ADDED
Binary file (3.76 kB). View file
 
src/data/__pycache__/segment.cpython-310.pyc ADDED
Binary file (1.16 kB). View file
 
src/data/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
src/data/batch.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List
4
+
5
+ import torch
6
+
7
+ from .segment import SegmentId
8
+
9
+
10
+ @dataclass
11
+ class Batch:
12
+ obs: torch.ByteTensor
13
+ act: torch.LongTensor
14
+ rew: torch.FloatTensor
15
+ end: torch.LongTensor
16
+ trunc: torch.LongTensor
17
+ mask_padding: torch.BoolTensor
18
+ info: List[Dict[str, Any]]
19
+ segment_ids: List[SegmentId]
20
+
21
+ def pin_memory(self) -> Batch:
22
+ return Batch(**{k: v if k in ("segment_ids", "info") else v.pin_memory() for k, v in self.__dict__.items()})
23
+
24
+ def to(self, device: torch.device) -> Batch:
25
+ return Batch(**{k: v if k in ("segment_ids", "info") else v.to(device) for k, v in self.__dict__.items()})
src/data/batch_sampler.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Generator, List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .dataset import CSGOHdf5Dataset, Dataset
7
+ from .segment import SegmentId
8
+
9
+
10
+ class BatchSampler(torch.utils.data.Sampler):
11
+ def __init__(
12
+ self,
13
+ dataset: Dataset,
14
+ rank: int,
15
+ world_size: int,
16
+ batch_size: int,
17
+ seq_length: int,
18
+ sample_weights: Optional[List[float]] = None,
19
+ can_sample_beyond_end: bool = False,
20
+ ) -> None:
21
+ super().__init__(dataset)
22
+ assert isinstance(dataset, (Dataset, CSGOHdf5Dataset))
23
+ self.dataset = dataset
24
+ self.rank = rank
25
+ self.world_size = world_size
26
+ self.sample_weights = sample_weights
27
+ self.batch_size = batch_size
28
+ self.seq_length = seq_length
29
+ self.can_sample_beyond_end = can_sample_beyond_end
30
+
31
+ def __len__(self):
32
+ raise NotImplementedError
33
+
34
+ def __iter__(self) -> Generator[List[SegmentId], None, None]:
35
+ while True:
36
+ yield self.sample()
37
+
38
+ def sample(self) -> List[SegmentId]:
39
+ num_episodes = self.dataset.num_episodes
40
+
41
+ if (self.sample_weights is None) or num_episodes < len(self.sample_weights):
42
+ weights = self.dataset.lengths / self.dataset.num_steps
43
+ else:
44
+ weights = self.sample_weights
45
+ num_weights = len(self.sample_weights)
46
+ assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1
47
+ sizes = [
48
+ num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1)
49
+ for i in range(num_weights)
50
+ ]
51
+ weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)]
52
+
53
+ episodes_partition = np.arange(self.rank, num_episodes, self.world_size)
54
+ weights = np.array(weights[self.rank::self.world_size])
55
+ max_eps = self.batch_size
56
+ episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum())
57
+ episode_ids = episode_ids.repeat(self.batch_size // max_eps)
58
+ timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
59
+
60
+ # padding allowed, both before start and after end
61
+ if self.can_sample_beyond_end:
62
+ starts = timesteps - np.random.randint(0, self.seq_length, len(timesteps))
63
+ stops = starts + self.seq_length
64
+
65
+ # padding allowed only before start
66
+ else:
67
+ stops = np.minimum(
68
+ self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.seq_length, len(timesteps))
69
+ )
70
+ starts = stops - self.seq_length
71
+
72
+ return [SegmentId(*x) for x in zip(episode_ids, starts, stops)]
src/data/dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ import multiprocessing as mp
3
+ from pathlib import Path
4
+ import shutil
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ import h5py
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset as TorchDataset
12
+
13
+ from .episode import Episode
14
+ from .segment import Segment, SegmentId
15
+ from .utils import make_segment
16
+ from utils import StateDictMixin
17
+
18
+
19
+ class Dataset(StateDictMixin, TorchDataset):
20
+ def __init__(
21
+ self,
22
+ directory: Path,
23
+ dataset_full_res: Optional[TorchDataset],
24
+ name: Optional[str] = None,
25
+ cache_in_ram: bool = False,
26
+ use_manager: bool = False,
27
+ save_on_disk: bool = True,
28
+ ) -> None:
29
+ super().__init__()
30
+
31
+ # State
32
+ self.is_static = False
33
+ self.num_episodes = None
34
+ self.num_steps = None
35
+ self.start_idx = None
36
+ self.lengths = None
37
+ self.counter_rew = None
38
+ self.counter_end = None
39
+
40
+ self._directory = Path(directory).expanduser()
41
+ self._name = name if name is not None else self._directory.stem
42
+ self._cache_in_ram = cache_in_ram
43
+ self._save_on_disk = save_on_disk
44
+ self._default_path = self._directory / "info.pt"
45
+ self._cache = mp.Manager().dict() if use_manager else {}
46
+ self._reset()
47
+
48
+ self._dataset_full_res = dataset_full_res
49
+
50
+ def __len__(self) -> int:
51
+ return self.num_steps
52
+
53
+ def __getitem__(self, segment_id: SegmentId) -> Segment:
54
+ episode = self.load_episode(segment_id.episode_id)
55
+ segment = make_segment(episode, segment_id, should_pad=True)
56
+ if self._dataset_full_res is not None:
57
+ segment_id_full_res = SegmentId(episode.info["original_file_id"], segment_id.start, segment_id.stop)
58
+ segment.info["full_res"] = self._dataset_full_res[segment_id_full_res].obs
59
+ elif "full_res" in segment.info:
60
+ segment.info["full_res"] = segment.info["full_res"][segment_id.start:segment_id.stop]
61
+ return segment
62
+
63
+ def __str__(self) -> str:
64
+ return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps."
65
+
66
+ @property
67
+ def name(self) -> str:
68
+ return self._name
69
+
70
+ @property
71
+ def counts_rew(self) -> List[int]:
72
+ return [self.counter_rew[r] for r in [-1, 0, 1]]
73
+
74
+ @property
75
+ def counts_end(self) -> List[int]:
76
+ return [self.counter_end[e] for e in [0, 1]]
77
+
78
+ def _reset(self) -> None:
79
+ self.num_episodes = 0
80
+ self.num_steps = 0
81
+ self.start_idx = np.array([], dtype=np.int64)
82
+ self.lengths = np.array([], dtype=np.int64)
83
+ self.counter_rew = Counter()
84
+ self.counter_end = Counter()
85
+ self._cache.clear()
86
+
87
+ def clear(self) -> None:
88
+ self.assert_not_static()
89
+ if self._directory.is_dir():
90
+ shutil.rmtree(self._directory)
91
+ self._reset()
92
+
93
+ def load_episode(self, episode_id: int) -> Episode:
94
+ if self._cache_in_ram and episode_id in self._cache:
95
+ episode = self._cache[episode_id]
96
+ else:
97
+ episode = Episode.load(self._get_episode_path(episode_id))
98
+ if self._cache_in_ram:
99
+ self._cache[episode_id] = episode
100
+ return episode
101
+
102
+ def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
103
+ self.assert_not_static()
104
+ episode = episode.to("cpu")
105
+
106
+ if episode_id is None:
107
+ episode_id = self.num_episodes
108
+ self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
109
+ self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
110
+ self.num_steps += len(episode)
111
+ self.num_episodes += 1
112
+
113
+ else:
114
+ assert episode_id < self.num_episodes
115
+ old_episode = self.load_episode(episode_id)
116
+ incr_num_steps = len(episode) - len(old_episode)
117
+ self.lengths[episode_id] = len(episode)
118
+ self.start_idx[episode_id + 1 :] += incr_num_steps
119
+ self.num_steps += incr_num_steps
120
+ self.counter_rew.subtract(old_episode.rew.sign().tolist())
121
+ self.counter_end.subtract(old_episode.end.tolist())
122
+
123
+ self.counter_rew.update(episode.rew.sign().tolist())
124
+ self.counter_end.update(episode.end.tolist())
125
+
126
+ if self._save_on_disk:
127
+ episode.save(self._get_episode_path(episode_id))
128
+
129
+ if self._cache_in_ram:
130
+ self._cache[episode_id] = episode
131
+
132
+ return episode_id
133
+
134
+ def _get_episode_path(self, episode_id: int) -> Path:
135
+ n = 3 # number of hierarchies
136
+ powers = np.arange(n)
137
+ subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers
138
+ subfolders = [int(x) for x in subfolders[::-1]]
139
+ subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)])
140
+ return self._directory / subfolders / f"{episode_id}.pt"
141
+
142
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
143
+ super().load_state_dict(state_dict)
144
+ self._cache.clear()
145
+
146
+ def assert_not_static(self) -> None:
147
+ assert not self.is_static, "Trying to modify a static dataset."
148
+
149
+ def save_to_default_path(self) -> None:
150
+ self._default_path.parent.mkdir(exist_ok=True, parents=True)
151
+ torch.save(self.state_dict(), self._default_path)
152
+
153
+ def load_from_default_path(self) -> None:
154
+ if self._default_path.is_file():
155
+ self.load_state_dict(torch.load(self._default_path))
156
+
157
+
158
+ class CSGOHdf5Dataset(StateDictMixin, TorchDataset):
159
+ def __init__(self, directory: Path) -> None:
160
+ super().__init__()
161
+ filenames = sorted(Path(directory).rglob("*.hdf5"), key=lambda x: int(x.stem.split("_")[-1]))
162
+ self._filenames = {f"{x.parent.name}/{x.name}": x for x in filenames}
163
+ self._length_one_episode = 1000
164
+ self.num_episodes = len(self._filenames)
165
+ self.num_steps = self._length_one_episode * self.num_episodes
166
+ self.lengths = np.array([self._length_one_episode] * self.num_episodes, dtype=np.int64)
167
+
168
+ def __len__(self) -> int:
169
+ return self.num_steps
170
+
171
+ def save_to_default_path(self) -> None:
172
+ pass
173
+
174
+ def __getitem__(self, segment_id: SegmentId) -> Segment:
175
+ assert segment_id.start < self._length_one_episode and segment_id.stop > 0 and segment_id.start < segment_id.stop
176
+ pad_len_right = max(0, segment_id.stop - self._length_one_episode)
177
+ pad_len_left = max(0, -segment_id.start)
178
+
179
+ start = max(0, segment_id.start)
180
+ stop = min(self._length_one_episode, segment_id.stop)
181
+ mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
182
+
183
+ with h5py.File(self._filenames[segment_id.episode_id], "r") as f:
184
+ obs = torch.stack([torch.tensor(f[f"frame_{i}_x"][:]).flip(2).permute(2, 0, 1).div(255).mul(2).sub(1) for i in range(start, stop)])
185
+ act = torch.tensor(np.array([f[f"frame_{i}_y"][:] for i in range(start, stop)]))
186
+ states = torch.stack([torch.tensor(f[f"frame_{i}_observation"][:]) for i in range(start, stop)])
187
+ ego_state = torch.stack([torch.tensor(f[f"frame_{i}_ego_state"][:]) for i in range(start, stop)])
188
+
189
+ def pad(x):
190
+ right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
191
+ return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
192
+
193
+ obs = pad(obs)
194
+ act = pad(act)
195
+ rew = torch.zeros(obs.size(0))
196
+ end = torch.zeros(obs.size(0), dtype=torch.uint8)
197
+ trunc = torch.zeros(obs.size(0), dtype=torch.uint8)
198
+ return Segment(obs, act, rew, end, trunc, mask_padding, states=states, ego_state=ego_state, info={}, id=SegmentId(segment_id.episode_id, start, stop))
199
+
200
+ def load_episode(self, episode_id: int) -> Episode: # used by DatasetTraverser
201
+ s = self[SegmentId(episode_id, 0, self._length_one_episode)]
202
+ return Episode(s.obs, s.act, s.rew, s.end, s.trunc, s.info)
src/data/episode.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any, Dict, Optional
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass
10
+ class Episode:
11
+ obs: torch.FloatTensor
12
+ act: torch.LongTensor
13
+ rew: torch.FloatTensor
14
+ end: torch.ByteTensor
15
+ trunc: torch.ByteTensor
16
+ info: Dict[str, Any]
17
+ states: torch.FloatTensor
18
+ ego_state: torch.FloatTensor
19
+
20
+ def __len__(self) -> int:
21
+ return self.obs.size(0)
22
+
23
+ def __add__(self, other: Episode) -> Episode:
24
+ assert self.dead.sum() == 0
25
+ d = {k: torch.cat((v, other.__dict__[k]), dim=0) for k, v in self.__dict__.items() if k != "info"}
26
+ return Episode(**d, info=merge_info(self.info, other.info))
27
+
28
+ def to(self, device) -> Episode:
29
+ return Episode(**{k: v.to(device) if k != "info" else v for k, v in self.__dict__.items()})
30
+
31
+ @property
32
+ def dead(self) -> torch.ByteTensor:
33
+ return (self.end + self.trunc).clip(max=1)
34
+
35
+ def compute_metrics(self) -> Dict[str, Any]:
36
+ return {"length": len(self), "return": self.rew.sum().item()}
37
+
38
+ @classmethod
39
+ def load(cls, path: Path, map_location: Optional[torch.device] = None) -> Episode:
40
+ return cls(
41
+ **{
42
+ k: v.div(255).mul(2).sub(1) if k == "obs" else v
43
+ for k, v in torch.load(Path(path), map_location=map_location).items()
44
+ }
45
+ )
46
+
47
+ def save(self, path: Path) -> None:
48
+ path = Path(path)
49
+ path.parent.mkdir(parents=True, exist_ok=True)
50
+ d = {k: v.add(1).div(2).mul(255).byte() if k == "obs" else v for k, v in self.__dict__.items()}
51
+ torch.save(d, path.with_suffix(".tmp"))
52
+ path.with_suffix(".tmp").rename(path)
53
+
54
+
55
+ def merge_info(info_a, info_b):
56
+ keys_a = set(info_a)
57
+ keys_b = set(info_b)
58
+ intersection = keys_a & keys_b
59
+ info = {
60
+ **{k: info_a[k] for k in keys_a if k not in intersection},
61
+ **{k: info_b[k] for k in keys_b if k not in intersection},
62
+ **{k: torch.cat((info_a[k], info_b[k]), dim=0) for k in intersection},
63
+ }
64
+ return info
src/data/segment.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Union
4
+
5
+ import torch
6
+
7
+
8
+ @dataclass
9
+ class SegmentId:
10
+ episode_id: Union[int, str]
11
+ start: int
12
+ stop: int
13
+
14
+
15
+ @dataclass
16
+ class Segment:
17
+ obs: torch.FloatTensor
18
+ act: torch.LongTensor
19
+ rew: torch.FloatTensor
20
+ end: torch.ByteTensor
21
+ trunc: torch.ByteTensor
22
+ mask_padding: torch.BoolTensor
23
+ states: torch.FloatTensor
24
+ ego_state: torch.FloatTensor
25
+ info: Dict[str, Any]
26
+ id: SegmentId
27
+
28
+ @property
29
+ def effective_size(self):
30
+ return self.mask_padding.sum().item()
src/data/utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Generator, List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from .batch import Batch
8
+ from .episode import Episode
9
+ from .segment import Segment, SegmentId
10
+
11
+
12
+ def collate_segments_to_batch(segments: List[Segment]) -> Batch:
13
+ attrs = ("obs", "act", "rew", "end", "trunc", "mask_padding")
14
+ stack = (torch.stack([getattr(s, x) for s in segments]) for x in attrs)
15
+ return Batch(*stack, [s.info for s in segments], [s.id for s in segments])
16
+
17
+
18
+ def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
19
+ assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
20
+ pad_len_right = max(0, segment_id.stop - len(episode))
21
+ pad_len_left = max(0, -segment_id.start)
22
+ assert pad_len_right == pad_len_left == 0 or should_pad
23
+
24
+ def pad(x):
25
+ right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
26
+ return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
27
+
28
+ start = max(0, segment_id.start)
29
+ stop = min(len(episode), segment_id.stop)
30
+ mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
31
+
32
+ return Segment(
33
+ pad(episode.obs[start:stop]),
34
+ pad(episode.act[start:stop]),
35
+ pad(episode.rew[start:stop]),
36
+ pad(episode.end[start:stop]),
37
+ pad(episode.trunc[start:stop]),
38
+ mask_padding,
39
+ pad(episode.states[start:stop]),
40
+ pad(episode.ego_state[start:stop]),
41
+ info=episode.info,
42
+ id=SegmentId(segment_id.episode_id, start, stop),
43
+ )
44
+
45
+
46
+ class DatasetTraverser:
47
+ def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
48
+ self.dataset = dataset
49
+ self.batch_num_samples = batch_num_samples
50
+ self.chunk_size = chunk_size
51
+
52
+ def __len__(self):
53
+ return math.ceil(
54
+ sum(
55
+ [
56
+ math.ceil(self.dataset.lengths[episode_id] / self.chunk_size)
57
+ - int(self.dataset.lengths[episode_id] % self.chunk_size == 1)
58
+ for episode_id in range(self.dataset.num_episodes)
59
+ ]
60
+ )
61
+ / self.batch_num_samples
62
+ )
63
+
64
+ def __iter__(self) -> Generator[Batch, None, None]:
65
+ chunks = []
66
+ for episode_id in range(self.dataset.num_episodes):
67
+ episode = self.dataset.load_episode(episode_id)
68
+ segments = []
69
+ for i in range(math.ceil(len(episode) / self.chunk_size)):
70
+ start = i * self.chunk_size
71
+ stop = (i + 1) * self.chunk_size
72
+ segment = make_segment(
73
+ episode,
74
+ SegmentId(episode_id, start, stop),
75
+ should_pad=True,
76
+ )
77
+ segment_id_full_res = SegmentId(episode.info["original_file_id"], start, stop)
78
+ segment.info["full_res"] = self.dataset._dataset_full_res[segment_id_full_res].obs
79
+ chunks.append(segment)
80
+ if chunks[-1].effective_size < 2:
81
+ chunks.pop()
82
+
83
+ while len(chunks) >= self.batch_num_samples:
84
+ yield collate_segments_to_batch(chunks[: self.batch_num_samples])
85
+ chunks = chunks[self.batch_num_samples :]
86
+
87
+ if len(chunks) > 0:
88
+ yield collate_segments_to_batch(chunks)
89
+
src/envs/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .env import make_atari_env, TorchEnv
2
+ from .world_model_env import WorldModelEnv, WorldModelEnvConfig