musictimer commited on
Commit
02c6351
·
1 Parent(s): d294854

Fix initial bugs

Browse files
=0.16.0 ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Collecting gradio
2
+ Downloading gradio-5.44.1-py3-none-any.whl.metadata (16 kB)
3
+ Requirement already satisfied: huggingface-hub in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (0.26.1)
4
+ Collecting aiofiles<25.0,>=22.0 (from gradio)
5
+ Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
6
+ Requirement already satisfied: anyio<5.0,>=3.0 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (4.9.0)
7
+ Collecting brotli>=1.1.0 (from gradio)
8
+ Downloading Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.5 kB)
9
+ Collecting fastapi<1.0,>=0.115.2 (from gradio)
10
+ Downloading fastapi-0.116.1-py3-none-any.whl.metadata (28 kB)
11
+ Collecting ffmpy (from gradio)
12
+ Downloading ffmpy-0.6.1-py3-none-any.whl.metadata (2.9 kB)
13
+ Collecting gradio-client==1.12.1 (from gradio)
14
+ Downloading gradio_client-1.12.1-py3-none-any.whl.metadata (7.1 kB)
15
+ Collecting groovy~=0.1 (from gradio)
16
+ Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
17
+ Requirement already satisfied: httpx<1.0,>=0.24.1 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (0.28.1)
18
+ Collecting huggingface-hub
19
+ Using cached huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
20
+ Requirement already satisfied: jinja2<4.0 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (3.1.6)
21
+ Requirement already satisfied: markupsafe<4.0,>=2.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (3.0.1)
22
+ Requirement already satisfied: numpy<3.0,>=1.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (1.26.0)
23
+ Collecting orjson~=3.0 (from gradio)
24
+ Downloading orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (41 kB)
25
+ Requirement already satisfied: packaging in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (25.0)
26
+ Requirement already satisfied: pandas<3.0,>=1.0 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (2.3.0)
27
+ Requirement already satisfied: pillow<12.0,>=8.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (10.3.0)
28
+ Collecting pydantic<2.12,>=2.0 (from gradio)
29
+ Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)
30
+ Collecting pydub (from gradio)
31
+ Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
32
+ Collecting python-multipart>=0.0.18 (from gradio)
33
+ Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
34
+ Requirement already satisfied: pyyaml<7.0,>=5.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (6.0.2)
35
+ Collecting ruff>=0.9.3 (from gradio)
36
+ Downloading ruff-0.12.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
37
+ Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
38
+ Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
39
+ Collecting semantic-version~=2.0 (from gradio)
40
+ Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
41
+ Collecting starlette<1.0,>=0.40.0 (from gradio)
42
+ Downloading starlette-0.47.3-py3-none-any.whl.metadata (6.2 kB)
43
+ Collecting tomlkit<0.14.0,>=0.12.0 (from gradio)
44
+ Downloading tomlkit-0.13.3-py3-none-any.whl.metadata (2.8 kB)
45
+ Collecting typer<1.0,>=0.12 (from gradio)
46
+ Downloading typer-0.17.4-py3-none-any.whl.metadata (15 kB)
47
+ Requirement already satisfied: typing-extensions~=4.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (4.12.2)
48
+ Collecting uvicorn>=0.14.0 (from gradio)
49
+ Downloading uvicorn-0.35.0-py3-none-any.whl.metadata (6.5 kB)
50
+ Requirement already satisfied: fsspec in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio-client==1.12.1->gradio) (2024.9.0)
51
+ Collecting websockets<16.0,>=10.0 (from gradio-client==1.12.1->gradio)
52
+ Downloading websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
53
+ Requirement already satisfied: filelock in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from huggingface-hub) (3.16.1)
54
+ Requirement already satisfied: requests in /home/alienware3/.local/lib/python3.10/site-packages (from huggingface-hub) (2.32.4)
55
+ Requirement already satisfied: tqdm>=4.42.1 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from huggingface-hub) (4.66.4)
56
+ Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub)
57
+ Using cached hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
58
+ Requirement already satisfied: exceptiongroup>=1.0.2 in /home/alienware3/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (1.3.0)
59
+ Requirement already satisfied: idna>=2.8 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (3.10)
60
+ Requirement already satisfied: sniffio>=1.1 in /home/alienware3/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)
61
+ Requirement already satisfied: certifi in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from httpx<1.0,>=0.24.1->gradio) (2024.8.30)
62
+ Requirement already satisfied: httpcore==1.* in /home/alienware3/.local/lib/python3.10/site-packages (from httpx<1.0,>=0.24.1->gradio) (1.0.9)
63
+ Requirement already satisfied: h11>=0.16 in /home/alienware3/.local/lib/python3.10/site-packages (from httpcore==1.*->httpx<1.0,>=0.24.1->gradio) (0.16.0)
64
+ Requirement already satisfied: python-dateutil>=2.8.2 in /home/alienware3/.local/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2.9.0.post0)
65
+ Requirement already satisfied: pytz>=2020.1 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2024.2)
66
+ Requirement already satisfied: tzdata>=2022.7 in /home/alienware3/.local/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
67
+ Collecting annotated-types>=0.6.0 (from pydantic<2.12,>=2.0->gradio)
68
+ Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)
69
+ Collecting pydantic-core==2.33.2 (from pydantic<2.12,>=2.0->gradio)
70
+ Downloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
71
+ Collecting typing-inspection>=0.4.0 (from pydantic<2.12,>=2.0->gradio)
72
+ Downloading typing_inspection-0.4.1-py3-none-any.whl.metadata (2.6 kB)
73
+ Requirement already satisfied: click>=8.0.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio) (8.1.7)
74
+ Collecting shellingham>=1.3.0 (from typer<1.0,>=0.12->gradio)
75
+ Downloading shellingham-1.5.4-py2.py3-none-any.whl.metadata (3.5 kB)
76
+ Requirement already satisfied: rich>=10.11.0 in /home/alienware3/.local/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio) (14.0.0)
77
+ Requirement already satisfied: charset_normalizer<4,>=2 in /home/alienware3/.local/lib/python3.10/site-packages (from requests->huggingface-hub) (3.3.2)
78
+ Requirement already satisfied: urllib3<3,>=1.21.1 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from requests->huggingface-hub) (2.2.3)
79
+ Requirement already satisfied: six>=1.5 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.16.0)
80
+ Requirement already satisfied: markdown-it-py>=2.2.0 in /home/alienware3/.local/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)
81
+ Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/alienware3/.local/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.2)
82
+ Requirement already satisfied: mdurl~=0.1 in /home/alienware3/.local/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)
83
+ Downloading gradio-5.44.1-py3-none-any.whl (60.2 MB)
84
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.2/60.2 MB 98.5 MB/s eta 0:00:00
85
+ Downloading gradio_client-1.12.1-py3-none-any.whl (324 kB)
86
+ Using cached huggingface_hub-0.34.4-py3-none-any.whl (561 kB)
87
+ Downloading aiofiles-24.1.0-py3-none-any.whl (15 kB)
88
+ Downloading Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.0 MB)
89
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.0/3.0 MB 100.4 MB/s eta 0:00:00
90
+ Downloading fastapi-0.116.1-py3-none-any.whl (95 kB)
91
+ Downloading groovy-0.1.2-py3-none-any.whl (14 kB)
92
+ Using cached hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
93
+ Downloading orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (132 kB)
94
+ Downloading pydantic-2.11.7-py3-none-any.whl (444 kB)
95
+ Downloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
96
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 94.6 MB/s eta 0:00:00
97
+ Downloading python_multipart-0.0.20-py3-none-any.whl (24 kB)
98
+ Downloading ruff-0.12.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
99
+ ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.3/13.3 MB 104.4 MB/s eta 0:00:00
100
+ Downloading safehttpx-0.1.6-py3-none-any.whl (8.7 kB)
101
+ Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)
102
+ Downloading starlette-0.47.3-py3-none-any.whl (72 kB)
103
+ Downloading tomlkit-0.13.3-py3-none-any.whl (38 kB)
104
+ Downloading typer-0.17.4-py3-none-any.whl (46 kB)
105
+ Downloading uvicorn-0.35.0-py3-none-any.whl (66 kB)
106
+ Downloading ffmpy-0.6.1-py3-none-any.whl (5.5 kB)
107
+ Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
108
+ Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)
109
+ Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)
110
+ Downloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)
111
+ Downloading websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (181 kB)
112
+ Installing collected packages: pydub, brotli, websockets, uvicorn, typing-inspection, tomlkit, shellingham, semantic-version, ruff, python-multipart, pydantic-core, orjson, hf-xet, groovy, ffmpy, annotated-types, aiofiles, pydantic, huggingface-hub, typer, starlette, safehttpx, gradio-client, fastapi, gradio
113
+ Attempting uninstall: huggingface-hub
114
+ Found existing installation: huggingface-hub 0.26.1
115
+ Uninstalling huggingface-hub-0.26.1:
116
+ Successfully uninstalled huggingface-hub-0.26.1
117
+ Successfully installed aiofiles-24.1.0 annotated-types-0.7.0 brotli-1.1.0 fastapi-0.116.1 ffmpy-0.6.1 gradio-5.44.1 gradio-client-1.12.1 groovy-0.1.2 hf-xet-1.1.9 huggingface-hub-0.34.4 orjson-3.11.3 pydantic-2.11.7 pydantic-core-2.33.2 pydub-0.25.1 python-multipart-0.0.20 ruff-0.12.12 safehttpx-0.1.6 semantic-version-2.10.0 shellingham-1.5.4 starlette-0.47.3 tomlkit-0.13.3 typer-0.17.4 typing-inspection-0.4.1 uvicorn-0.35.0 websockets-15.0.1
=4.0.0 ADDED
File without changes
DEPLOYMENT.md ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diamond CSGO AI Player - Deployment Guide
2
+
3
+ ## 🚀 Deploying to Hugging Face Spaces
4
+
5
+ ### Prerequisites
6
+ 1. Hugging Face account
7
+ 2. Model checkpoint files (`agent_epoch_00206.pt` or similar)
8
+ 3. Git and Git LFS installed
9
+
10
+ ### Step 1: Prepare Repository
11
+
12
+ 1. **Clone/Fork this repository**
13
+ 2. **Install Git LFS** (for large model files):
14
+ ```bash
15
+ git lfs install
16
+ git lfs track "*.pt"
17
+ git add .gitattributes
18
+ ```
19
+
20
+ 3. **Add your model checkpoint**:
21
+ ```bash
22
+ # Copy your trained model to the project root
23
+ cp /path/to/your/agent_epoch_00206.pt .
24
+ git add agent_epoch_00206.pt
25
+ git commit -m "Add trained model checkpoint"
26
+ ```
27
+
28
+ ### Step 2: Create Hugging Face Space
29
+
30
+ 1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
31
+ 2. Click "Create new Space"
32
+ 3. Configure:
33
+ - **Space name**: `diamond-csgo-ai` (or your choice)
34
+ - **License**: Your preferred license
35
+ - **Space SDK**: `Docker`
36
+ - **Space hardware**:
37
+ - `CPU basic` (free) - for demo/testing
38
+ - `GPU T4 small` (paid) - for better performance
39
+
40
+ ### Step 3: Upload Code
41
+
42
+ ```bash
43
+ # Clone your new space
44
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
45
+ cd YOUR_SPACE_NAME
46
+
47
+ # Copy all project files
48
+ cp -r /path/to/diamond/* .
49
+
50
+ # Commit and push
51
+ git add .
52
+ git commit -m "Initial Diamond CSGO AI deployment"
53
+ git push
54
+ ```
55
+
56
+ ### Step 4: Configuration Files
57
+
58
+ Ensure these files are in your space root:
59
+
60
+ - `app.py` - Main FastAPI application
61
+ - `requirements.txt` - Python dependencies
62
+ - `Dockerfile` - Container configuration
63
+ - `README.md` - Space description
64
+ - `packages.txt` - System packages (if needed)
65
+
66
+ ### Step 5: Model Setup
67
+
68
+ If your model is too large for Git (>100MB), use Git LFS or download from Hub:
69
+
70
+ ```python
71
+ # In your app.py, add model downloading:
72
+ from huggingface_hub import hf_hub_download
73
+
74
+ def download_model():
75
+ return hf_hub_download(
76
+ repo_id="YOUR_USERNAME/YOUR_MODEL_REPO",
77
+ filename="agent_epoch_00206.pt"
78
+ )
79
+ ```
80
+
81
+ ## 🔧 Local Testing
82
+
83
+ Before deploying, test locally:
84
+
85
+ ```bash
86
+ # Install dependencies
87
+ pip install -r requirements.txt
88
+
89
+ # Run tests
90
+ python test_web_app.py
91
+
92
+ # Start local server
93
+ python run_web_demo.py
94
+ ```
95
+
96
+ Visit `http://localhost:7860` to test the interface.
97
+
98
+ ## ⚙️ Configuration Options
99
+
100
+ ### Hardware Requirements
101
+
102
+ | Tier | CPU | RAM | GPU | Performance |
103
+ |------|-----|-----|-----|-------------|
104
+ | Free | 2 vCPU | 16GB | None | Basic demo |
105
+ | Basic GPU | 4 vCPU | 16GB | T4 | Good performance |
106
+ | Premium | 8 vCPU | 32GB | A10G | Best experience |
107
+
108
+ ### Environment Variables
109
+
110
+ Add these in your Space settings:
111
+ - `CUDA_VISIBLE_DEVICES=""` (for CPU-only)
112
+ - `PYTHONPATH="/app/src:/app"`
113
+
114
+ ## 🎮 Usage Instructions
115
+
116
+ Once deployed, users can:
117
+
118
+ 1. Visit your Space URL
119
+ 2. Click on the game canvas
120
+ 3. Use keyboard controls:
121
+ - **WASD** - Movement
122
+ - **Space** - Jump
123
+ - **Arrow keys** - Camera
124
+ - **1,2,3** - Weapons
125
+ - **R** - Reload
126
+ - **M** - Switch Human/AI mode
127
+
128
+ ## 🐛 Troubleshooting
129
+
130
+ ### Common Issues
131
+
132
+ 1. **Model not loading**:
133
+ - Check checkpoint file exists
134
+ - Verify file size (<5GB for Spaces)
135
+ - Use Git LFS for large files
136
+
137
+ 2. **Import errors**:
138
+ - Check `requirements.txt` is complete
139
+ - Verify Python path in `Dockerfile`
140
+
141
+ 3. **Performance issues**:
142
+ - Use GPU hardware tier
143
+ - Reduce model complexity
144
+ - Lower frame rate
145
+
146
+ 4. **WebSocket connection failed**:
147
+ - Check firewall settings
148
+ - Verify port 7860 is accessible
149
+ - Try different browser
150
+
151
+ ### Debug Mode
152
+
153
+ Enable debug logging:
154
+ ```python
155
+ import logging
156
+ logging.basicConfig(level=logging.DEBUG)
157
+ ```
158
+
159
+ ## 📊 Monitoring
160
+
161
+ Monitor your Space:
162
+ - View logs in HF Spaces interface
163
+ - Check GPU utilization
164
+ - Monitor user sessions
165
+
166
+ ## 🔄 Updates
167
+
168
+ To update your deployed Space:
169
+ ```bash
170
+ git pull # Get latest changes
171
+ git add .
172
+ git commit -m "Update to latest version"
173
+ git push # Automatically redeploys
174
+ ```
175
+
176
+ ## 💡 Tips for Success
177
+
178
+ 1. **Start with CPU tier** to test basic functionality
179
+ 2. **Use smaller models** for faster loading
180
+ 3. **Test thoroughly locally** before deploying
181
+ 4. **Monitor resource usage** to optimize costs
182
+ 5. **Add usage instructions** in your Space README
183
+
184
+ ## 🎯 Next Steps
185
+
186
+ After successful deployment:
187
+ - Share your Space with the community
188
+ - Collect user feedback
189
+ - Iterate on the interface
190
+ - Add new features like replay saving
191
+ - Consider multi-user support
192
+
193
+ Happy deploying! 🚀
HF_SPACES_CACHE_FIX.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🔧 HF Spaces Cache Permission Fix
2
+
3
+ ## ❌ **Problem:**
4
+ ```
5
+ ERROR:app:Failed to load model: [Errno 13] Permission denied: '/.cache'
6
+ ```
7
+
8
+ HF Spaces containers can't write to the root `/.cache` directory, causing model downloads to fail.
9
+
10
+ ## ✅ **Solution Applied:**
11
+
12
+ ### 1. **Fixed Cache Directory in app.py**
13
+ - ✅ Set custom cache directory: `/tmp/torch_cache`
14
+ - ✅ Added proper permissions handling
15
+ - ✅ Fixed OMP_NUM_THREADS environment variable issue
16
+
17
+ ### 2. **Updated Dockerfile**
18
+ - ✅ Set environment variables to use `/tmp` for caches
19
+ - ✅ Pre-create cache directories
20
+ - ✅ Fixed OMP_NUM_THREADS value
21
+
22
+ ### 3. **Key Changes Made:**
23
+
24
+ #### **app.py Changes:**
25
+ ```python
26
+ # Fixed cache directory for torch.hub
27
+ state_dict = torch.hub.load_state_dict_from_url(
28
+ model_url,
29
+ map_location=device,
30
+ model_dir=cache_dir, # Custom cache dir
31
+ check_hash=False # Skip hash check for speed
32
+ )
33
+
34
+ # Fixed environment variables
35
+ os.environ["OMP_NUM_THREADS"] = "2" # Valid integer
36
+ os.environ["TORCH_HOME"] = "/tmp/torch"
37
+ os.environ["HF_HOME"] = "/tmp/huggingface"
38
+ ```
39
+
40
+ #### **Dockerfile Changes:**
41
+ ```dockerfile
42
+ ENV OMP_NUM_THREADS=2
43
+ ENV TORCH_HOME=/tmp/torch
44
+ ENV HF_HOME=/tmp/huggingface
45
+ ENV TRANSFORMERS_CACHE=/tmp/transformers
46
+
47
+ RUN mkdir -p /tmp/torch /tmp/huggingface /tmp/transformers
48
+ ```
49
+
50
+ ## 🚀 **Expected Results:**
51
+ - ✅ No more "Permission denied: /.cache" errors
52
+ - ✅ No more "Invalid value for environment variable OMP_NUM_THREADS" warnings
53
+ - ✅ Model downloads work properly on HF Spaces
54
+ - ✅ App starts correctly and clicking works
55
+
56
+ ## 📋 **To Deploy:**
57
+ 1. **Commit the changes**: `git add . && git commit -m "Fix HF Spaces cache permissions"`
58
+ 2. **Push to HF Spaces**: `git push`
59
+ 3. **Monitor logs**: Check that download succeeds without permission errors
60
+ 4. **Test**: Click the game area - should work now!
61
+
62
+ ## 🔍 **Log Messages to Look For:**
63
+ ### ✅ **Success:**
64
+ ```
65
+ INFO:app:Loading state dict from https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt
66
+ INFO:app:State dict loaded, applying to agent...
67
+ INFO:app:Model has actor_critic weights: False
68
+ INFO:app:Actor-critic model exists but has no trained weights - using dummy mode!
69
+ INFO:app:WebPlayEnv set to human control mode (no trained weights)
70
+ INFO:app:Models initialized successfully!
71
+ ```
72
+
73
+ ### ❌ **If Still Failing:**
74
+ ```
75
+ ERROR:app:Failed to load model: [Errno 13] Permission denied
76
+ ```
77
+
78
+ ## 🎯 **What This Fixes:**
79
+ 1. ✅ **Model downloading** - now uses writable `/tmp` directory
80
+ 2. ✅ **Environment variables** - OMP_NUM_THREADS is valid
81
+ 3. ✅ **Game clicking** - works after model loads (even without actor_critic)
82
+ 4. ✅ **HF Spaces compatibility** - follows container best practices
83
+
84
+ The app should now work perfectly on HF Spaces! 🎉
HF_SPACES_DEPLOYMENT_GUIDE.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Hugging Face Spaces Deployment - Troubleshooting Guide
2
+
3
+ ## ✅ **Your Local Fix Applied**
4
+ Great news! The core issue has been resolved locally. The problem was that the downloaded model doesn't contain `actor_critic` weights, but the code assumed it did. This caused a `NoneType` error when clicking to start the game.
5
+
6
+ **Fixed**: The app now properly detects when `actor_critic` weights are missing and falls back to human control mode instead of crashing.
7
+
8
+ ## 🔍 **Potential HF Spaces Issues & Solutions**
9
+
10
+ ### **Issue 1: Model Download Timeouts** ⏰
11
+
12
+ **Symptoms:**
13
+ - "Model loading timed out" message
14
+ - App shows loading forever
15
+ - Click doesn't start the game
16
+
17
+ **Root Cause:** HF Spaces network can be slower, 5-minute timeout may not be enough.
18
+
19
+ **Solution:**
20
+ ```python
21
+ # In app.py, update the timeout in _load_model_from_url_async():
22
+ success = await asyncio.wait_for(future, timeout=900.0) # 15 minutes instead of 5
23
+ ```
24
+
25
+ ### **Issue 2: Memory Limitations** 💾
26
+
27
+ **Symptoms:**
28
+ - App crashes during model loading
29
+ - "Out of memory" errors in logs
30
+ - Models load but inference fails
31
+
32
+ **Root Cause:** HF Spaces free tier has only 16GB RAM.
33
+
34
+ **Quick Fix:** Force CPU-only mode
35
+ ```python
36
+ # Add at the top of app.py
37
+ import os
38
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU mode for HF Spaces
39
+ ```
40
+
41
+ **Better Solution:** Add memory management
42
+ ```python
43
+ # Add memory cleanup after model loading
44
+ import gc
45
+ gc.collect()
46
+ if torch.cuda.is_available():
47
+ torch.cuda.empty_cache()
48
+ ```
49
+
50
+ ### **Issue 3: WebSocket Connection Failures** 🔌
51
+
52
+ **Symptoms:**
53
+ - "Connection Error" or "Disconnected" status
54
+ - Click works but no response
55
+ - Frequent reconnections
56
+
57
+ **Root Cause:** HF Spaces proxy/domain restrictions.
58
+
59
+ **Solution:** Update the WebSocket connection code in the HTML template:
60
+ ```javascript
61
+ // Replace the connectWebSocket function in app.py HTML
62
+ function connectWebSocket() {
63
+ const isHFSpaces = window.location.hostname.includes('huggingface.co');
64
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
65
+ const wsUrl = `${protocol}//${window.location.host}/ws`;
66
+
67
+ ws = new WebSocket(wsUrl);
68
+
69
+ // Longer timeout for HF Spaces
70
+ const timeout = isHFSpaces ? 30000 : 10000;
71
+
72
+ const connectTimer = setTimeout(() => {
73
+ if (ws.readyState !== WebSocket.OPEN) {
74
+ ws.close();
75
+ setTimeout(connectWebSocket, 5000); // Retry after 5s
76
+ }
77
+ }, timeout);
78
+
79
+ ws.onopen = function(event) {
80
+ clearTimeout(connectTimer);
81
+ statusEl.textContent = 'Connected';
82
+ statusEl.style.color = '#00ff00';
83
+
84
+ // Re-send start if user already clicked
85
+ if (gameStarted && !gamePlaying) {
86
+ ws.send(JSON.stringify({ type: 'start' }));
87
+ }
88
+ };
89
+ }
90
+ ```
91
+
92
+ ### **Issue 4: Actor-Critic Model Missing** 🧠
93
+
94
+ **Already Fixed!** ✅ The app now handles this gracefully:
95
+ - Detects missing `actor_critic` weights
96
+ - Falls back to human control mode
97
+ - Shows proper warning messages
98
+ - Game still works (user can control manually)
99
+
100
+ ### **Issue 5: Dockerfile Optimization** 🐳
101
+
102
+ **Update your Dockerfile for HF Spaces:**
103
+ ```dockerfile
104
+ # Add these optimizations
105
+ ENV SHM_SIZE=2g
106
+ ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
107
+ ENV OMP_NUM_THREADS=4
108
+
109
+ # Add health check
110
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=60s \
111
+ CMD curl --fail http://localhost:7860/health || exit 1
112
+ ```
113
+
114
+ ## 🚀 **Quick Deployment Checklist**
115
+
116
+ ### **Before Deploying:**
117
+ 1. ✅ **Test locally with conda**: `conda activate diamond && python run_web_demo.py`
118
+ 2. ✅ **Verify the fix works**: Click should now work (even without actor_critic weights)
119
+ 3. ✅ **Check model download**: Test internet connectivity for HF model URL
120
+
121
+ ### **For HF Spaces Deployment:**
122
+
123
+ 1. **Update timeout values:**
124
+ ```python
125
+ # In app.py line ~153
126
+ success = await asyncio.wait_for(future, timeout=900.0) # 15 min
127
+ ```
128
+
129
+ 2. **Add health check endpoint:**
130
+ ```python
131
+ @app.get("/health")
132
+ async def health_check():
133
+ return {
134
+ "status": "healthy",
135
+ "models_ready": game_engine.models_ready,
136
+ "actor_critic_loaded": game_engine.actor_critic_loaded
137
+ }
138
+ ```
139
+
140
+ 3. **Force CPU mode for free tier:**
141
+ ```python
142
+ # Add at app.py startup
143
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
144
+ ```
145
+
146
+ 4. **Update Dockerfile** with the optimizations above
147
+
148
+ 5. **Test WebSocket connection** - add the improved connection handling
149
+
150
+ ## 🔧 **Debugging on HF Spaces**
151
+
152
+ ### **Check Logs:**
153
+ 1. Go to your Space page on HuggingFace
154
+ 2. Click "Logs" tab
155
+ 3. Look for these messages:
156
+ - ✅ `"Actor-critic model exists but has no trained weights - using dummy mode!"`
157
+ - ✅ `"WebPlayEnv set to human control mode"`
158
+ - ❌ `"Model loading timed out"`
159
+ - ❌ `"WebSocket error"`
160
+
161
+ ### **Test Health Endpoint:**
162
+ - Visit: `https://your-space.hf.space/health`
163
+ - Should return JSON with status info
164
+
165
+ ### **Browser Console:**
166
+ - Open Developer Tools (F12)
167
+ - Check for WebSocket connection errors
168
+ - Look for JavaScript errors during click
169
+
170
+ ## 🎯 **Expected Behavior After Fixes**
171
+
172
+ 1. **App loads** → Shows loading progress bar
173
+ 2. **Models initialize** → Either loads actor_critic OR shows "no trained weights"
174
+ 3. **User clicks game area** → Game starts immediately (no hanging)
175
+ 4. **If actor_critic missing** → User gets manual control (still playable!)
176
+ 5. **If actor_critic loaded** → AI takes control automatically
177
+
178
+ ## 🆘 **If Issues Persist**
179
+
180
+ **Quick Diagnostic:**
181
+ ```python
182
+ # Add this test endpoint to app.py
183
+ @app.get("/debug")
184
+ async def debug_info():
185
+ return {
186
+ "models_ready": game_engine.models_ready,
187
+ "actor_critic_loaded": game_engine.actor_critic_loaded,
188
+ "loading_status": game_engine.loading_status,
189
+ "game_started": game_engine.game_started,
190
+ "obs_shape": str(game_engine.obs.shape) if game_engine.obs is not None else "None",
191
+ "connected_clients": len(connected_clients),
192
+ "cuda_available": torch.cuda.is_available(),
193
+ "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
194
+ }
195
+ ```
196
+
197
+ Visit `/debug` endpoint to see the current state.
198
+
199
+ **Most Common Issue:** If clicking still doesn't work on HF Spaces, it's usually the WebSocket connection. Update the connection handling as described above.
200
+
201
+ The core model/clicking issue is now fixed - the remaining items are deployment optimizations for HF Spaces' specific environment! 🎉
HF_SPACES_GPU_FIX.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 HF Spaces GPU Acceleration Fix
2
+
3
+ ## ❌ **Problem Identified:**
4
+ Your T4 GPU wasn't being used because:
5
+
6
+ 1. **Dockerfile disabled CUDA**: `ENV CUDA_VISIBLE_DEVICES=""`
7
+ 2. **Environment variable issues**: OMP_NUM_THREADS causing warnings
8
+ 3. **App running on CPU**: Despite having T4 GPU hardware
9
+
10
+ ## ✅ **Complete Fix Applied:**
11
+
12
+ ### **1. Dockerfile Changes**
13
+ ```dockerfile
14
+ # REMOVED this line that was disabling GPU:
15
+ # ENV CUDA_VISIBLE_DEVICES=""
16
+
17
+ # Fixed environment variables:
18
+ ENV OMP_NUM_THREADS=2
19
+ ENV MKL_NUM_THREADS=2
20
+ ```
21
+
22
+ ### **2. App.py Improvements**
23
+ - ✅ **Fixed OMP_NUM_THREADS early**: Set before any imports
24
+ - ✅ **Improved GPU detection**: Better logging and detection
25
+ - ✅ **Cache directories**: Moved setup to very beginning
26
+
27
+ ### **3. Environment Variable Priority**
28
+ Environment variables are now set in this order:
29
+ 1. **Dockerfile** - Base container settings
30
+ 2. **app.py top** - Python-level fixes (before imports)
31
+ 3. **HF Spaces** - Runtime overrides
32
+
33
+ ## 🎯 **Expected Results After Fix:**
34
+
35
+ ### **Before (CPU mode):**
36
+ ```
37
+ INFO:app:Using device: cpu
38
+ INFO:app:CUDA not available, using CPU - this is normal for HF Spaces free tier
39
+ CPU 56%
40
+ GPU 0%
41
+ GPU VRAM 0/16 GB
42
+ ```
43
+
44
+ ### **After (GPU mode):**
45
+ ```
46
+ INFO:app:Using device: cuda
47
+ INFO:app:CUDA available: True
48
+ INFO:app:GPU device count: 1
49
+ INFO:app:Current GPU: Tesla T4
50
+ INFO:app:GPU memory: 15.1 GB
51
+ INFO:app:🚀 GPU acceleration enabled!
52
+ ```
53
+
54
+ ### **Performance Improvement:**
55
+ - **CPU usage**: Should drop to ~20-30%
56
+ - **GPU usage**: Should show 10-50% during AI inference
57
+ - **GPU VRAM**: Should show 2-4GB usage
58
+ - **AI FPS**: Should increase from ~2 FPS to 10+ FPS
59
+
60
+ ## 📋 **Deployment Steps:**
61
+
62
+ 1. **Commit and push changes:**
63
+ ```bash
64
+ git add .
65
+ git commit -m "Enable GPU acceleration for HF Spaces T4"
66
+ git push
67
+ ```
68
+
69
+ 2. **Wait for rebuild** (HF Spaces will restart automatically)
70
+
71
+ 3. **Check new logs** for GPU detection:
72
+ ```
73
+ INFO:app:🚀 GPU acceleration enabled!
74
+ ```
75
+
76
+ 4. **Monitor system stats:**
77
+ - GPU usage should now show activity
78
+ - GPU VRAM should show memory allocation
79
+ - Overall performance should be much faster
80
+
81
+ ## 🔍 **Debugging Commands:**
82
+
83
+ ### **Check CUDA in container:**
84
+ ```python
85
+ import torch
86
+ print(f"CUDA available: {torch.cuda.is_available()}")
87
+ print(f"GPU count: {torch.cuda.device_count()}")
88
+ if torch.cuda.is_available():
89
+ print(f"GPU name: {torch.cuda.get_device_name(0)}")
90
+ ```
91
+
92
+ ### **Check environment variables:**
93
+ ```python
94
+ import os
95
+ print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
96
+ print(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS')}")
97
+ ```
98
+
99
+ ## 🚨 **If GPU Still Not Working:**
100
+
101
+ ### **1. Verify HF Spaces Hardware:**
102
+ - Check your Space settings
103
+ - Ensure "T4 small" or "T4 medium" is selected
104
+ - Free tier doesn't have GPU access
105
+
106
+ ### **2. Check Container Logs:**
107
+ Look for these messages:
108
+ - ✅ `"🚀 GPU acceleration enabled!"`
109
+ - ❌ `"CUDA not available"`
110
+
111
+ ### **3. Alternative: Force GPU Detection**
112
+ If needed, add this debug code to app.py:
113
+ ```python
114
+ # Debug GPU detection
115
+ logger.info(f"Environment CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
116
+ logger.info(f"PyTorch CUDA compiled: {torch.version.cuda}")
117
+ logger.info(f"PyTorch version: {torch.__version__}")
118
+ ```
119
+
120
+ ## ⚡ **Performance Optimization Tips:**
121
+
122
+ ### **For T4 GPU:**
123
+ 1. **Enable model compilation** (optional):
124
+ ```bash
125
+ # Set environment variable in HF Spaces settings:
126
+ ENABLE_TORCH_COMPILE=1
127
+ ```
128
+
129
+ 2. **Increase AI FPS** (if needed):
130
+ ```python
131
+ # In app.py, line ~86:
132
+ self.ai_fps = 15 # Increase from 10 to 15
133
+ ```
134
+
135
+ 3. **Monitor GPU memory**:
136
+ - T4 has 16GB VRAM
137
+ - App should use 2-4GB
138
+ - Leave headroom for other processes
139
+
140
+ ## 🎮 **Expected User Experience:**
141
+
142
+ 1. **Faster loading**: Models load to GPU memory
143
+ 2. **Responsive gameplay**: AI inference runs at 10+ FPS
144
+ 3. **Smoother visuals**: Display updates without lag
145
+ 4. **Better AI performance**: GPU acceleration improves model inference
146
+
147
+ Your HF Spaces deployment should now fully utilize the T4 GPU! 🚀
148
+
149
+ ## 📊 **Monitor These Metrics:**
150
+ - **GPU Utilization**: 10-50% during gameplay
151
+ - **GPU Memory**: 2-4GB allocated
152
+ - **AI FPS**: 10-15 FPS (displayed in web interface)
153
+ - **CPU Usage**: Should decrease to 20-30%
154
+
155
+ The game should feel much more responsive now! 🎉
__pycache__/app.cpython-310.pyc ADDED
Binary file (37.2 kB). View file
 
__pycache__/config_web.cpython-310.pyc ADDED
Binary file (5.48 kB). View file
 
__pycache__/debug_app.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
__pycache__/hf_space_config.cpython-310.pyc ADDED
Binary file (4.92 kB). View file
 
app.py.backup ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web-based Diamond CSGO AI Player for Hugging Face Spaces
3
+ Uses FastAPI + WebSocket for real-time keyboard input and game streaming
4
+ """
5
+
6
+ # Fix environment variables FIRST, before any other imports
7
+ import os
8
+ import tempfile
9
+
10
+ # Fix OMP_NUM_THREADS immediately (before PyTorch/NumPy imports)
11
+ if "OMP_NUM_THREADS" not in os.environ or not os.environ.get("OMP_NUM_THREADS", "").isdigit():
12
+ os.environ["OMP_NUM_THREADS"] = "2"
13
+
14
+ # Set up cache directories immediately
15
+ temp_dir = tempfile.gettempdir()
16
+ os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch"))
17
+ os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface"))
18
+ os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers"))
19
+
20
+ # Create cache directories
21
+ for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]:
22
+ cache_path = os.environ[cache_var]
23
+ os.makedirs(cache_path, exist_ok=True)
24
+
25
+ import asyncio
26
+ import base64
27
+ import io
28
+ import json
29
+ import logging
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Set
32
+
33
+ import cv2
34
+ import numpy as np
35
+ import torch
36
+ import uvicorn
37
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
38
+ from fastapi.responses import HTMLResponse
39
+ from fastapi.staticfiles import StaticFiles
40
+ from hydra import compose, initialize
41
+ from hydra.utils import instantiate
42
+ from omegaconf import DictConfig, OmegaConf
43
+ from PIL import Image
44
+
45
+ # Import your modules
46
+ import sys
47
+ from pathlib import Path
48
+
49
+ # Add project root to path for src package imports
50
+ project_root = Path(__file__).parent
51
+ if str(project_root) not in sys.path:
52
+ sys.path.insert(0, str(project_root))
53
+
54
+ from src.agent import Agent
55
+ from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
56
+ from src.envs import WorldModelEnv
57
+ from src.game.web_play_env import WebPlayEnv
58
+ from src.utils import extract_state_dict
59
+ from config_web import web_config
60
+
61
+ # Configure logging
62
+ logging.basicConfig(level=logging.INFO)
63
+ logger = logging.getLogger(__name__)
64
+
65
+ # Global variables
66
+ app = FastAPI(title="Diamond CSGO AI Player")
67
+
68
+ # Set safe defaults for headless CI/Spaces environments
69
+ os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
70
+ os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
71
+ os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
72
+
73
+ # Environment variables already set at top of file
74
+ connected_clients: Set[WebSocket] = set()
75
+
76
+ class WebKeyMap:
77
+ """Map web key codes to pygame-like keys for CSGO actions"""
78
+ WEB_TO_CSGO = {
79
+ 'KeyW': 'w',
80
+ 'KeyA': 'a',
81
+ 'KeyS': 's',
82
+ 'KeyD': 'd',
83
+ 'Space': 'space',
84
+ 'ControlLeft': 'left ctrl',
85
+ 'ShiftLeft': 'left shift',
86
+ 'Digit1': '1',
87
+ 'Digit2': '2',
88
+ 'Digit3': '3',
89
+ 'KeyR': 'r',
90
+ 'ArrowUp': 'camera_up',
91
+ 'ArrowDown': 'camera_down',
92
+ 'ArrowLeft': 'camera_left',
93
+ 'ArrowRight': 'camera_right'
94
+ }
95
+
96
+ class WebGameEngine:
97
+ """Web-compatible game engine that replaces pygame functionality"""
98
+
99
+ def __init__(self):
100
+ self.play_env: Optional[WebPlayEnv] = None
101
+ self.obs = None
102
+ self.running = False
103
+ self.game_started = False
104
+ self.fps = 30 # Display FPS
105
+ self.ai_fps = 15 # AI inference FPS (matching standalone play.py performance)
106
+ self.frame_count = 0
107
+ self.ai_frame_count = 0
108
+ self.last_ai_time = 0
109
+ self.start_time = 0 # Track when AI started for proper FPS calculation
110
+ self.last_frame_send_time = 0 # Track frame sending for optimization
111
+ self.web_fps = 20 # Web display FPS (lower than AI FPS to reduce network overhead)
112
+ self.pressed_keys: Set[str] = set()
113
+ self.mouse_x = 0
114
+ self.mouse_y = 0
115
+ self.l_click = False
116
+ self.r_click = False
117
+ self.should_reset = False
118
+ self.cached_obs = None # Cache last observation for frame skipping
119
+ self.first_inference_done = False # Track if first inference completed
120
+ self.models_ready = False # Track if models are loaded
121
+ self.download_progress = 0 # Track download progress (0-100)
122
+ self.loading_status = "Initializing..." # Loading status message
123
+ self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
124
+ import time
125
+ self.time_module = time
126
+
127
+ async def _load_model_from_url_async(self, agent, device):
128
+ """Load model from URL using torch.hub (HF Spaces compatible)"""
129
+ import asyncio
130
+ import concurrent.futures
131
+
132
+ def load_model_weights():
133
+ """Load model weights in thread pool to avoid blocking"""
134
+ try:
135
+ logger.info("Loading model using torch.hub.load_state_dict_from_url...")
136
+ self.loading_status = "Downloading model..."
137
+ self.download_progress = 10
138
+
139
+ model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
140
+
141
+ # Use torch.hub to download and load state dict with custom cache dir
142
+ logger.info(f"Loading state dict from {model_url}")
143
+
144
+ # Set custom cache directory that we have write permissions for
145
+ cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache")
146
+ os.makedirs(cache_dir, exist_ok=True)
147
+
148
+ # Use torch.hub with custom cache directory
149
+ state_dict = torch.hub.load_state_dict_from_url(
150
+ model_url,
151
+ map_location=device,
152
+ model_dir=cache_dir,
153
+ check_hash=False # Skip hash check for faster loading
154
+ )
155
+
156
+ self.download_progress = 60
157
+ self.loading_status = "Loading model weights into agent..."
158
+ logger.info("State dict loaded, applying to agent...")
159
+
160
+ # Load state dict into agent, but skip actor_critic if not present
161
+ has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
162
+ logger.info(f"Model has actor_critic weights: {has_actor_critic}")
163
+ agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
164
+
165
+ # Track if actor_critic was actually loaded with trained weights
166
+ self.actor_critic_loaded = has_actor_critic
167
+
168
+ self.download_progress = 100
169
+ self.loading_status = "Model loaded successfully!"
170
+ logger.info("All model weights loaded successfully!")
171
+ return True
172
+
173
+ except Exception as e:
174
+ logger.error(f"Failed to load model: {e}")
175
+ import traceback
176
+ traceback.print_exc()
177
+ return False
178
+
179
+ # Run in thread pool to avoid blocking with timeout
180
+ loop = asyncio.get_event_loop()
181
+ try:
182
+ with concurrent.futures.ThreadPoolExecutor() as executor:
183
+ # Add timeout for model loading (5 minutes max)
184
+ future = loop.run_in_executor(executor, load_model_weights)
185
+ success = await asyncio.wait_for(future, timeout=300.0) # 5 minute timeout
186
+ return success
187
+ except asyncio.TimeoutError:
188
+ logger.error("Model loading timed out after 5 minutes")
189
+ self.loading_status = "Model loading timed out - using dummy mode"
190
+ return False
191
+ except Exception as e:
192
+ logger.error(f"Error in model loading executor: {e}")
193
+ self.loading_status = f"Model loading error: {str(e)[:50]}..."
194
+ return False
195
+
196
+ async def initialize_models(self):
197
+ """Initialize the AI models and environment"""
198
+ try:
199
+ import torch
200
+ logger.info("Initializing models...")
201
+
202
+ # Setup environment and paths
203
+ web_config.setup_environment_variables()
204
+ web_config.create_default_configs()
205
+
206
+ config_path = web_config.get_config_path()
207
+ logger.info(f"Using config path: {config_path}")
208
+
209
+ # For Hydra, use relative path from app.py location
210
+ # Since app.py is in project root, config is simply "config"
211
+ relative_config_path = "config"
212
+ logger.info(f"Relative config path: {relative_config_path}")
213
+
214
+ with initialize(version_base="1.3", config_path=relative_config_path):
215
+ cfg = compose(config_name="trainer")
216
+
217
+ # Override config for deployment
218
+ cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml")
219
+ cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml")
220
+
221
+ # Use GPU if available, otherwise fall back to CPU
222
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
+ logger.info(f"Using device: {device}")
224
+
225
+ # Log GPU availability and CUDA info for debugging
226
+ if torch.cuda.is_available():
227
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
228
+ logger.info(f"GPU device count: {torch.cuda.device_count()}")
229
+ logger.info(f"Current GPU: {torch.cuda.get_device_name(0)}")
230
+ logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
231
+ logger.info("🚀 GPU acceleration enabled!")
232
+ else:
233
+ logger.info("CUDA not available, using CPU mode")
234
+
235
+ # Initialize agent first
236
+ num_actions = cfg.env.num_actions
237
+ agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
238
+
239
+ # Get spawn directory
240
+ spawn_dir = web_config.get_spawn_dir()
241
+
242
+ # Try to load checkpoint (remote first, then local, then dummy mode)
243
+ try:
244
+ # First try to load from Hugging Face Hub using torch.hub
245
+ logger.info("Loading model from Hugging Face Hub with torch.hub...")
246
+
247
+ success = await self._load_model_from_url_async(agent, device)
248
+
249
+ if success:
250
+ logger.info("Successfully loaded checkpoint from HF Hub")
251
+ else:
252
+ # Fallback to local checkpoint if available
253
+ logger.error("Failed to load from HF Hub! Check the detailed error above.")
254
+ checkpoint_path = web_config.get_checkpoint_path()
255
+ if checkpoint_path.exists():
256
+ logger.info(f"Loading local checkpoint: {checkpoint_path}")
257
+ self.loading_status = "Loading local checkpoint..."
258
+ agent.load(checkpoint_path)
259
+ logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
260
+ # Assume local checkpoint has actor_critic weights (may need verification)
261
+ self.actor_critic_loaded = True
262
+ else:
263
+ logger.error(f"No local checkpoint found at: {checkpoint_path}")
264
+ raise FileNotFoundError("No model checkpoint available (local or remote)")
265
+
266
+ except Exception as e:
267
+ logger.error(f"Failed to load any checkpoint: {e}")
268
+ self._init_dummy_mode()
269
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
270
+ return True
271
+
272
+ # Initialize world model environment
273
+ try:
274
+ sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
275
+ if agent.upsampler is not None:
276
+ sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
277
+ wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
278
+ wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model,
279
+ spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
280
+
281
+ # Create play environment
282
+ self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
283
+
284
+ # Verify actor-critic is loaded and ready for inference
285
+ if agent.actor_critic is not None and self.actor_critic_loaded:
286
+ logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
287
+ logger.info(f"Actor-critic device: {agent.actor_critic.device}")
288
+ # Force AI control for web demo
289
+ self.play_env.is_human_player = False
290
+ logger.info("WebPlayEnv set to AI control mode")
291
+ elif agent.actor_critic is not None and not self.actor_critic_loaded:
292
+ logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
293
+ self.play_env.is_human_player = True
294
+ logger.info("WebPlayEnv set to human control mode (no trained weights)")
295
+ else:
296
+ logger.warning("No actor-critic model found - AI inference will not work!")
297
+ self.play_env.is_human_player = True
298
+ logger.info("WebPlayEnv set to human control mode (fallback)")
299
+
300
+ # Enable model compilation for better performance (like standalone play.py)
301
+ # This gives 20-50% speedup but causes 10-30s delay on first inference
302
+ import os
303
+ enable_compile = device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1"
304
+
305
+ if enable_compile:
306
+ logger.info("🚀 Compiling models for faster inference (like standalone play.py)...")
307
+ logger.info("⏱️ First inference will take 10-30s, but subsequent inferences will be much faster!")
308
+ try:
309
+ wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
310
+ if wm_env.upsample_next_obs is not None:
311
+ wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
312
+ logger.info("✅ Model compilation enabled - expect 20-50% speedup!")
313
+ except Exception as e:
314
+ logger.warning(f"Model compilation failed: {e}")
315
+ enable_compile = False
316
+
317
+ if not enable_compile:
318
+ logger.info("Model compilation disabled. Set ENABLE_TORCH_COMPILE=1 for better performance.")
319
+
320
+ # Reset environment
321
+ self.obs, _ = self.play_env.reset()
322
+ self.cached_obs = self.obs # Initialize cache
323
+
324
+ logger.info("Models initialized successfully!")
325
+ logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}")
326
+ self.models_ready = True
327
+ self.loading_status = "Ready!"
328
+ return True
329
+
330
+ except Exception as e:
331
+ logger.error(f"Failed to initialize world model environment: {e}")
332
+ self._init_dummy_mode()
333
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
334
+ self.models_ready = True
335
+ self.loading_status = "Using dummy mode"
336
+ return True
337
+
338
+ except Exception as e:
339
+ logger.error(f"Failed to initialize models: {e}")
340
+ import traceback
341
+ traceback.print_exc()
342
+ self._init_dummy_mode()
343
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
344
+ self.models_ready = True
345
+ self.loading_status = "Error - using dummy mode"
346
+ return True
347
+
348
+ def _init_dummy_mode(self):
349
+ """Initialize dummy mode for testing without models"""
350
+ logger.info("Initializing dummy mode...")
351
+
352
+ # Create a test observation
353
+ height, width = 150, 600
354
+ img_array = np.zeros((height, width, 3), dtype=np.uint8)
355
+
356
+ # Add test pattern
357
+ for y in range(height):
358
+ for x in range(width):
359
+ img_array[y, x, 0] = (x % 256) # Red gradient
360
+ img_array[y, x, 1] = (y % 256) # Green gradient
361
+ img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
362
+
363
+ # Convert to torch tensor in expected format [-1, 1]
364
+ tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
365
+ tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
366
+ tensor = tensor.unsqueeze(0) # Add batch dimension
367
+
368
+ self.obs = tensor
369
+ self.play_env = None # No real environment in dummy mode
370
+ logger.info("Dummy mode initialized with test pattern")
371
+
372
+
373
+ def step_environment(self):
374
+ """Step the environment with current input state (with intelligent frame skipping)"""
375
+ if self.play_env is None:
376
+ # Dummy mode - just return current observation
377
+ return self.obs, 0.0, False, False, {"mode": "dummy"}
378
+
379
+ try:
380
+ # Check if reset is requested
381
+ if self.should_reset:
382
+ self.reset_environment()
383
+ self.should_reset = False
384
+ self.last_ai_time = self.time_module.time() # Reset AI timer
385
+ return self.obs, 0.0, False, False, {"reset": True}
386
+
387
+ # Intelligent frame skipping: only run AI inference at target FPS
388
+ current_time = self.time_module.time()
389
+ time_since_last_ai = current_time - self.last_ai_time
390
+ should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
391
+
392
+ if should_run_ai:
393
+ # Show loading indicator for first inference (can be slow)
394
+ if not self.first_inference_done:
395
+ logger.info("Running first AI inference (may take 5-15 seconds)...")
396
+
397
+ # Run AI inference
398
+ inference_start = self.time_module.time()
399
+ next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
400
+ pressed_keys=self.pressed_keys,
401
+ mouse_x=self.mouse_x,
402
+ mouse_y=self.mouse_y,
403
+ l_click=self.l_click,
404
+ r_click=self.r_click
405
+ )
406
+ inference_time = self.time_module.time() - inference_start
407
+
408
+ # Log first inference completion
409
+ if not self.first_inference_done:
410
+ self.first_inference_done = True
411
+ logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!")
412
+
413
+ # Cache the new observation and update timing
414
+ self.cached_obs = next_obs
415
+ self.last_ai_time = current_time
416
+ self.ai_frame_count += 1
417
+
418
+ # Add AI performance info
419
+ info = info or {}
420
+ info["ai_inference"] = True
421
+
422
+ # Calculate proper AI FPS: frames / elapsed time since start
423
+ elapsed_time = current_time - self.start_time
424
+ if elapsed_time > 0 and self.ai_frame_count > 0:
425
+ ai_fps = self.ai_frame_count / elapsed_time
426
+ # Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference)
427
+ info["ai_fps"] = min(ai_fps, 100.0)
428
+ else:
429
+ info["ai_fps"] = 0
430
+
431
+ info["inference_time"] = inference_time
432
+
433
+ return next_obs, reward, done, truncated, info
434
+ else:
435
+ # Use cached observation for smoother display without AI overhead
436
+ obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
437
+
438
+ # Calculate AI FPS for cached frames too
439
+ elapsed_time = current_time - self.start_time
440
+ if elapsed_time > 0 and self.ai_frame_count > 0:
441
+ ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS
442
+ else:
443
+ ai_fps = 0
444
+
445
+ return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps}
446
+
447
+ except Exception as e:
448
+ logger.error(f"Error stepping environment: {e}")
449
+ obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
450
+ return obs_to_return, 0.0, False, False, {"error": str(e)}
451
+
452
+ def reset_environment(self):
453
+ """Reset the environment"""
454
+ try:
455
+ if self.play_env is not None:
456
+ self.obs, _ = self.play_env.reset()
457
+ self.cached_obs = self.obs # Update cache
458
+ logger.info("Environment reset successfully")
459
+ else:
460
+ # Dummy mode - recreate test pattern
461
+ self._init_dummy_mode()
462
+ self.cached_obs = self.obs # Update cache
463
+ logger.info("Dummy environment reset")
464
+ except Exception as e:
465
+ logger.error(f"Error resetting environment: {e}")
466
+
467
+ def request_reset(self):
468
+ """Request environment reset on next step"""
469
+ self.should_reset = True
470
+ logger.info("Environment reset requested")
471
+
472
+ def start_game(self):
473
+ """Start the game"""
474
+ self.game_started = True
475
+ self.start_time = self.time_module.time() # Reset start time for FPS calculation
476
+ self.ai_frame_count = 0 # Reset AI frame count
477
+ logger.info("Game started")
478
+
479
+ def pause_game(self):
480
+ """Pause/stop the game"""
481
+ self.game_started = False
482
+ logger.info("Game paused")
483
+
484
+ def obs_to_base64(self, obs: torch.Tensor) -> str:
485
+ """Convert observation tensor to base64 image for web display (optimized for speed)"""
486
+ if obs is None:
487
+ return ""
488
+
489
+ try:
490
+ # Handle observation tensor conversion based on dimensions
491
+ if obs.ndim == 4 and obs.size(0) == 1:
492
+ # 4D tensor with batch dimension [1, C, H, W] -> [C, H, W]
493
+ img_tensor = obs[0]
494
+ elif obs.ndim == 3:
495
+ # 3D tensor [C, H, W]
496
+ img_tensor = obs
497
+ elif obs.ndim == 2:
498
+ # 2D tensor - likely an error, return empty string
499
+ logger.warning(f"Unexpected 2D observation tensor: {obs.shape}")
500
+ return ""
501
+ else:
502
+ logger.warning(f"Unexpected observation dimensions: {obs.shape}")
503
+ return ""
504
+
505
+ # Convert to numpy with proper range conversion
506
+ img_array = img_tensor.mul(127.5).add_(127.5).clamp_(0, 255).byte()
507
+ img_array = img_array.permute(1, 2, 0).cpu().numpy()
508
+
509
+ # Direct resize with OpenCV (much faster than PIL)
510
+ img_array = cv2.resize(img_array, (600, 150), interpolation=cv2.INTER_NEAREST)
511
+
512
+ # Convert BGR to RGB if needed (OpenCV uses BGR)
513
+ if img_array.shape[2] == 3:
514
+ img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
515
+
516
+ # Optimized JPEG encoding with OpenCV (faster than PIL)
517
+ success, buffer = cv2.imencode('.jpg', img_array, [cv2.IMWRITE_JPEG_QUALITY, 80])
518
+ if success:
519
+ img_str = base64.b64encode(buffer).decode()
520
+ return f"data:image/jpeg;base64,{img_str}"
521
+ else:
522
+ logger.warning("Frame encoding failed, using fallback")
523
+ return ""
524
+
525
+ except Exception as e:
526
+ logger.error(f"Error converting observation to base64: {e}")
527
+ return ""
528
+
529
+ async def game_loop(self):
530
+ """Main game loop that runs continuously"""
531
+ self.running = True
532
+
533
+ while self.running:
534
+ try:
535
+ # Check if models are ready
536
+ if not self.models_ready:
537
+ # Send loading status to clients
538
+ if connected_clients:
539
+ loading_data = {
540
+ 'type': 'loading',
541
+ 'status': self.loading_status,
542
+ 'progress': self.download_progress,
543
+ 'ready': False
544
+ }
545
+ disconnected = set()
546
+ for client in connected_clients.copy():
547
+ try:
548
+ await client.send_text(json.dumps(loading_data))
549
+ except:
550
+ disconnected.add(client)
551
+ connected_clients.difference_update(disconnected)
552
+
553
+ await asyncio.sleep(0.5) # Check every 500ms during loading
554
+ continue
555
+
556
+ # Only step environment if game is started
557
+ if not self.game_started:
558
+ # Game not started - just send current observation without stepping
559
+ should_send_frame = True if (self.obs is not None and connected_clients) else False
560
+ # Don't modify self.obs when game isn't started!
561
+ await asyncio.sleep(0.1)
562
+ else:
563
+ # Game is started - step environment
564
+ should_send_frame = True
565
+ if self.play_env is None:
566
+ await asyncio.sleep(0.1)
567
+ continue
568
+
569
+ # Step environment with current input state
570
+ next_obs, reward, done, truncated, info = self.step_environment()
571
+
572
+ if done or truncated:
573
+ # Auto-reset when episode ends
574
+ self.reset_environment()
575
+ else:
576
+ self.obs = next_obs
577
+
578
+ # Send frame to all connected clients with smart throttling for performance
579
+ current_time = self.time_module.time()
580
+ time_since_last_frame_send = current_time - self.last_frame_send_time
581
+ should_send_web_frame = time_since_last_frame_send >= (1.0 / self.web_fps)
582
+
583
+ if should_send_frame and should_send_web_frame and connected_clients and self.obs is not None:
584
+ # Set default values for when game isn't running
585
+ if not self.game_started:
586
+ reward = 0.0
587
+ info = {"waiting": True}
588
+ # If game is started, reward and info should be set above
589
+
590
+ # Convert observation to base64
591
+ image_data = self.obs_to_base64(self.obs)
592
+
593
+ # Debug logging for first few frames
594
+ if self.frame_count < 5:
595
+ logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
596
+ f"image_data_length={len(image_data) if image_data else 0}, "
597
+ f"game_started={self.game_started}")
598
+
599
+ frame_data = {
600
+ 'type': 'frame',
601
+ 'image': image_data,
602
+ 'frame_count': self.frame_count,
603
+ 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
604
+ 'info': str(info) if info else "",
605
+ 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
606
+ 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False,
607
+ 'web_fps': self.web_fps, # Add web FPS for monitoring
608
+ 'ai_target_fps': self.ai_fps # Add target AI FPS for monitoring
609
+ }
610
+
611
+ # Send to all connected clients
612
+ disconnected = set()
613
+ for client in connected_clients.copy():
614
+ try:
615
+ await client.send_text(json.dumps(frame_data))
616
+ except:
617
+ disconnected.add(client)
618
+
619
+ # Remove disconnected clients
620
+ connected_clients.difference_update(disconnected)
621
+
622
+ # Update frame send timing
623
+ self.last_frame_send_time = current_time
624
+
625
+ self.frame_count += 1
626
+ await asyncio.sleep(1.0 / self.fps) # Control FPS
627
+
628
+ except Exception as e:
629
+ logger.error(f"Error in game loop: {e}")
630
+ await asyncio.sleep(0.1)
631
+
632
+ # Global game engine instance
633
+ game_engine = WebGameEngine()
634
+
635
+ @app.on_event("startup")
636
+ async def startup_event():
637
+ """Initialize models when the app starts"""
638
+ # Start the game loop immediately (it will handle loading state)
639
+ asyncio.create_task(game_engine.game_loop())
640
+
641
+ # Initialize models in background (non-blocking)
642
+ asyncio.create_task(game_engine.initialize_models())
643
+
644
+ @app.get("/performance")
645
+ async def get_performance_stats():
646
+ """Get current performance statistics"""
647
+ current_time = game_engine.time_module.time()
648
+ elapsed_time = current_time - game_engine.start_time if game_engine.start_time > 0 else 0
649
+
650
+ return {
651
+ "ai_fps_current": game_engine.ai_frame_count / elapsed_time if elapsed_time > 0 else 0,
652
+ "ai_fps_target": game_engine.ai_fps,
653
+ "web_fps_target": game_engine.web_fps,
654
+ "display_fps_target": game_engine.fps,
655
+ "models_ready": game_engine.models_ready,
656
+ "actor_critic_loaded": game_engine.actor_critic_loaded,
657
+ "game_started": game_engine.game_started,
658
+ "connected_clients": len(connected_clients),
659
+ "total_ai_frames": game_engine.ai_frame_count,
660
+ "total_display_frames": game_engine.frame_count,
661
+ "elapsed_time": elapsed_time,
662
+ "torch_compile_enabled": os.environ.get("ENABLE_TORCH_COMPILE", "1") == "1",
663
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
664
+ }
665
+
666
+ @app.get("/", response_class=HTMLResponse)
667
+ async def get_homepage():
668
+ """Serve the main game interface"""
669
+ html_content = """
670
+ <!DOCTYPE html>
671
+ <html>
672
+ <head>
673
+ <title>Physics-informed BEV World Model</title>
674
+ <style>
675
+ body {
676
+ margin: 0;
677
+ padding: 20px;
678
+ background: #1a1a1a;
679
+ color: white;
680
+ font-family: 'Courier New', monospace;
681
+ text-align: center;
682
+ }
683
+ #gameCanvas {
684
+ border: 2px solid #00ff00;
685
+ background: #000;
686
+ margin: 20px auto;
687
+ display: block;
688
+ }
689
+ #controls {
690
+ margin: 20px;
691
+ display: grid;
692
+ grid-template-columns: 1fr 1fr;
693
+ gap: 20px;
694
+ max-width: 800px;
695
+ margin: 20px auto;
696
+ }
697
+ .control-section {
698
+ background: #2a2a2a;
699
+ padding: 15px;
700
+ border-radius: 8px;
701
+ border: 1px solid #444;
702
+ }
703
+ .key-display {
704
+ background: #333;
705
+ border: 1px solid #555;
706
+ padding: 5px 10px;
707
+ margin: 2px;
708
+ border-radius: 4px;
709
+ display: inline-block;
710
+ min-width: 30px;
711
+ }
712
+ .key-pressed {
713
+ background: #00ff00;
714
+ color: #000;
715
+ }
716
+ #status {
717
+ margin: 10px;
718
+ padding: 10px;
719
+ background: #2a2a2a;
720
+ border-radius: 4px;
721
+ }
722
+ .info {
723
+ color: #00ff00;
724
+ margin: 5px 0;
725
+ }
726
+ </style>
727
+ </head>
728
+ <body>
729
+ <h1>🎮 Physics-informed BEV World Model</h1>
730
+ <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
731
+ <p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p>
732
+
733
+ <!-- Model Download Progress -->
734
+ <div id="downloadSection" style="display: none; margin: 20px;">
735
+ <p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p>
736
+ <div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;">
737
+ <div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div>
738
+ </div>
739
+ <p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p>
740
+ </div>
741
+
742
+ <canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas>
743
+
744
+ <div id="status">
745
+ <div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
746
+ <div class="info">Game: <span id="gameStatus">Click to Start</span></div>
747
+ <div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div>
748
+ <div class="info">Reward: <span id="reward">0</span></div>
749
+ </div>
750
+
751
+ <div id="controls">
752
+ <div class="control-section">
753
+ <h3>Movement</h3>
754
+ <div>
755
+ <span class="key-display" id="key-w">W</span> Forward<br>
756
+ <span class="key-display" id="key-a">A</span> Left
757
+ <span class="key-display" id="key-s">S</span> Back
758
+ <span class="key-display" id="key-d">D</span> Right<br>
759
+ <span class="key-display" id="key-space">Space</span> Jump
760
+ <span class="key-display" id="key-ctrl">Ctrl</span> Crouch
761
+ <span class="key-display" id="key-shift">Shift</span> Walk
762
+ </div>
763
+ </div>
764
+
765
+ <div class="control-section">
766
+ <h3>Actions</h3>
767
+ <div>
768
+ <span class="key-display" id="key-1">1</span> Weapon 1<br>
769
+ <span class="key-display" id="key-2">2</span> Weapon 2
770
+ <span class="key-display" id="key-3">3</span> Weapon 3<br>
771
+ <span class="key-display" id="key-r">R</span> Reload<br>
772
+ <span class="key-display" id="key-arrows">↑↓←→</span> Camera<br>
773
+ <span class="key-display" id="key-enter">Enter</span> Reset Game<br>
774
+ <span class="key-display" id="key-esc">Esc</span> Pause/Quit
775
+ </div>
776
+ </div>
777
+ </div>
778
+
779
+ <script>
780
+ const canvas = document.getElementById('gameCanvas');
781
+ const ctx = canvas.getContext('2d');
782
+ const statusEl = document.getElementById('connectionStatus');
783
+ const gameStatusEl = document.getElementById('gameStatus');
784
+ const frameEl = document.getElementById('frameCount');
785
+ const aiFpsEl = document.getElementById('aiFps');
786
+ const rewardEl = document.getElementById('reward');
787
+ const loadingEl = document.getElementById('loadingIndicator');
788
+ const downloadSectionEl = document.getElementById('downloadSection');
789
+ const downloadStatusEl = document.getElementById('downloadStatus');
790
+ const progressBarEl = document.getElementById('progressBar');
791
+ const progressTextEl = document.getElementById('progressText');
792
+
793
+ let ws = null;
794
+ let pressedKeys = new Set();
795
+ let gameStarted = false;
796
+
797
+ // Key mapping
798
+ const keyDisplayMap = {
799
+ 'KeyW': 'key-w',
800
+ 'KeyA': 'key-a',
801
+ 'KeyS': 'key-s',
802
+ 'KeyD': 'key-d',
803
+ 'Space': 'key-space',
804
+ 'ControlLeft': 'key-ctrl',
805
+ 'ShiftLeft': 'key-shift',
806
+ 'Digit1': 'key-1',
807
+ 'Digit2': 'key-2',
808
+ 'Digit3': 'key-3',
809
+ 'KeyR': 'key-r',
810
+ 'ArrowUp': 'key-arrows',
811
+ 'ArrowDown': 'key-arrows',
812
+ 'ArrowLeft': 'key-arrows',
813
+ 'ArrowRight': 'key-arrows',
814
+ 'Enter': 'key-enter',
815
+ 'Escape': 'key-esc'
816
+ };
817
+
818
+ function connectWebSocket() {
819
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
820
+ const wsUrl = `${protocol}//${window.location.host}/ws`;
821
+
822
+ ws = new WebSocket(wsUrl);
823
+
824
+ ws.onopen = function(event) {
825
+ statusEl.textContent = 'Connected';
826
+ statusEl.style.color = '#00ff00';
827
+ // If user already clicked to start before WS was ready, send start now
828
+ if (gameStarted) {
829
+ ws.send(JSON.stringify({ type: 'start' }));
830
+ }
831
+ };
832
+
833
+ ws.onmessage = function(event) {
834
+ const data = JSON.parse(event.data);
835
+
836
+ if (data.type === 'loading') {
837
+ // Handle loading status
838
+ downloadSectionEl.style.display = 'block';
839
+ downloadStatusEl.textContent = data.status;
840
+
841
+ if (data.progress !== undefined) {
842
+ progressBarEl.style.width = data.progress + '%';
843
+ progressTextEl.textContent = data.progress + '% - ' + data.status;
844
+ } else {
845
+ progressTextEl.textContent = data.status;
846
+ }
847
+
848
+ gameStatusEl.textContent = 'Loading Models...';
849
+ gameStatusEl.style.color = '#ffaa00';
850
+
851
+ } else if (data.type === 'frame') {
852
+ // Hide loading indicators once we get frames
853
+ downloadSectionEl.style.display = 'none';
854
+ // Update frame display
855
+ if (data.image) {
856
+ const img = new Image();
857
+ img.onload = function() {
858
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
859
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
860
+ };
861
+ img.src = data.image;
862
+ }
863
+
864
+ frameEl.textContent = data.frame_count;
865
+ rewardEl.textContent = data.reward.toFixed(2);
866
+
867
+ // Update AI FPS display and hide loading indicator once AI starts
868
+ if (data.ai_fps !== undefined && data.ai_fps !== null) {
869
+ // Ensure FPS value is reasonable
870
+ const aiFps = Math.min(Math.max(data.ai_fps, 0), 100);
871
+ aiFpsEl.textContent = aiFps.toFixed(1);
872
+
873
+ // Color code AI FPS for performance indication
874
+ if (aiFps >= 8) {
875
+ aiFpsEl.style.color = '#00ff00'; // Green for good performance
876
+ } else if (aiFps >= 5) {
877
+ aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance
878
+ } else if (aiFps > 0) {
879
+ aiFpsEl.style.color = '#ff0000'; // Red for poor performance
880
+ } else {
881
+ aiFpsEl.style.color = '#888888'; // Gray for inactive
882
+ }
883
+
884
+ // Hide loading indicator once AI inference starts working
885
+ if (aiFps > 0 && gameStarted) {
886
+ loadingEl.style.display = 'none';
887
+ gameStatusEl.textContent = 'Playing';
888
+ gameStatusEl.style.color = '#00ff00';
889
+ }
890
+ }
891
+ }
892
+ };
893
+
894
+ ws.onclose = function(event) {
895
+ statusEl.textContent = 'Disconnected';
896
+ statusEl.style.color = '#ff0000';
897
+ setTimeout(connectWebSocket, 1000); // Reconnect after 1 second
898
+ };
899
+
900
+ ws.onerror = function(event) {
901
+ statusEl.textContent = 'Error';
902
+ statusEl.style.color = '#ff0000';
903
+ };
904
+ }
905
+
906
+ function sendKeyState() {
907
+ if (ws && ws.readyState === WebSocket.OPEN) {
908
+ ws.send(JSON.stringify({
909
+ type: 'keys',
910
+ keys: Array.from(pressedKeys)
911
+ }));
912
+ }
913
+ }
914
+
915
+ function startGame() {
916
+ if (ws && ws.readyState === WebSocket.OPEN) {
917
+ ws.send(JSON.stringify({
918
+ type: 'start'
919
+ }));
920
+ gameStarted = true;
921
+ gameStatusEl.textContent = 'Starting AI...';
922
+ gameStatusEl.style.color = '#ffff00';
923
+ loadingEl.style.display = 'block';
924
+ console.log('Game started');
925
+ }
926
+ }
927
+
928
+ function pauseGame() {
929
+ if (ws && ws.readyState === WebSocket.OPEN) {
930
+ ws.send(JSON.stringify({
931
+ type: 'pause'
932
+ }));
933
+ gameStarted = false;
934
+ gameStatusEl.textContent = 'Paused - Click to Resume';
935
+ gameStatusEl.style.color = '#ffff00';
936
+ console.log('Game paused');
937
+ }
938
+ }
939
+
940
+ function updateKeyDisplay() {
941
+ // Reset all key displays
942
+ Object.values(keyDisplayMap).forEach(id => {
943
+ const el = document.getElementById(id);
944
+ if (el) el.classList.remove('key-pressed');
945
+ });
946
+
947
+ // Highlight pressed keys
948
+ pressedKeys.forEach(key => {
949
+ const displayId = keyDisplayMap[key];
950
+ if (displayId) {
951
+ const el = document.getElementById(displayId);
952
+ if (el) el.classList.add('key-pressed');
953
+ }
954
+ });
955
+ }
956
+
957
+ // Focus canvas and handle keyboard events
958
+ canvas.addEventListener('click', () => {
959
+ canvas.focus();
960
+ if (!gameStarted) {
961
+ // Queue start locally and send immediately if WS is open
962
+ gameStarted = true;
963
+ gameStatusEl.textContent = 'Starting AI...';
964
+ gameStatusEl.style.color = '#ffff00';
965
+ loadingEl.style.display = 'block';
966
+ if (ws && ws.readyState === WebSocket.OPEN) {
967
+ ws.send(JSON.stringify({ type: 'start' }));
968
+ }
969
+ }
970
+ });
971
+
972
+ canvas.addEventListener('keydown', (event) => {
973
+ event.preventDefault();
974
+
975
+ // Handle special keys
976
+ if (event.code === 'Enter') {
977
+ if (ws && ws.readyState === WebSocket.OPEN) {
978
+ ws.send(JSON.stringify({
979
+ type: 'reset'
980
+ }));
981
+ console.log('Environment reset requested');
982
+ }
983
+ // Add to pressedKeys for visual feedback
984
+ pressedKeys.add(event.code);
985
+ updateKeyDisplay();
986
+
987
+ // Remove Enter from pressedKeys after a short delay for visual feedback
988
+ setTimeout(() => {
989
+ pressedKeys.delete(event.code);
990
+ updateKeyDisplay();
991
+ }, 200);
992
+ } else if (event.code === 'Escape') {
993
+ pauseGame();
994
+ // Add to pressedKeys for visual feedback
995
+ pressedKeys.add(event.code);
996
+ updateKeyDisplay();
997
+
998
+ // Remove ESC from pressedKeys after a short delay for visual feedback
999
+ setTimeout(() => {
1000
+ pressedKeys.delete(event.code);
1001
+ updateKeyDisplay();
1002
+ }, 200);
1003
+ } else {
1004
+ // Only send game keys if game is started
1005
+ if (gameStarted) {
1006
+ pressedKeys.add(event.code);
1007
+ updateKeyDisplay();
1008
+ sendKeyState();
1009
+ }
1010
+ }
1011
+ });
1012
+
1013
+ canvas.addEventListener('keyup', (event) => {
1014
+ event.preventDefault();
1015
+
1016
+ // Don't handle special keys release (handled in keydown with timeout)
1017
+ if (event.code !== 'Enter' && event.code !== 'Escape') {
1018
+ if (gameStarted) {
1019
+ pressedKeys.delete(event.code);
1020
+ updateKeyDisplay();
1021
+ sendKeyState();
1022
+ }
1023
+ }
1024
+ });
1025
+
1026
+ // Handle mouse events for clicks
1027
+ canvas.addEventListener('mousedown', (event) => {
1028
+ if (ws && ws.readyState === WebSocket.OPEN) {
1029
+ ws.send(JSON.stringify({
1030
+ type: 'mouse',
1031
+ button: event.button,
1032
+ action: 'down',
1033
+ x: event.offsetX,
1034
+ y: event.offsetY
1035
+ }));
1036
+ }
1037
+ });
1038
+
1039
+ canvas.addEventListener('mouseup', (event) => {
1040
+ if (ws && ws.readyState === WebSocket.OPEN) {
1041
+ ws.send(JSON.stringify({
1042
+ type: 'mouse',
1043
+ button: event.button,
1044
+ action: 'up',
1045
+ x: event.offsetX,
1046
+ y: event.offsetY
1047
+ }));
1048
+ }
1049
+ });
1050
+
1051
+ // Initialize
1052
+ connectWebSocket();
1053
+ canvas.focus();
1054
+ </script>
1055
+ </body>
1056
+ </html>
1057
+ """
1058
+ return html_content
1059
+
1060
+ @app.websocket("/ws")
1061
+ async def websocket_endpoint(websocket: WebSocket):
1062
+ """Handle WebSocket connections for real-time game communication"""
1063
+ await websocket.accept()
1064
+ connected_clients.add(websocket)
1065
+
1066
+ try:
1067
+ while True:
1068
+ # Receive messages from client
1069
+ data = await websocket.receive_text()
1070
+ message = json.loads(data)
1071
+
1072
+ if message['type'] == 'keys':
1073
+ # Update pressed keys
1074
+ game_engine.pressed_keys = set(message['keys'])
1075
+
1076
+ elif message['type'] == 'reset':
1077
+ # Handle environment reset request
1078
+ game_engine.request_reset()
1079
+
1080
+ elif message['type'] == 'start':
1081
+ # Handle game start request
1082
+ game_engine.start_game()
1083
+
1084
+ elif message['type'] == 'pause':
1085
+ # Handle game pause request
1086
+ game_engine.pause_game()
1087
+
1088
+ elif message['type'] == 'mouse':
1089
+ # Handle mouse events
1090
+ if message['action'] == 'down':
1091
+ if message['button'] == 0: # Left click
1092
+ game_engine.l_click = True
1093
+ elif message['button'] == 2: # Right click
1094
+ game_engine.r_click = True
1095
+ elif message['action'] == 'up':
1096
+ if message['button'] == 0: # Left click
1097
+ game_engine.l_click = False
1098
+ elif message['button'] == 2: # Right click
1099
+ game_engine.r_click = False
1100
+
1101
+ # Update mouse position (relative to canvas)
1102
+ game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px
1103
+ game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px
1104
+
1105
+ except WebSocketDisconnect:
1106
+ connected_clients.discard(websocket)
1107
+ except Exception as e:
1108
+ logger.error(f"WebSocket error: {e}")
1109
+ connected_clients.discard(websocket)
1110
+
1111
+ if __name__ == "__main__":
1112
+ # For local development
1113
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
1114
+
app.py.backup2 ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Web-based Diamond CSGO AI Player for Hugging Face Spaces
3
+ Uses FastAPI + WebSocket for real-time keyboard input and game streaming
4
+ """
5
+
6
+ # Fix environment variables FIRST, before any other imports
7
+ import os
8
+ import tempfile
9
+
10
+ # Fix OMP_NUM_THREADS immediately (before PyTorch/NumPy imports)
11
+ if "OMP_NUM_THREADS" not in os.environ or not os.environ.get("OMP_NUM_THREADS", "").isdigit():
12
+ os.environ["OMP_NUM_THREADS"] = "2"
13
+
14
+ # Set up cache directories immediately
15
+ temp_dir = tempfile.gettempdir()
16
+ os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch"))
17
+ os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface"))
18
+ os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers"))
19
+
20
+ # Create cache directories
21
+ for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]:
22
+ cache_path = os.environ[cache_var]
23
+ os.makedirs(cache_path, exist_ok=True)
24
+
25
+ import asyncio
26
+ import base64
27
+ import io
28
+ import json
29
+ import logging
30
+ from pathlib import Path
31
+ from typing import Dict, List, Optional, Set
32
+
33
+ import cv2
34
+ import numpy as np
35
+ import torch
36
+ import uvicorn
37
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
38
+ from fastapi.responses import HTMLResponse
39
+ from fastapi.staticfiles import StaticFiles
40
+ from hydra import compose, initialize
41
+ from hydra.utils import instantiate
42
+ from omegaconf import DictConfig, OmegaConf
43
+ from PIL import Image
44
+
45
+ # Import your modules
46
+ import sys
47
+ from pathlib import Path
48
+
49
+ # Add project root to path for src package imports
50
+ project_root = Path(__file__).parent
51
+ if str(project_root) not in sys.path:
52
+ sys.path.insert(0, str(project_root))
53
+
54
+ from src.agent import Agent
55
+ from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
56
+ from src.envs import WorldModelEnv
57
+ from src.game.web_play_env import WebPlayEnv
58
+ from src.utils import extract_state_dict
59
+ from config_web import web_config
60
+
61
+ # Configure logging
62
+ logging.basicConfig(level=logging.INFO)
63
+ logger = logging.getLogger(__name__)
64
+
65
+ # Global variables
66
+ app = FastAPI(title="Diamond CSGO AI Player")
67
+
68
+ # Set safe defaults for headless CI/Spaces environments
69
+ os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
70
+ os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
71
+ os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
72
+
73
+ # Environment variables already set at top of file
74
+ connected_clients: Set[WebSocket] = set()
75
+
76
+ class WebKeyMap:
77
+ """Map web key codes to pygame-like keys for CSGO actions"""
78
+ WEB_TO_CSGO = {
79
+ 'KeyW': 'w',
80
+ 'KeyA': 'a',
81
+ 'KeyS': 's',
82
+ 'KeyD': 'd',
83
+ 'Space': 'space',
84
+ 'ControlLeft': 'left ctrl',
85
+ 'ShiftLeft': 'left shift',
86
+ 'Digit1': '1',
87
+ 'Digit2': '2',
88
+ 'Digit3': '3',
89
+ 'KeyR': 'r',
90
+ 'ArrowUp': 'camera_up',
91
+ 'ArrowDown': 'camera_down',
92
+ 'ArrowLeft': 'camera_left',
93
+ 'ArrowRight': 'camera_right'
94
+ }
95
+
96
+ class WebGameEngine:
97
+ """Web-compatible game engine that replaces pygame functionality"""
98
+
99
+ def __init__(self):
100
+ self.play_env: Optional[WebPlayEnv] = None
101
+ self.obs = None
102
+ self.running = False
103
+ self.game_started = False
104
+ self.fps = 30 # Display FPS
105
+ self.ai_fps = 40 # AI inference FPS (matching standalone play.py performance)
106
+ self.frame_count = 0
107
+ self.ai_frame_count = 0
108
+ self.last_ai_time = 0
109
+ self.start_time = 0 # Track when AI started for proper FPS calculation
110
+ self.last_frame_send_time = 0 # Track frame sending for optimization
111
+ self.web_fps = 20 # Web display FPS (lower than AI FPS to reduce network overhead)
112
+ self.pressed_keys: Set[str] = set()
113
+ self.mouse_x = 0
114
+ self.mouse_y = 0
115
+ self.l_click = False
116
+ self.r_click = False
117
+ self.should_reset = False
118
+ self.cached_obs = None # Cache last observation for frame skipping
119
+ self.first_inference_done = False # Track if first inference completed
120
+ self.models_ready = False # Track if models are loaded
121
+ self.download_progress = 0 # Track download progress (0-100)
122
+ self.loading_status = "Initializing..." # Loading status message
123
+ self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
124
+ import time
125
+ self.time_module = time
126
+
127
+ async def _load_model_from_url_async(self, agent, device):
128
+ """Load model from URL using torch.hub (HF Spaces compatible)"""
129
+ import asyncio
130
+ import concurrent.futures
131
+
132
+ def load_model_weights():
133
+ """Load model weights in thread pool to avoid blocking"""
134
+ try:
135
+ logger.info("Loading model using torch.hub.load_state_dict_from_url...")
136
+ self.loading_status = "Downloading model..."
137
+ self.download_progress = 10
138
+
139
+ model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
140
+
141
+ # Use torch.hub to download and load state dict with custom cache dir
142
+ logger.info(f"Loading state dict from {model_url}")
143
+
144
+ # Set custom cache directory that we have write permissions for
145
+ cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache")
146
+ os.makedirs(cache_dir, exist_ok=True)
147
+
148
+ # Use torch.hub with custom cache directory
149
+ state_dict = torch.hub.load_state_dict_from_url(
150
+ model_url,
151
+ map_location=device,
152
+ model_dir=cache_dir,
153
+ check_hash=False # Skip hash check for faster loading
154
+ )
155
+
156
+ self.download_progress = 60
157
+ self.loading_status = "Loading model weights into agent..."
158
+ logger.info("State dict loaded, applying to agent...")
159
+
160
+ # Load state dict into agent, but skip actor_critic if not present
161
+ has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
162
+ logger.info(f"Model has actor_critic weights: {has_actor_critic}")
163
+ agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
164
+
165
+ # Track if actor_critic was actually loaded with trained weights
166
+ self.actor_critic_loaded = has_actor_critic
167
+
168
+ self.download_progress = 100
169
+ self.loading_status = "Model loaded successfully!"
170
+ logger.info("All model weights loaded successfully!")
171
+ return True
172
+
173
+ except Exception as e:
174
+ logger.error(f"Failed to load model: {e}")
175
+ import traceback
176
+ traceback.print_exc()
177
+ return False
178
+
179
+ # Run in thread pool to avoid blocking with timeout
180
+ loop = asyncio.get_event_loop()
181
+ try:
182
+ with concurrent.futures.ThreadPoolExecutor() as executor:
183
+ # Add timeout for model loading (5 minutes max)
184
+ future = loop.run_in_executor(executor, load_model_weights)
185
+ success = await asyncio.wait_for(future, timeout=300.0) # 5 minute timeout
186
+ return success
187
+ except asyncio.TimeoutError:
188
+ logger.error("Model loading timed out after 5 minutes")
189
+ self.loading_status = "Model loading timed out - using dummy mode"
190
+ return False
191
+ except Exception as e:
192
+ logger.error(f"Error in model loading executor: {e}")
193
+ self.loading_status = f"Model loading error: {str(e)[:50]}..."
194
+ return False
195
+
196
+ async def initialize_models(self):
197
+ """Initialize the AI models and environment"""
198
+ try:
199
+ import torch
200
+ logger.info("Initializing models...")
201
+
202
+ # Setup environment and paths
203
+ web_config.setup_environment_variables()
204
+ web_config.create_default_configs()
205
+
206
+ config_path = web_config.get_config_path()
207
+ logger.info(f"Using config path: {config_path}")
208
+
209
+ # For Hydra, use relative path from app.py location
210
+ # Since app.py is in project root, config is simply "config"
211
+ relative_config_path = "config"
212
+ logger.info(f"Relative config path: {relative_config_path}")
213
+
214
+ with initialize(version_base="1.3", config_path=relative_config_path):
215
+ cfg = compose(config_name="trainer")
216
+
217
+ # Override config for deployment
218
+ cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml")
219
+ cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml")
220
+
221
+ # Use GPU if available, otherwise fall back to CPU
222
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223
+ logger.info(f"Using device: {device}")
224
+
225
+ # Log GPU availability and CUDA info for debugging
226
+ if torch.cuda.is_available():
227
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
228
+ logger.info(f"GPU device count: {torch.cuda.device_count()}")
229
+ logger.info(f"Current GPU: {torch.cuda.get_device_name(0)}")
230
+ logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
231
+ logger.info("🚀 GPU acceleration enabled!")
232
+ else:
233
+ logger.info("CUDA not available, using CPU mode")
234
+
235
+ # Initialize agent first
236
+ num_actions = cfg.env.num_actions
237
+ agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
238
+
239
+ # Get spawn directory
240
+ spawn_dir = web_config.get_spawn_dir()
241
+
242
+ # Try to load checkpoint (remote first, then local, then dummy mode)
243
+ try:
244
+ # First try to load from Hugging Face Hub using torch.hub
245
+ logger.info("Loading model from Hugging Face Hub with torch.hub...")
246
+
247
+ success = await self._load_model_from_url_async(agent, device)
248
+
249
+ if success:
250
+ logger.info("Successfully loaded checkpoint from HF Hub")
251
+ else:
252
+ # Fallback to local checkpoint if available
253
+ logger.error("Failed to load from HF Hub! Check the detailed error above.")
254
+ checkpoint_path = web_config.get_checkpoint_path()
255
+ if checkpoint_path.exists():
256
+ logger.info(f"Loading local checkpoint: {checkpoint_path}")
257
+ self.loading_status = "Loading local checkpoint..."
258
+ agent.load(checkpoint_path)
259
+ logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
260
+ # Assume local checkpoint has actor_critic weights (may need verification)
261
+ self.actor_critic_loaded = True
262
+ else:
263
+ logger.error(f"No local checkpoint found at: {checkpoint_path}")
264
+ raise FileNotFoundError("No model checkpoint available (local or remote)")
265
+
266
+ except Exception as e:
267
+ logger.error(f"Failed to load any checkpoint: {e}")
268
+ self._init_dummy_mode()
269
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
270
+ return True
271
+
272
+ # Initialize world model environment
273
+ try:
274
+ sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
275
+ if agent.upsampler is not None:
276
+ sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
277
+ wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
278
+ wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model,
279
+ spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
280
+
281
+ # Create play environment
282
+ self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
283
+
284
+ # Verify actor-critic is loaded and ready for inference
285
+ if agent.actor_critic is not None and self.actor_critic_loaded:
286
+ logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
287
+ logger.info(f"Actor-critic device: {agent.actor_critic.device}")
288
+ # Force AI control for web demo
289
+ self.play_env.is_human_player = False
290
+ logger.info("WebPlayEnv set to AI control mode")
291
+ elif agent.actor_critic is not None and not self.actor_critic_loaded:
292
+ logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
293
+ self.play_env.is_human_player = True
294
+ logger.info("WebPlayEnv set to human control mode (no trained weights)")
295
+ else:
296
+ logger.warning("No actor-critic model found - AI inference will not work!")
297
+ self.play_env.is_human_player = True
298
+ logger.info("WebPlayEnv set to human control mode (fallback)")
299
+
300
+ # Enable model compilation for better performance (like standalone play.py)
301
+ # This gives 20-50% speedup but causes 10-30s delay on first inference
302
+ import os
303
+ enable_compile = device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1"
304
+
305
+ if enable_compile:
306
+ logger.info("🚀 Compiling models for faster inference (like standalone play.py)...")
307
+ logger.info("⏱️ First inference will take 10-30s, but subsequent inferences will be much faster!")
308
+ try:
309
+ wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
310
+ if wm_env.upsample_next_obs is not None:
311
+ wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
312
+ logger.info("✅ Model compilation enabled - expect 20-50% speedup!")
313
+ except Exception as e:
314
+ logger.warning(f"Model compilation failed: {e}")
315
+ enable_compile = False
316
+
317
+ if not enable_compile:
318
+ logger.info("Model compilation disabled. Set ENABLE_TORCH_COMPILE=1 for better performance.")
319
+
320
+ # Reset environment
321
+ self.obs, _ = self.play_env.reset()
322
+ self.cached_obs = self.obs # Initialize cache
323
+
324
+ logger.info("Models initialized successfully!")
325
+ logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}")
326
+ self.models_ready = True
327
+ self.loading_status = "Ready!"
328
+ return True
329
+
330
+ except Exception as e:
331
+ logger.error(f"Failed to initialize world model environment: {e}")
332
+ self._init_dummy_mode()
333
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
334
+ self.models_ready = True
335
+ self.loading_status = "Using dummy mode"
336
+ return True
337
+
338
+ except Exception as e:
339
+ logger.error(f"Failed to initialize models: {e}")
340
+ import traceback
341
+ traceback.print_exc()
342
+ self._init_dummy_mode()
343
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
344
+ self.models_ready = True
345
+ self.loading_status = "Error - using dummy mode"
346
+ return True
347
+
348
+ def _init_dummy_mode(self):
349
+ """Initialize dummy mode for testing without models"""
350
+ logger.info("Initializing dummy mode...")
351
+
352
+ # Create a test observation
353
+ height, width = 150, 600
354
+ img_array = np.zeros((height, width, 3), dtype=np.uint8)
355
+
356
+ # Add test pattern
357
+ for y in range(height):
358
+ for x in range(width):
359
+ img_array[y, x, 0] = (x % 256) # Red gradient
360
+ img_array[y, x, 1] = (y % 256) # Green gradient
361
+ img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
362
+
363
+ # Convert to torch tensor in expected format [-1, 1]
364
+ tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
365
+ tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
366
+ tensor = tensor.unsqueeze(0) # Add batch dimension
367
+
368
+ self.obs = tensor
369
+ self.play_env = None # No real environment in dummy mode
370
+ logger.info("Dummy mode initialized with test pattern")
371
+
372
+
373
+ def step_environment(self):
374
+ """Step the environment with current input state (with intelligent frame skipping)"""
375
+ if self.play_env is None:
376
+ # Dummy mode - just return current observation
377
+ return self.obs, 0.0, False, False, {"mode": "dummy"}
378
+
379
+ try:
380
+ # Check if reset is requested
381
+ if self.should_reset:
382
+ self.reset_environment()
383
+ self.should_reset = False
384
+ self.last_ai_time = self.time_module.time() # Reset AI timer
385
+ return self.obs, 0.0, False, False, {"reset": True}
386
+
387
+ # Intelligent frame skipping: only run AI inference at target FPS
388
+ current_time = self.time_module.time()
389
+ time_since_last_ai = current_time - self.last_ai_time
390
+ should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
391
+
392
+ if should_run_ai:
393
+ # Show loading indicator for first inference (can be slow)
394
+ if not self.first_inference_done:
395
+ logger.info("Running first AI inference (may take 5-15 seconds)...")
396
+
397
+ # Run AI inference
398
+ inference_start = self.time_module.time()
399
+ next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
400
+ pressed_keys=self.pressed_keys,
401
+ mouse_x=self.mouse_x,
402
+ mouse_y=self.mouse_y,
403
+ l_click=self.l_click,
404
+ r_click=self.r_click
405
+ )
406
+ inference_time = self.time_module.time() - inference_start
407
+
408
+ # Log first inference completion
409
+ if not self.first_inference_done:
410
+ self.first_inference_done = True
411
+ logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!")
412
+
413
+ # Cache the new observation and update timing
414
+ self.cached_obs = next_obs
415
+ self.last_ai_time = current_time
416
+ self.ai_frame_count += 1
417
+
418
+ # Add AI performance info
419
+ info = info or {}
420
+ info["ai_inference"] = True
421
+
422
+ # Calculate proper AI FPS: frames / elapsed time since start
423
+ elapsed_time = current_time - self.start_time
424
+ if elapsed_time > 0 and self.ai_frame_count > 0:
425
+ ai_fps = self.ai_frame_count / elapsed_time
426
+ # Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference)
427
+ info["ai_fps"] = min(ai_fps, 100.0)
428
+ else:
429
+ info["ai_fps"] = 0
430
+
431
+ info["inference_time"] = inference_time
432
+
433
+ return next_obs, reward, done, truncated, info
434
+ else:
435
+ # Use cached observation for smoother display without AI overhead
436
+ obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
437
+
438
+ # Calculate AI FPS for cached frames too
439
+ elapsed_time = current_time - self.start_time
440
+ if elapsed_time > 0 and self.ai_frame_count > 0:
441
+ ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS
442
+ else:
443
+ ai_fps = 0
444
+
445
+ return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps}
446
+
447
+ except Exception as e:
448
+ logger.error(f"Error stepping environment: {e}")
449
+ obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
450
+ return obs_to_return, 0.0, False, False, {"error": str(e)}
451
+
452
+ def reset_environment(self):
453
+ """Reset the environment"""
454
+ try:
455
+ if self.play_env is not None:
456
+ self.obs, _ = self.play_env.reset()
457
+ self.cached_obs = self.obs # Update cache
458
+ logger.info("Environment reset successfully")
459
+ else:
460
+ # Dummy mode - recreate test pattern
461
+ self._init_dummy_mode()
462
+ self.cached_obs = self.obs # Update cache
463
+ logger.info("Dummy environment reset")
464
+ except Exception as e:
465
+ logger.error(f"Error resetting environment: {e}")
466
+
467
+ def request_reset(self):
468
+ """Request environment reset on next step"""
469
+ self.should_reset = True
470
+ logger.info("Environment reset requested")
471
+
472
+ def start_game(self):
473
+ """Start the game"""
474
+ self.game_started = True
475
+ self.start_time = self.time_module.time() # Reset start time for FPS calculation
476
+ self.ai_frame_count = 0 # Reset AI frame count
477
+ logger.info("Game started")
478
+
479
+ def pause_game(self):
480
+ """Pause/stop the game"""
481
+ self.game_started = False
482
+ logger.info("Game paused")
483
+
484
+ def obs_to_base64(self, obs: torch.Tensor) -> str:
485
+ """Convert observation tensor to base64 image for web display (optimized for speed)"""
486
+ if obs is None:
487
+ return ""
488
+
489
+ try:
490
+ # Handle observation tensor conversion based on dimensions
491
+ if obs.ndim == 4 and obs.size(0) == 1:
492
+ # 4D tensor with batch dimension [1, C, H, W] -> [C, H, W]
493
+ img_tensor = obs[0]
494
+ elif obs.ndim == 3:
495
+ # 3D tensor [C, H, W]
496
+ img_tensor = obs
497
+ elif obs.ndim == 2:
498
+ # 2D tensor - likely an error, return empty string
499
+ logger.warning(f"Unexpected 2D observation tensor: {obs.shape}")
500
+ return ""
501
+ else:
502
+ logger.warning(f"Unexpected observation dimensions: {obs.shape}")
503
+ return ""
504
+
505
+ # Convert to numpy with proper range conversion
506
+ img_array = img_tensor.mul(127.5).add_(127.5).clamp_(0, 255).byte()
507
+ img_array = img_array.permute(1, 2, 0).cpu().numpy()
508
+
509
+ # Direct resize with OpenCV (much faster than PIL)
510
+ img_array = cv2.resize(img_array, (600, 150), interpolation=cv2.INTER_CUBIC)
511
+
512
+ # Note: img_array is already in RGB format from PyTorch tensor, no conversion needed
513
+
514
+ # Optimized JPEG encoding with OpenCV (faster than PIL)
515
+ success, buffer = cv2.imencode('.jpg', img_array, [cv2.IMWRITE_JPEG_QUALITY, 95])
516
+ if success:
517
+ img_str = base64.b64encode(buffer).decode()
518
+ return f"data:image/jpeg;base64,{img_str}"
519
+ else:
520
+ logger.warning("Frame encoding failed, using fallback")
521
+ return ""
522
+
523
+ except Exception as e:
524
+ logger.error(f"Error converting observation to base64: {e}")
525
+ return ""
526
+
527
+ async def game_loop(self):
528
+ """Main game loop that runs continuously"""
529
+ self.running = True
530
+
531
+ while self.running:
532
+ try:
533
+ # Check if models are ready
534
+ if not self.models_ready:
535
+ # Send loading status to clients
536
+ if connected_clients:
537
+ loading_data = {
538
+ 'type': 'loading',
539
+ 'status': self.loading_status,
540
+ 'progress': self.download_progress,
541
+ 'ready': False
542
+ }
543
+ disconnected = set()
544
+ for client in connected_clients.copy():
545
+ try:
546
+ await client.send_text(json.dumps(loading_data))
547
+ except:
548
+ disconnected.add(client)
549
+ connected_clients.difference_update(disconnected)
550
+
551
+ await asyncio.sleep(0.5) # Check every 500ms during loading
552
+ continue
553
+
554
+ # Only step environment if game is started
555
+ if not self.game_started:
556
+ # Game not started - just send current observation without stepping
557
+ should_send_frame = True if (self.obs is not None and connected_clients) else False
558
+ # Don't modify self.obs when game isn't started!
559
+ await asyncio.sleep(0.1)
560
+ else:
561
+ # Game is started - step environment
562
+ should_send_frame = True
563
+ if self.play_env is None:
564
+ await asyncio.sleep(0.1)
565
+ continue
566
+
567
+ # Step environment with current input state
568
+ next_obs, reward, done, truncated, info = self.step_environment()
569
+
570
+ if done or truncated:
571
+ # Auto-reset when episode ends
572
+ self.reset_environment()
573
+ else:
574
+ self.obs = next_obs
575
+
576
+ # Send frame to all connected clients with smart throttling for performance
577
+ current_time = self.time_module.time()
578
+ time_since_last_frame_send = current_time - self.last_frame_send_time
579
+ should_send_web_frame = time_since_last_frame_send >= (1.0 / self.web_fps)
580
+
581
+ if should_send_frame and should_send_web_frame and connected_clients and self.obs is not None:
582
+ # Set default values for when game isn't running
583
+ if not self.game_started:
584
+ reward = 0.0
585
+ info = {"waiting": True}
586
+ # If game is started, reward and info should be set above
587
+
588
+ # Convert observation to base64
589
+ image_data = self.obs_to_base64(self.obs)
590
+
591
+ # Debug logging for first few frames
592
+ if self.frame_count < 5:
593
+ logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
594
+ f"image_data_length={len(image_data) if image_data else 0}, "
595
+ f"game_started={self.game_started}")
596
+
597
+ frame_data = {
598
+ 'type': 'frame',
599
+ 'image': image_data,
600
+ 'frame_count': self.frame_count,
601
+ 'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
602
+ 'info': str(info) if info else "",
603
+ 'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
604
+ 'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False,
605
+ 'web_fps': self.web_fps, # Add web FPS for monitoring
606
+ 'ai_target_fps': self.ai_fps # Add target AI FPS for monitoring
607
+ }
608
+
609
+ # Send to all connected clients
610
+ disconnected = set()
611
+ for client in connected_clients.copy():
612
+ try:
613
+ await client.send_text(json.dumps(frame_data))
614
+ except:
615
+ disconnected.add(client)
616
+
617
+ # Remove disconnected clients
618
+ connected_clients.difference_update(disconnected)
619
+
620
+ # Update frame send timing
621
+ self.last_frame_send_time = current_time
622
+
623
+ self.frame_count += 1
624
+ await asyncio.sleep(1.0 / self.fps) # Control FPS
625
+
626
+ except Exception as e:
627
+ logger.error(f"Error in game loop: {e}")
628
+ await asyncio.sleep(0.1)
629
+
630
+ # Global game engine instance
631
+ game_engine = WebGameEngine()
632
+
633
+ @app.on_event("startup")
634
+ async def startup_event():
635
+ """Initialize models when the app starts"""
636
+ # Start the game loop immediately (it will handle loading state)
637
+ asyncio.create_task(game_engine.game_loop())
638
+
639
+ # Initialize models in background (non-blocking)
640
+ asyncio.create_task(game_engine.initialize_models())
641
+
642
+ @app.get("/performance")
643
+ async def get_performance_stats():
644
+ """Get current performance statistics"""
645
+ current_time = game_engine.time_module.time()
646
+ elapsed_time = current_time - game_engine.start_time if game_engine.start_time > 0 else 0
647
+
648
+ return {
649
+ "ai_fps_current": game_engine.ai_frame_count / elapsed_time if elapsed_time > 0 else 0,
650
+ "ai_fps_target": game_engine.ai_fps,
651
+ "web_fps_target": game_engine.web_fps,
652
+ "display_fps_target": game_engine.fps,
653
+ "models_ready": game_engine.models_ready,
654
+ "actor_critic_loaded": game_engine.actor_critic_loaded,
655
+ "game_started": game_engine.game_started,
656
+ "connected_clients": len(connected_clients),
657
+ "total_ai_frames": game_engine.ai_frame_count,
658
+ "total_display_frames": game_engine.frame_count,
659
+ "elapsed_time": elapsed_time,
660
+ "torch_compile_enabled": os.environ.get("ENABLE_TORCH_COMPILE", "1") == "1",
661
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
662
+ }
663
+
664
+ @app.get("/", response_class=HTMLResponse)
665
+ async def get_homepage():
666
+ """Serve the main game interface"""
667
+ html_content = """
668
+ <!DOCTYPE html>
669
+ <html>
670
+ <head>
671
+ <title>Physics-informed BEV World Model</title>
672
+ <style>
673
+ body {
674
+ margin: 0;
675
+ padding: 20px;
676
+ background: #1a1a1a;
677
+ color: white;
678
+ font-family: 'Courier New', monospace;
679
+ text-align: center;
680
+ }
681
+ #gameCanvas {
682
+ border: 2px solid #00ff00;
683
+ background: #000;
684
+ margin: 20px auto;
685
+ display: block;
686
+ }
687
+ #controls {
688
+ margin: 20px;
689
+ display: grid;
690
+ grid-template-columns: 1fr 1fr;
691
+ gap: 20px;
692
+ max-width: 800px;
693
+ margin: 20px auto;
694
+ }
695
+ .control-section {
696
+ background: #2a2a2a;
697
+ padding: 15px;
698
+ border-radius: 8px;
699
+ border: 1px solid #444;
700
+ }
701
+ .key-display {
702
+ background: #333;
703
+ border: 1px solid #555;
704
+ padding: 5px 10px;
705
+ margin: 2px;
706
+ border-radius: 4px;
707
+ display: inline-block;
708
+ min-width: 30px;
709
+ }
710
+ .key-pressed {
711
+ background: #00ff00;
712
+ color: #000;
713
+ }
714
+ #status {
715
+ margin: 10px;
716
+ padding: 10px;
717
+ background: #2a2a2a;
718
+ border-radius: 4px;
719
+ }
720
+ .info {
721
+ color: #00ff00;
722
+ margin: 5px 0;
723
+ }
724
+ </style>
725
+ </head>
726
+ <body>
727
+ <h1>🎮 Physics-informed BEV World Model</h1>
728
+ <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
729
+ <p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p>
730
+
731
+ <!-- Model Download Progress -->
732
+ <div id="downloadSection" style="display: none; margin: 20px;">
733
+ <p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p>
734
+ <div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;">
735
+ <div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div>
736
+ </div>
737
+ <p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p>
738
+ </div>
739
+
740
+ <canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas>
741
+
742
+ <div id="status">
743
+ <div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
744
+ <div class="info">Game: <span id="gameStatus">Click to Start</span></div>
745
+ <div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div>
746
+ <div class="info">Reward: <span id="reward">0</span></div>
747
+ </div>
748
+
749
+ <div id="controls">
750
+ <div class="control-section">
751
+ <h3>Movement</h3>
752
+ <div>
753
+ <span class="key-display" id="key-w">W</span> Forward<br>
754
+ <span class="key-display" id="key-a">A</span> Left
755
+ <span class="key-display" id="key-s">S</span> Back
756
+ <span class="key-display" id="key-d">D</span> Right<br>
757
+ <span class="key-display" id="key-space">Space</span> Jump
758
+ <span class="key-display" id="key-ctrl">Ctrl</span> Crouch
759
+ <span class="key-display" id="key-shift">Shift</span> Walk
760
+ </div>
761
+ </div>
762
+
763
+ <div class="control-section">
764
+ <h3>Actions</h3>
765
+ <div>
766
+ <span class="key-display" id="key-1">1</span> Weapon 1<br>
767
+ <span class="key-display" id="key-2">2</span> Weapon 2
768
+ <span class="key-display" id="key-3">3</span> Weapon 3<br>
769
+ <span class="key-display" id="key-r">R</span> Reload<br>
770
+ <span class="key-display" id="key-arrows">↑↓←→</span> Camera<br>
771
+ <span class="key-display" id="key-enter">Enter</span> Reset Game<br>
772
+ <span class="key-display" id="key-esc">Esc</span> Pause/Quit
773
+ </div>
774
+ </div>
775
+ </div>
776
+
777
+ <script>
778
+ const canvas = document.getElementById('gameCanvas');
779
+ const ctx = canvas.getContext('2d');
780
+ const statusEl = document.getElementById('connectionStatus');
781
+ const gameStatusEl = document.getElementById('gameStatus');
782
+ const frameEl = document.getElementById('frameCount');
783
+ const aiFpsEl = document.getElementById('aiFps');
784
+ const rewardEl = document.getElementById('reward');
785
+ const loadingEl = document.getElementById('loadingIndicator');
786
+ const downloadSectionEl = document.getElementById('downloadSection');
787
+ const downloadStatusEl = document.getElementById('downloadStatus');
788
+ const progressBarEl = document.getElementById('progressBar');
789
+ const progressTextEl = document.getElementById('progressText');
790
+
791
+ let ws = null;
792
+ let pressedKeys = new Set();
793
+ let gameStarted = false;
794
+
795
+ // Key mapping
796
+ const keyDisplayMap = {
797
+ 'KeyW': 'key-w',
798
+ 'KeyA': 'key-a',
799
+ 'KeyS': 'key-s',
800
+ 'KeyD': 'key-d',
801
+ 'Space': 'key-space',
802
+ 'ControlLeft': 'key-ctrl',
803
+ 'ShiftLeft': 'key-shift',
804
+ 'Digit1': 'key-1',
805
+ 'Digit2': 'key-2',
806
+ 'Digit3': 'key-3',
807
+ 'KeyR': 'key-r',
808
+ 'ArrowUp': 'key-arrows',
809
+ 'ArrowDown': 'key-arrows',
810
+ 'ArrowLeft': 'key-arrows',
811
+ 'ArrowRight': 'key-arrows',
812
+ 'Enter': 'key-enter',
813
+ 'Escape': 'key-esc'
814
+ };
815
+
816
+ function connectWebSocket() {
817
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
818
+ const wsUrl = `${protocol}//${window.location.host}/ws`;
819
+
820
+ ws = new WebSocket(wsUrl);
821
+
822
+ ws.onopen = function(event) {
823
+ statusEl.textContent = 'Connected';
824
+ statusEl.style.color = '#00ff00';
825
+ // If user already clicked to start before WS was ready, send start now
826
+ if (gameStarted) {
827
+ ws.send(JSON.stringify({ type: 'start' }));
828
+ }
829
+ };
830
+
831
+ ws.onmessage = function(event) {
832
+ const data = JSON.parse(event.data);
833
+
834
+ if (data.type === 'loading') {
835
+ // Handle loading status
836
+ downloadSectionEl.style.display = 'block';
837
+ downloadStatusEl.textContent = data.status;
838
+
839
+ if (data.progress !== undefined) {
840
+ progressBarEl.style.width = data.progress + '%';
841
+ progressTextEl.textContent = data.progress + '% - ' + data.status;
842
+ } else {
843
+ progressTextEl.textContent = data.status;
844
+ }
845
+
846
+ gameStatusEl.textContent = 'Loading Models...';
847
+ gameStatusEl.style.color = '#ffaa00';
848
+
849
+ } else if (data.type === 'frame') {
850
+ // Hide loading indicators once we get frames
851
+ downloadSectionEl.style.display = 'none';
852
+ // Update frame display
853
+ if (data.image) {
854
+ const img = new Image();
855
+ img.onload = function() {
856
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
857
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
858
+ };
859
+ img.src = data.image;
860
+ }
861
+
862
+ frameEl.textContent = data.frame_count;
863
+ rewardEl.textContent = data.reward.toFixed(2);
864
+
865
+ // Update AI FPS display and hide loading indicator once AI starts
866
+ if (data.ai_fps !== undefined && data.ai_fps !== null) {
867
+ // Ensure FPS value is reasonable
868
+ const aiFps = Math.min(Math.max(data.ai_fps, 0), 100);
869
+ aiFpsEl.textContent = aiFps.toFixed(1);
870
+
871
+ // Color code AI FPS for performance indication
872
+ if (aiFps >= 8) {
873
+ aiFpsEl.style.color = '#00ff00'; // Green for good performance
874
+ } else if (aiFps >= 5) {
875
+ aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance
876
+ } else if (aiFps > 0) {
877
+ aiFpsEl.style.color = '#ff0000'; // Red for poor performance
878
+ } else {
879
+ aiFpsEl.style.color = '#888888'; // Gray for inactive
880
+ }
881
+
882
+ // Hide loading indicator once AI inference starts working
883
+ if (aiFps > 0 && gameStarted) {
884
+ loadingEl.style.display = 'none';
885
+ gameStatusEl.textContent = 'Playing';
886
+ gameStatusEl.style.color = '#00ff00';
887
+ }
888
+ }
889
+ }
890
+ };
891
+
892
+ ws.onclose = function(event) {
893
+ statusEl.textContent = 'Disconnected';
894
+ statusEl.style.color = '#ff0000';
895
+ setTimeout(connectWebSocket, 1000); // Reconnect after 1 second
896
+ };
897
+
898
+ ws.onerror = function(event) {
899
+ statusEl.textContent = 'Error';
900
+ statusEl.style.color = '#ff0000';
901
+ };
902
+ }
903
+
904
+ function sendKeyState() {
905
+ if (ws && ws.readyState === WebSocket.OPEN) {
906
+ ws.send(JSON.stringify({
907
+ type: 'keys',
908
+ keys: Array.from(pressedKeys)
909
+ }));
910
+ }
911
+ }
912
+
913
+ function startGame() {
914
+ if (ws && ws.readyState === WebSocket.OPEN) {
915
+ ws.send(JSON.stringify({
916
+ type: 'start'
917
+ }));
918
+ gameStarted = true;
919
+ gameStatusEl.textContent = 'Starting AI...';
920
+ gameStatusEl.style.color = '#ffff00';
921
+ loadingEl.style.display = 'block';
922
+ console.log('Game started');
923
+ }
924
+ }
925
+
926
+ function pauseGame() {
927
+ if (ws && ws.readyState === WebSocket.OPEN) {
928
+ ws.send(JSON.stringify({
929
+ type: 'pause'
930
+ }));
931
+ gameStarted = false;
932
+ gameStatusEl.textContent = 'Paused - Click to Resume';
933
+ gameStatusEl.style.color = '#ffff00';
934
+ console.log('Game paused');
935
+ }
936
+ }
937
+
938
+ function updateKeyDisplay() {
939
+ // Reset all key displays
940
+ Object.values(keyDisplayMap).forEach(id => {
941
+ const el = document.getElementById(id);
942
+ if (el) el.classList.remove('key-pressed');
943
+ });
944
+
945
+ // Highlight pressed keys
946
+ pressedKeys.forEach(key => {
947
+ const displayId = keyDisplayMap[key];
948
+ if (displayId) {
949
+ const el = document.getElementById(displayId);
950
+ if (el) el.classList.add('key-pressed');
951
+ }
952
+ });
953
+ }
954
+
955
+ // Focus canvas and handle keyboard events
956
+ canvas.addEventListener('click', () => {
957
+ canvas.focus();
958
+ if (!gameStarted) {
959
+ // Queue start locally and send immediately if WS is open
960
+ gameStarted = true;
961
+ gameStatusEl.textContent = 'Starting AI...';
962
+ gameStatusEl.style.color = '#ffff00';
963
+ loadingEl.style.display = 'block';
964
+ if (ws && ws.readyState === WebSocket.OPEN) {
965
+ ws.send(JSON.stringify({ type: 'start' }));
966
+ }
967
+ }
968
+ });
969
+
970
+ canvas.addEventListener('keydown', (event) => {
971
+ event.preventDefault();
972
+
973
+ // Handle special keys
974
+ if (event.code === 'Enter') {
975
+ if (ws && ws.readyState === WebSocket.OPEN) {
976
+ ws.send(JSON.stringify({
977
+ type: 'reset'
978
+ }));
979
+ console.log('Environment reset requested');
980
+ }
981
+ // Add to pressedKeys for visual feedback
982
+ pressedKeys.add(event.code);
983
+ updateKeyDisplay();
984
+
985
+ // Remove Enter from pressedKeys after a short delay for visual feedback
986
+ setTimeout(() => {
987
+ pressedKeys.delete(event.code);
988
+ updateKeyDisplay();
989
+ }, 200);
990
+ } else if (event.code === 'Escape') {
991
+ pauseGame();
992
+ // Add to pressedKeys for visual feedback
993
+ pressedKeys.add(event.code);
994
+ updateKeyDisplay();
995
+
996
+ // Remove ESC from pressedKeys after a short delay for visual feedback
997
+ setTimeout(() => {
998
+ pressedKeys.delete(event.code);
999
+ updateKeyDisplay();
1000
+ }, 200);
1001
+ } else {
1002
+ // Only send game keys if game is started
1003
+ if (gameStarted) {
1004
+ pressedKeys.add(event.code);
1005
+ updateKeyDisplay();
1006
+ sendKeyState();
1007
+ }
1008
+ }
1009
+ });
1010
+
1011
+ canvas.addEventListener('keyup', (event) => {
1012
+ event.preventDefault();
1013
+
1014
+ // Don't handle special keys release (handled in keydown with timeout)
1015
+ if (event.code !== 'Enter' && event.code !== 'Escape') {
1016
+ if (gameStarted) {
1017
+ pressedKeys.delete(event.code);
1018
+ updateKeyDisplay();
1019
+ sendKeyState();
1020
+ }
1021
+ }
1022
+ });
1023
+
1024
+ // Handle mouse events for clicks
1025
+ canvas.addEventListener('mousedown', (event) => {
1026
+ if (ws && ws.readyState === WebSocket.OPEN) {
1027
+ ws.send(JSON.stringify({
1028
+ type: 'mouse',
1029
+ button: event.button,
1030
+ action: 'down',
1031
+ x: event.offsetX,
1032
+ y: event.offsetY
1033
+ }));
1034
+ }
1035
+ });
1036
+
1037
+ canvas.addEventListener('mouseup', (event) => {
1038
+ if (ws && ws.readyState === WebSocket.OPEN) {
1039
+ ws.send(JSON.stringify({
1040
+ type: 'mouse',
1041
+ button: event.button,
1042
+ action: 'up',
1043
+ x: event.offsetX,
1044
+ y: event.offsetY
1045
+ }));
1046
+ }
1047
+ });
1048
+
1049
+ // Initialize
1050
+ connectWebSocket();
1051
+ canvas.focus();
1052
+ </script>
1053
+ </body>
1054
+ </html>
1055
+ """
1056
+ return html_content
1057
+
1058
+ @app.websocket("/ws")
1059
+ async def websocket_endpoint(websocket: WebSocket):
1060
+ """Handle WebSocket connections for real-time game communication"""
1061
+ await websocket.accept()
1062
+ connected_clients.add(websocket)
1063
+
1064
+ try:
1065
+ while True:
1066
+ # Receive messages from client
1067
+ data = await websocket.receive_text()
1068
+ message = json.loads(data)
1069
+
1070
+ if message['type'] == 'keys':
1071
+ # Update pressed keys
1072
+ game_engine.pressed_keys = set(message['keys'])
1073
+
1074
+ elif message['type'] == 'reset':
1075
+ # Handle environment reset request
1076
+ game_engine.request_reset()
1077
+
1078
+ elif message['type'] == 'start':
1079
+ # Handle game start request
1080
+ game_engine.start_game()
1081
+
1082
+ elif message['type'] == 'pause':
1083
+ # Handle game pause request
1084
+ game_engine.pause_game()
1085
+
1086
+ elif message['type'] == 'mouse':
1087
+ # Handle mouse events
1088
+ if message['action'] == 'down':
1089
+ if message['button'] == 0: # Left click
1090
+ game_engine.l_click = True
1091
+ elif message['button'] == 2: # Right click
1092
+ game_engine.r_click = True
1093
+ elif message['action'] == 'up':
1094
+ if message['button'] == 0: # Left click
1095
+ game_engine.l_click = False
1096
+ elif message['button'] == 2: # Right click
1097
+ game_engine.r_click = False
1098
+
1099
+ # Update mouse position (relative to canvas)
1100
+ game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px
1101
+ game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px
1102
+
1103
+ except WebSocketDisconnect:
1104
+ connected_clients.discard(websocket)
1105
+ except Exception as e:
1106
+ logger.error(f"WebSocket error: {e}")
1107
+ connected_clients.discard(websocket)
1108
+
1109
+ if __name__ == "__main__":
1110
+ # For local development
1111
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
1112
+
app_config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Spaces configuration
3
+ """
4
+
5
+ import os
6
+
7
+ # Hugging Face Spaces configuration
8
+ title = "Diamond CSGO AI Player"
9
+ description = """
10
+ # 🎮 Diamond CSGO AI Player
11
+
12
+ Experience an AI agent trained with diffusion models playing Counter-Strike: Global Offensive!
13
+
14
+ ## How to Play
15
+ 1. Click on the game canvas to focus it
16
+ 2. Use keyboard controls:
17
+ - **WASD** - Movement
18
+ - **Space** - Jump
19
+ - **Ctrl** - Crouch
20
+ - **1,2,3** - Switch weapons
21
+ - **R** - Reload
22
+ - **Arrow Keys** - Camera movement
23
+ - **Mouse clicks** - Fire
24
+ 3. Press **M** to switch between Human/AI control
25
+
26
+ ## Technical Details
27
+ This demo showcases the Diamond framework, which combines:
28
+ - Diffusion models for world modeling
29
+ - Actor-critic reinforcement learning
30
+ - Multi-step planning in imagination
31
+
32
+ The AI learns to play by predicting future game states and optimizing actions through a learned world model.
33
+ """
34
+
35
+ # App settings
36
+ port = int(os.getenv("PORT", 7860))
37
+ host = "0.0.0.0"
copy.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Essential files to copy to your HF Space directory:
2
+ cp app.py /home/alienware3/Documents/diamond-ai-player/
3
+ cp requirements.txt /home/alienware3/Documents/diamond-ai-player/
4
+ cp Dockerfile /home/alienware3/Documents/diamond-ai-player/
5
+ cp packages.txt /home/alienware3/Documents/diamond-ai-player/
6
+ cp README.md /home/alienware3/Documents/diamond-ai-player/
7
+ cp config_web.py /home/alienware3/Documents/diamond-ai-player/
8
+
9
+ # Copy entire directories
10
+ cp -r src/ /home/alienware3/Documents/diamond-ai-player/
11
+ cp -r config/ /home/alienware3/Documents/diamond-ai-player/
12
+ cp -r csgo/ /home/alienware3/Documents/diamond-ai-player/
13
+
14
+ # Copy your trained model (choose one)
15
+ # cp agent_epoch_00003.pt /home/alienware3/Documents/diamond-ai-player/ # OR
16
+ #cp agent_epoch_00206.pt /home/alienware3/Documents/diamond-ai-player/
17
+
18
+ cd /home/alienware3/Documents/diamond-ai-player/
19
+ git add .
20
+ git commit -m "Fix initial bugs"
21
+ git push origin main
debug_app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Debug version of the web app to isolate the black screen issue
4
+ """
5
+
6
+ import asyncio
7
+ import base64
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+ from typing import Dict, List, Optional, Set
15
+
16
+ import numpy as np
17
+ import torch
18
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
19
+ from fastapi.responses import HTMLResponse
20
+ from PIL import Image
21
+
22
+ # Add src to Python path
23
+ src_path = Path(__file__).parent / "src"
24
+ sys.path.insert(0, str(src_path))
25
+
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.DEBUG)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Simple debug app
31
+ app = FastAPI(title="Diamond CSGO AI Player - Debug")
32
+
33
+ class DebugGameEngine:
34
+ """Simple debug version to test image rendering"""
35
+
36
+ def __init__(self):
37
+ self.obs = None
38
+ self.frame_count = 0
39
+ self.initialized = False
40
+
41
+ def create_test_observation(self):
42
+ """Create a test observation for debugging"""
43
+ # Create a test image with some pattern
44
+ height, width = 150, 600
45
+
46
+ # Create RGB pattern
47
+ img_array = np.zeros((height, width, 3), dtype=np.uint8)
48
+
49
+ # Add some pattern to make it visible
50
+ for y in range(height):
51
+ for x in range(width):
52
+ img_array[y, x, 0] = (x % 256) # Red gradient
53
+ img_array[y, x, 1] = (y % 256) # Green gradient
54
+ img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
55
+
56
+ # Convert to torch tensor in the expected format [-1, 1]
57
+ tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
58
+ tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
59
+ tensor = tensor.unsqueeze(0) # Add batch dimension
60
+
61
+ logger.info(f"Created test observation: {tensor.shape}, range: {tensor.min():.3f} to {tensor.max():.3f}")
62
+ return tensor
63
+
64
+ def obs_to_base64(self, obs: torch.Tensor) -> str:
65
+ """Convert observation tensor to base64 image for web display"""
66
+ if obs is None:
67
+ logger.warning("Observation is None")
68
+ return ""
69
+
70
+ try:
71
+ logger.debug(f"Converting obs: shape={obs.shape}, dtype={obs.dtype}")
72
+
73
+ # Convert tensor to PIL Image
74
+ if obs.ndim == 4 and obs.size(0) == 1:
75
+ img_array = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
76
+ else:
77
+ img_array = obs.add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
78
+
79
+ logger.debug(f"Image array: shape={img_array.shape}, range={img_array.min()} to {img_array.max()}")
80
+
81
+ img = Image.fromarray(img_array)
82
+
83
+ # Resize for web display
84
+ img = img.resize((600, 300), Image.BICUBIC)
85
+
86
+ # Convert to base64
87
+ buffer = io.BytesIO()
88
+ img.save(buffer, format='PNG')
89
+ img_str = base64.b64encode(buffer.getvalue()).decode()
90
+
91
+ logger.debug(f"Successfully converted to base64, length: {len(img_str)}")
92
+ return f"data:image/png;base64,{img_str}"
93
+
94
+ except Exception as e:
95
+ logger.error(f"Error converting observation to base64: {e}")
96
+ import traceback
97
+ traceback.print_exc()
98
+ return ""
99
+
100
+ async def initialize(self):
101
+ """Initialize with test data"""
102
+ logger.info("Initializing debug game engine...")
103
+ self.obs = self.create_test_observation()
104
+ self.initialized = True
105
+ logger.info("Debug game engine initialized successfully!")
106
+ return True
107
+
108
+ # Global debug engine
109
+ debug_engine = DebugGameEngine()
110
+ connected_clients: Set[WebSocket] = set()
111
+
112
+ @app.on_event("startup")
113
+ async def startup_event():
114
+ """Initialize debug engine"""
115
+ success = await debug_engine.initialize()
116
+ if success:
117
+ # Start a simple game loop
118
+ asyncio.create_task(debug_game_loop())
119
+
120
+ async def debug_game_loop():
121
+ """Simple debug game loop"""
122
+ while True:
123
+ try:
124
+ if debug_engine.initialized and connected_clients:
125
+ # Send test frame to all connected clients
126
+ frame_data = {
127
+ 'type': 'frame',
128
+ 'image': debug_engine.obs_to_base64(debug_engine.obs),
129
+ 'frame_count': debug_engine.frame_count,
130
+ 'reward': 0.0,
131
+ 'info': f"Debug frame {debug_engine.frame_count}"
132
+ }
133
+
134
+ logger.debug(f"Sending frame {debug_engine.frame_count}")
135
+
136
+ # Send to all connected clients
137
+ disconnected = set()
138
+ for client in connected_clients.copy():
139
+ try:
140
+ await client.send_text(json.dumps(frame_data))
141
+ except Exception as e:
142
+ logger.error(f"Error sending to client: {e}")
143
+ disconnected.add(client)
144
+
145
+ # Remove disconnected clients
146
+ connected_clients.difference_update(disconnected)
147
+
148
+ debug_engine.frame_count += 1
149
+
150
+ await asyncio.sleep(1.0 / 15) # 15 FPS
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error in debug game loop: {e}")
154
+ await asyncio.sleep(0.1)
155
+
156
+ @app.get("/", response_class=HTMLResponse)
157
+ async def get_homepage():
158
+ """Serve debug interface"""
159
+ html_content = """
160
+ <!DOCTYPE html>
161
+ <html>
162
+ <head>
163
+ <title>Diamond CSGO AI - Debug</title>
164
+ <style>
165
+ body {
166
+ margin: 0;
167
+ padding: 20px;
168
+ background: #1a1a1a;
169
+ color: white;
170
+ font-family: monospace;
171
+ text-align: center;
172
+ }
173
+ #gameCanvas {
174
+ border: 2px solid #00ff00;
175
+ background: #000;
176
+ margin: 20px auto;
177
+ display: block;
178
+ }
179
+ #status {
180
+ margin: 10px;
181
+ padding: 10px;
182
+ background: #2a2a2a;
183
+ border-radius: 4px;
184
+ }
185
+ .info {
186
+ color: #00ff00;
187
+ margin: 5px 0;
188
+ }
189
+ </style>
190
+ </head>
191
+ <body>
192
+ <h1>🔧 Diamond CSGO AI - Debug Mode</h1>
193
+ <p>Testing image rendering and WebSocket communication</p>
194
+
195
+ <canvas id="gameCanvas" width="600" height="300"></canvas>
196
+
197
+ <div id="status">
198
+ <div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
199
+ <div class="info">Frame: <span id="frameCount">0</span></div>
200
+ <div class="info">Info: <span id="info">Waiting...</span></div>
201
+ </div>
202
+
203
+ <div id="debug">
204
+ <h3>Debug Information</h3>
205
+ <p id="lastUpdate">No updates yet</p>
206
+ <p id="imageInfo">No image data</p>
207
+ </div>
208
+
209
+ <script>
210
+ const canvas = document.getElementById('gameCanvas');
211
+ const ctx = canvas.getContext('2d');
212
+ const statusEl = document.getElementById('connectionStatus');
213
+ const frameEl = document.getElementById('frameCount');
214
+ const infoEl = document.getElementById('info');
215
+ const lastUpdateEl = document.getElementById('lastUpdate');
216
+ const imageInfoEl = document.getElementById('imageInfo');
217
+
218
+ let ws = null;
219
+
220
+ function connectWebSocket() {
221
+ const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
222
+ const wsUrl = `${protocol}//${window.location.host}/ws`;
223
+ console.log('Connecting to:', wsUrl);
224
+
225
+ ws = new WebSocket(wsUrl);
226
+
227
+ ws.onopen = function(event) {
228
+ console.log('WebSocket connected');
229
+ statusEl.textContent = 'Connected';
230
+ statusEl.style.color = '#00ff00';
231
+ };
232
+
233
+ ws.onmessage = function(event) {
234
+ console.log('Received message');
235
+ const data = JSON.parse(event.data);
236
+
237
+ if (data.type === 'frame') {
238
+ console.log('Received frame:', data.frame_count);
239
+ lastUpdateEl.textContent = `Last update: ${new Date().toLocaleTimeString()}`;
240
+
241
+ // Update frame display
242
+ if (data.image && data.image.length > 0) {
243
+ console.log('Loading image, length:', data.image.length);
244
+ imageInfoEl.textContent = `Image data length: ${data.image.length}`;
245
+
246
+ const img = new Image();
247
+ img.onload = function() {
248
+ console.log('Image loaded successfully');
249
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
250
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
251
+ };
252
+ img.onerror = function() {
253
+ console.error('Failed to load image');
254
+ imageInfoEl.textContent = 'Failed to load image';
255
+ };
256
+ img.src = data.image;
257
+ } else {
258
+ console.warn('No image data received');
259
+ imageInfoEl.textContent = 'No image data received';
260
+ }
261
+
262
+ frameEl.textContent = data.frame_count;
263
+ infoEl.textContent = data.info || 'No info';
264
+ }
265
+ };
266
+
267
+ ws.onclose = function(event) {
268
+ console.log('WebSocket disconnected');
269
+ statusEl.textContent = 'Disconnected';
270
+ statusEl.style.color = '#ff0000';
271
+ setTimeout(connectWebSocket, 1000);
272
+ };
273
+
274
+ ws.onerror = function(event) {
275
+ console.error('WebSocket error:', event);
276
+ statusEl.textContent = 'Error';
277
+ statusEl.style.color = '#ff0000';
278
+ };
279
+ }
280
+
281
+ // Initialize
282
+ connectWebSocket();
283
+ </script>
284
+ </body>
285
+ </html>
286
+ """
287
+ return html_content
288
+
289
+ @app.websocket("/ws")
290
+ async def websocket_endpoint(websocket: WebSocket):
291
+ """Handle WebSocket connections"""
292
+ await websocket.accept()
293
+ connected_clients.add(websocket)
294
+ logger.info(f"Client connected. Total clients: {len(connected_clients)}")
295
+
296
+ try:
297
+ while True:
298
+ # Just wait for disconnection
299
+ await websocket.receive_text()
300
+ except WebSocketDisconnect:
301
+ connected_clients.discard(websocket)
302
+ logger.info(f"Client disconnected. Total clients: {len(connected_clients)}")
303
+ except Exception as e:
304
+ logger.error(f"WebSocket error: {e}")
305
+ connected_clients.discard(websocket)
306
+
307
+ if __name__ == "__main__":
308
+ import uvicorn
309
+ uvicorn.run("debug_app:app", host="0.0.0.0", port=7861, reload=False)
download_models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to download models from Hugging Face Hub if not present locally
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ from pathlib import Path
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def download_checkpoint_if_needed():
14
+ """Download model checkpoint if not present locally"""
15
+ # Check if we have any local checkpoints
16
+ possible_checkpoints = [
17
+ Path("agent_epoch_00206.pt"),
18
+ Path("agent_epoch_00003.pt"),
19
+ Path("checkpoints/agent_epoch_00206.pt"),
20
+ Path("checkpoints/agent_epoch_00003.pt"),
21
+ ]
22
+
23
+ for ckpt_path in possible_checkpoints:
24
+ if ckpt_path.exists():
25
+ logger.info(f"Found local checkpoint: {ckpt_path}")
26
+ return True
27
+
28
+ logger.info("No local checkpoint found, attempting to download from Hugging Face Hub...")
29
+
30
+ try:
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ # This would download from a hypothetical HF model repository
34
+ # You would need to upload your models to HF Hub first
35
+ # Example:
36
+ # checkpoint_path = hf_hub_download(
37
+ # repo_id="your-username/diamond-csgo-model",
38
+ # filename="agent_epoch_00206.pt",
39
+ # cache_dir="./checkpoints"
40
+ # )
41
+
42
+ logger.warning("Model download not implemented yet.")
43
+ logger.warning("Please ensure you have model checkpoints available locally.")
44
+ return False
45
+
46
+ except ImportError:
47
+ logger.error("huggingface_hub not installed. Cannot download models.")
48
+ return False
49
+ except Exception as e:
50
+ logger.error(f"Failed to download models: {e}")
51
+ return False
52
+
53
+ def setup_demo_data():
54
+ """Set up minimal demo data if models are not available"""
55
+ spawn_dir = Path("csgo/spawn/0")
56
+ spawn_dir.mkdir(parents=True, exist_ok=True)
57
+
58
+ # Create minimal dummy files for demo
59
+ import numpy as np
60
+ import json
61
+
62
+ files_to_create = {
63
+ "act.npy": np.zeros((100, 51)), # 100 timesteps, 51 actions
64
+ "low_res.npy": np.zeros((100, 3, 150, 600)), # 100 frames
65
+ "full_res.npy": np.zeros((100, 3, 300, 1200)), # 100 high-res frames
66
+ "next_act.npy": np.zeros((100, 51)),
67
+ }
68
+
69
+ for filename, data in files_to_create.items():
70
+ file_path = spawn_dir / filename
71
+ if not file_path.exists():
72
+ np.save(file_path, data)
73
+ logger.info(f"Created dummy file: {file_path}")
74
+
75
+ # Create info.json
76
+ info_path = spawn_dir / "info.json"
77
+ if not info_path.exists():
78
+ info_data = {
79
+ "episode_length": 100,
80
+ "total_reward": 0.0,
81
+ "demo": True
82
+ }
83
+ with open(info_path, 'w') as f:
84
+ json.dump(info_data, f)
85
+ logger.info(f"Created info file: {info_path}")
86
+
87
+ if __name__ == "__main__":
88
+ logger.info("Setting up Diamond CSGO demo...")
89
+
90
+ # Try to download models
91
+ has_models = download_checkpoint_if_needed()
92
+
93
+ # Set up demo data
94
+ setup_demo_data()
95
+
96
+ if not has_models:
97
+ logger.warning("Running in demo mode without trained models.")
98
+ logger.warning("The AI agent will not function properly without model checkpoints.")
99
+
100
+ logger.info("Setup complete!")
low_res.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50e10ae563ce86e98c58de99c2d3a0a4c0f111f6fde2c87323bd1d5eb5210d1a
3
+ size 20288
npy_reader.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+
5
+ def read_npy_file(file_path):
6
+ """
7
+ Read a .npy file and return the numpy array.
8
+
9
+ Args:
10
+ file_path (str): Path to the .npy file
11
+
12
+ Returns:
13
+ numpy.ndarray: The loaded numpy array
14
+
15
+ Raises:
16
+ FileNotFoundError: If the file doesn't exist
17
+ ValueError: If the file is not a valid .npy file
18
+ """
19
+ try:
20
+ # Check if file exists
21
+ if not os.path.exists(file_path):
22
+ raise FileNotFoundError(f"File not found: {file_path}")
23
+
24
+ # Check if file has .npy extension
25
+ if not file_path.lower().endswith('.npy'):
26
+ print(f"Warning: File {file_path} doesn't have .npy extension")
27
+
28
+ # Load the numpy array
29
+ array = np.load(file_path)
30
+
31
+ print(f"Successfully loaded {file_path}")
32
+ print(f"Array shape: {array.shape}")
33
+ print(f"Array dtype: {array.dtype}")
34
+ print(f"Array size: {array.size}")
35
+
36
+ return array
37
+
38
+ except Exception as e:
39
+ raise ValueError(f"Error reading {file_path}: {str(e)}")
40
+
41
+ def print_array_info(array, show_data=False, max_elements=10):
42
+ """
43
+ Print detailed information about a numpy array.
44
+
45
+ Args:
46
+ array (numpy.ndarray): The numpy array to analyze
47
+ show_data (bool): Whether to show actual data values
48
+ max_elements (int): Maximum number of elements to display
49
+ """
50
+ print("\n" + "="*50)
51
+ print("ARRAY INFORMATION")
52
+ print("="*50)
53
+ print(f"Shape: {array.shape}")
54
+ print(f"Dtype: {array.dtype}")
55
+ print(f"Size: {array.size}")
56
+ print(f"Dimensions: {array.ndim}")
57
+ print(f"Memory usage: {array.nbytes} bytes")
58
+
59
+ if array.size > 0:
60
+ print(f"Min value: {np.min(array)}")
61
+ print(f"Max value: {np.max(array)}")
62
+ print(f"Mean value: {np.mean(array)}")
63
+ print(f"Std deviation: {np.std(array)}")
64
+
65
+ if show_data:
66
+ print("\nArray data:")
67
+ if array.size <= max_elements:
68
+ print(array)
69
+ else:
70
+ print(f"First {max_elements} elements:")
71
+ print(array.flat[:max_elements])
72
+ print("...")
73
+
74
+ def main():
75
+ """
76
+ Main function to demonstrate usage.
77
+ """
78
+
79
+ file_path = "/home/alienware3/Documents/diamond/low_res.npy"
80
+
81
+ try:
82
+ # Read the NPY file
83
+ array = read_npy_file(file_path)
84
+
85
+ # Print detailed information
86
+ print_array_info(array, show_data=True, max_elements=20)
87
+
88
+ except Exception as e:
89
+ print(f"Error: {e}")
90
+ sys.exit(1)
91
+
92
+ if __name__ == "__main__":
93
+ main()
run_web_demo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Script to run the web demo locally for testing
4
+ """
5
+
6
+ import logging
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add project root to Python path
11
+ root_path = Path(__file__).parent
12
+ sys.path.insert(0, str(root_path))
13
+
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
18
+ )
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def main():
22
+ """Run the web demo"""
23
+ logger.info("Starting Diamond CSGO AI Player web demo...")
24
+
25
+ try:
26
+ import uvicorn
27
+ from app import app
28
+
29
+ # Run the server
30
+ uvicorn.run(
31
+ app,
32
+ host="0.0.0.0",
33
+ port=7860,
34
+ log_level="info",
35
+ reload=False # Disable reload for production
36
+ )
37
+
38
+ except ImportError as e:
39
+ logger.error(f"Missing dependencies: {e}")
40
+ logger.error("Please install requirements: pip install -r requirements.txt")
41
+ sys.exit(1)
42
+
43
+ except Exception as e:
44
+ logger.error(f"Failed to start server: {e}")
45
+ sys.exit(1)
46
+
47
+ if __name__ == "__main__":
48
+ main()
test_web_app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify the web app can be imported and basic functionality works
4
+ """
5
+
6
+ import logging
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ # Add src to Python path
11
+ src_path = Path(__file__).parent / "src"
12
+ sys.path.insert(0, str(src_path))
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ def test_imports():
18
+ """Test if all required modules can be imported"""
19
+ logger.info("Testing imports...")
20
+
21
+ try:
22
+ # Test core modules
23
+ from config_web import web_config
24
+ logger.info("Web config imported")
25
+
26
+ from src.csgo.web_action_processing import WebCSGOAction
27
+ logger.info("Web action processing imported")
28
+
29
+ from src.game.web_play_env import WebPlayEnv
30
+ logger.info("Web play environment imported")
31
+
32
+ # Test web framework
33
+ import fastapi
34
+ import uvicorn
35
+ logger.info("Web framework imported")
36
+
37
+ # Test app
38
+ from app import app, WebGameEngine
39
+ logger.info("Main app imported")
40
+
41
+ return True
42
+
43
+ except Exception as e:
44
+ logger.error(f"Import failed: {e}")
45
+ return False
46
+
47
+ def test_config():
48
+ """Test configuration setup"""
49
+ logger.info("Testing configuration...")
50
+
51
+ try:
52
+ from config_web import web_config
53
+
54
+ # Test path resolution
55
+ config_path = web_config.get_config_path()
56
+ logger.info(f"Config path: {config_path}")
57
+
58
+ spawn_dir = web_config.get_spawn_dir()
59
+ logger.info(f"Spawn directory: {spawn_dir}")
60
+
61
+ checkpoint_path = web_config.get_checkpoint_path()
62
+ logger.info(f"Checkpoint path: {checkpoint_path}")
63
+
64
+ return True
65
+
66
+ except Exception as e:
67
+ logger.error(f"Configuration test failed: {e}")
68
+ return False
69
+
70
+ def test_action_processing():
71
+ """Test web action processing"""
72
+ logger.info("Testing action processing...")
73
+
74
+ try:
75
+ from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
76
+
77
+ # Test key mapping
78
+ test_keys = {'KeyW', 'KeyA', 'Space'}
79
+ action_names = web_keys_to_csgo_action_names(test_keys)
80
+ logger.info(f"Key mapping: {test_keys} -> {action_names}")
81
+
82
+ # Test action creation
83
+ action = WebCSGOAction(
84
+ key_names=action_names,
85
+ mouse_x=10,
86
+ mouse_y=5,
87
+ l_click=False,
88
+ r_click=False
89
+ )
90
+ logger.info(f"Action created: {action}")
91
+
92
+ return True
93
+
94
+ except Exception as e:
95
+ logger.error(f"Action processing test failed: {e}")
96
+ return False
97
+
98
+ def main():
99
+ """Run all tests"""
100
+ logger.info("Starting Diamond CSGO web app tests...")
101
+
102
+ tests = [
103
+ ("Imports", test_imports),
104
+ ("Configuration", test_config),
105
+ ("Action Processing", test_action_processing),
106
+ ]
107
+
108
+ passed = 0
109
+ total = len(tests)
110
+
111
+ for name, test_func in tests:
112
+ logger.info(f"\n--- Testing {name} ---")
113
+ if test_func():
114
+ logger.info(f"PASS: {name} test passed")
115
+ passed += 1
116
+ else:
117
+ logger.error(f"FAIL: {name} test failed")
118
+
119
+ logger.info(f"\n=== Test Results ===")
120
+ logger.info(f"Passed: {passed}/{total}")
121
+
122
+ if passed == total:
123
+ logger.info("All tests passed! The web app should work correctly.")
124
+ return True
125
+ else:
126
+ logger.error("Some tests failed. Please check the errors above.")
127
+ return False
128
+
129
+ if __name__ == "__main__":
130
+ success = main()
131
+ sys.exit(0 if success else 1)