Spaces:
Sleeping
Sleeping
Commit
·
02c6351
1
Parent(s):
d294854
Fix initial bugs
Browse files- =0.16.0 +117 -0
- =4.0.0 +0 -0
- DEPLOYMENT.md +193 -0
- HF_SPACES_CACHE_FIX.md +84 -0
- HF_SPACES_DEPLOYMENT_GUIDE.md +201 -0
- HF_SPACES_GPU_FIX.md +155 -0
- __pycache__/app.cpython-310.pyc +0 -0
- __pycache__/config_web.cpython-310.pyc +0 -0
- __pycache__/debug_app.cpython-310.pyc +0 -0
- __pycache__/hf_space_config.cpython-310.pyc +0 -0
- app.py.backup +1114 -0
- app.py.backup2 +1112 -0
- app_config.py +37 -0
- copy.sh +21 -0
- debug_app.py +309 -0
- download_models.py +100 -0
- low_res.npy +3 -0
- npy_reader.py +93 -0
- run_web_demo.py +48 -0
- test_web_app.py +131 -0
=0.16.0
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Collecting gradio
|
| 2 |
+
Downloading gradio-5.44.1-py3-none-any.whl.metadata (16 kB)
|
| 3 |
+
Requirement already satisfied: huggingface-hub in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (0.26.1)
|
| 4 |
+
Collecting aiofiles<25.0,>=22.0 (from gradio)
|
| 5 |
+
Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
|
| 6 |
+
Requirement already satisfied: anyio<5.0,>=3.0 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (4.9.0)
|
| 7 |
+
Collecting brotli>=1.1.0 (from gradio)
|
| 8 |
+
Downloading Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.5 kB)
|
| 9 |
+
Collecting fastapi<1.0,>=0.115.2 (from gradio)
|
| 10 |
+
Downloading fastapi-0.116.1-py3-none-any.whl.metadata (28 kB)
|
| 11 |
+
Collecting ffmpy (from gradio)
|
| 12 |
+
Downloading ffmpy-0.6.1-py3-none-any.whl.metadata (2.9 kB)
|
| 13 |
+
Collecting gradio-client==1.12.1 (from gradio)
|
| 14 |
+
Downloading gradio_client-1.12.1-py3-none-any.whl.metadata (7.1 kB)
|
| 15 |
+
Collecting groovy~=0.1 (from gradio)
|
| 16 |
+
Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
|
| 17 |
+
Requirement already satisfied: httpx<1.0,>=0.24.1 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (0.28.1)
|
| 18 |
+
Collecting huggingface-hub
|
| 19 |
+
Using cached huggingface_hub-0.34.4-py3-none-any.whl.metadata (14 kB)
|
| 20 |
+
Requirement already satisfied: jinja2<4.0 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (3.1.6)
|
| 21 |
+
Requirement already satisfied: markupsafe<4.0,>=2.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (3.0.1)
|
| 22 |
+
Requirement already satisfied: numpy<3.0,>=1.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (1.26.0)
|
| 23 |
+
Collecting orjson~=3.0 (from gradio)
|
| 24 |
+
Downloading orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (41 kB)
|
| 25 |
+
Requirement already satisfied: packaging in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (25.0)
|
| 26 |
+
Requirement already satisfied: pandas<3.0,>=1.0 in /home/alienware3/.local/lib/python3.10/site-packages (from gradio) (2.3.0)
|
| 27 |
+
Requirement already satisfied: pillow<12.0,>=8.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (10.3.0)
|
| 28 |
+
Collecting pydantic<2.12,>=2.0 (from gradio)
|
| 29 |
+
Downloading pydantic-2.11.7-py3-none-any.whl.metadata (67 kB)
|
| 30 |
+
Collecting pydub (from gradio)
|
| 31 |
+
Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
|
| 32 |
+
Collecting python-multipart>=0.0.18 (from gradio)
|
| 33 |
+
Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
|
| 34 |
+
Requirement already satisfied: pyyaml<7.0,>=5.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (6.0.2)
|
| 35 |
+
Collecting ruff>=0.9.3 (from gradio)
|
| 36 |
+
Downloading ruff-0.12.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
|
| 37 |
+
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
|
| 38 |
+
Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
|
| 39 |
+
Collecting semantic-version~=2.0 (from gradio)
|
| 40 |
+
Downloading semantic_version-2.10.0-py2.py3-none-any.whl.metadata (9.7 kB)
|
| 41 |
+
Collecting starlette<1.0,>=0.40.0 (from gradio)
|
| 42 |
+
Downloading starlette-0.47.3-py3-none-any.whl.metadata (6.2 kB)
|
| 43 |
+
Collecting tomlkit<0.14.0,>=0.12.0 (from gradio)
|
| 44 |
+
Downloading tomlkit-0.13.3-py3-none-any.whl.metadata (2.8 kB)
|
| 45 |
+
Collecting typer<1.0,>=0.12 (from gradio)
|
| 46 |
+
Downloading typer-0.17.4-py3-none-any.whl.metadata (15 kB)
|
| 47 |
+
Requirement already satisfied: typing-extensions~=4.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio) (4.12.2)
|
| 48 |
+
Collecting uvicorn>=0.14.0 (from gradio)
|
| 49 |
+
Downloading uvicorn-0.35.0-py3-none-any.whl.metadata (6.5 kB)
|
| 50 |
+
Requirement already satisfied: fsspec in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from gradio-client==1.12.1->gradio) (2024.9.0)
|
| 51 |
+
Collecting websockets<16.0,>=10.0 (from gradio-client==1.12.1->gradio)
|
| 52 |
+
Downloading websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
|
| 53 |
+
Requirement already satisfied: filelock in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from huggingface-hub) (3.16.1)
|
| 54 |
+
Requirement already satisfied: requests in /home/alienware3/.local/lib/python3.10/site-packages (from huggingface-hub) (2.32.4)
|
| 55 |
+
Requirement already satisfied: tqdm>=4.42.1 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from huggingface-hub) (4.66.4)
|
| 56 |
+
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub)
|
| 57 |
+
Using cached hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
|
| 58 |
+
Requirement already satisfied: exceptiongroup>=1.0.2 in /home/alienware3/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (1.3.0)
|
| 59 |
+
Requirement already satisfied: idna>=2.8 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (3.10)
|
| 60 |
+
Requirement already satisfied: sniffio>=1.1 in /home/alienware3/.local/lib/python3.10/site-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)
|
| 61 |
+
Requirement already satisfied: certifi in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from httpx<1.0,>=0.24.1->gradio) (2024.8.30)
|
| 62 |
+
Requirement already satisfied: httpcore==1.* in /home/alienware3/.local/lib/python3.10/site-packages (from httpx<1.0,>=0.24.1->gradio) (1.0.9)
|
| 63 |
+
Requirement already satisfied: h11>=0.16 in /home/alienware3/.local/lib/python3.10/site-packages (from httpcore==1.*->httpx<1.0,>=0.24.1->gradio) (0.16.0)
|
| 64 |
+
Requirement already satisfied: python-dateutil>=2.8.2 in /home/alienware3/.local/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2.9.0.post0)
|
| 65 |
+
Requirement already satisfied: pytz>=2020.1 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2024.2)
|
| 66 |
+
Requirement already satisfied: tzdata>=2022.7 in /home/alienware3/.local/lib/python3.10/site-packages (from pandas<3.0,>=1.0->gradio) (2025.2)
|
| 67 |
+
Collecting annotated-types>=0.6.0 (from pydantic<2.12,>=2.0->gradio)
|
| 68 |
+
Downloading annotated_types-0.7.0-py3-none-any.whl.metadata (15 kB)
|
| 69 |
+
Collecting pydantic-core==2.33.2 (from pydantic<2.12,>=2.0->gradio)
|
| 70 |
+
Downloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
|
| 71 |
+
Collecting typing-inspection>=0.4.0 (from pydantic<2.12,>=2.0->gradio)
|
| 72 |
+
Downloading typing_inspection-0.4.1-py3-none-any.whl.metadata (2.6 kB)
|
| 73 |
+
Requirement already satisfied: click>=8.0.0 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio) (8.1.7)
|
| 74 |
+
Collecting shellingham>=1.3.0 (from typer<1.0,>=0.12->gradio)
|
| 75 |
+
Downloading shellingham-1.5.4-py2.py3-none-any.whl.metadata (3.5 kB)
|
| 76 |
+
Requirement already satisfied: rich>=10.11.0 in /home/alienware3/.local/lib/python3.10/site-packages (from typer<1.0,>=0.12->gradio) (14.0.0)
|
| 77 |
+
Requirement already satisfied: charset_normalizer<4,>=2 in /home/alienware3/.local/lib/python3.10/site-packages (from requests->huggingface-hub) (3.3.2)
|
| 78 |
+
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from requests->huggingface-hub) (2.2.3)
|
| 79 |
+
Requirement already satisfied: six>=1.5 in /home/alienware3/miniconda3/envs/diamond/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.16.0)
|
| 80 |
+
Requirement already satisfied: markdown-it-py>=2.2.0 in /home/alienware3/.local/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)
|
| 81 |
+
Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/alienware3/.local/lib/python3.10/site-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.19.2)
|
| 82 |
+
Requirement already satisfied: mdurl~=0.1 in /home/alienware3/.local/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)
|
| 83 |
+
Downloading gradio-5.44.1-py3-none-any.whl (60.2 MB)
|
| 84 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 60.2/60.2 MB 98.5 MB/s eta 0:00:00
|
| 85 |
+
Downloading gradio_client-1.12.1-py3-none-any.whl (324 kB)
|
| 86 |
+
Using cached huggingface_hub-0.34.4-py3-none-any.whl (561 kB)
|
| 87 |
+
Downloading aiofiles-24.1.0-py3-none-any.whl (15 kB)
|
| 88 |
+
Downloading Brotli-1.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.0 MB)
|
| 89 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.0/3.0 MB 100.4 MB/s eta 0:00:00
|
| 90 |
+
Downloading fastapi-0.116.1-py3-none-any.whl (95 kB)
|
| 91 |
+
Downloading groovy-0.1.2-py3-none-any.whl (14 kB)
|
| 92 |
+
Using cached hf_xet-1.1.9-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
|
| 93 |
+
Downloading orjson-3.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (132 kB)
|
| 94 |
+
Downloading pydantic-2.11.7-py3-none-any.whl (444 kB)
|
| 95 |
+
Downloading pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
|
| 96 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.0/2.0 MB 94.6 MB/s eta 0:00:00
|
| 97 |
+
Downloading python_multipart-0.0.20-py3-none-any.whl (24 kB)
|
| 98 |
+
Downloading ruff-0.12.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.3 MB)
|
| 99 |
+
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.3/13.3 MB 104.4 MB/s eta 0:00:00
|
| 100 |
+
Downloading safehttpx-0.1.6-py3-none-any.whl (8.7 kB)
|
| 101 |
+
Downloading semantic_version-2.10.0-py2.py3-none-any.whl (15 kB)
|
| 102 |
+
Downloading starlette-0.47.3-py3-none-any.whl (72 kB)
|
| 103 |
+
Downloading tomlkit-0.13.3-py3-none-any.whl (38 kB)
|
| 104 |
+
Downloading typer-0.17.4-py3-none-any.whl (46 kB)
|
| 105 |
+
Downloading uvicorn-0.35.0-py3-none-any.whl (66 kB)
|
| 106 |
+
Downloading ffmpy-0.6.1-py3-none-any.whl (5.5 kB)
|
| 107 |
+
Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)
|
| 108 |
+
Downloading annotated_types-0.7.0-py3-none-any.whl (13 kB)
|
| 109 |
+
Downloading shellingham-1.5.4-py2.py3-none-any.whl (9.8 kB)
|
| 110 |
+
Downloading typing_inspection-0.4.1-py3-none-any.whl (14 kB)
|
| 111 |
+
Downloading websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (181 kB)
|
| 112 |
+
Installing collected packages: pydub, brotli, websockets, uvicorn, typing-inspection, tomlkit, shellingham, semantic-version, ruff, python-multipart, pydantic-core, orjson, hf-xet, groovy, ffmpy, annotated-types, aiofiles, pydantic, huggingface-hub, typer, starlette, safehttpx, gradio-client, fastapi, gradio
|
| 113 |
+
Attempting uninstall: huggingface-hub
|
| 114 |
+
Found existing installation: huggingface-hub 0.26.1
|
| 115 |
+
Uninstalling huggingface-hub-0.26.1:
|
| 116 |
+
Successfully uninstalled huggingface-hub-0.26.1
|
| 117 |
+
Successfully installed aiofiles-24.1.0 annotated-types-0.7.0 brotli-1.1.0 fastapi-0.116.1 ffmpy-0.6.1 gradio-5.44.1 gradio-client-1.12.1 groovy-0.1.2 hf-xet-1.1.9 huggingface-hub-0.34.4 orjson-3.11.3 pydantic-2.11.7 pydantic-core-2.33.2 pydub-0.25.1 python-multipart-0.0.20 ruff-0.12.12 safehttpx-0.1.6 semantic-version-2.10.0 shellingham-1.5.4 starlette-0.47.3 tomlkit-0.13.3 typer-0.17.4 typing-inspection-0.4.1 uvicorn-0.35.0 websockets-15.0.1
|
=4.0.0
ADDED
|
File without changes
|
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diamond CSGO AI Player - Deployment Guide
|
| 2 |
+
|
| 3 |
+
## 🚀 Deploying to Hugging Face Spaces
|
| 4 |
+
|
| 5 |
+
### Prerequisites
|
| 6 |
+
1. Hugging Face account
|
| 7 |
+
2. Model checkpoint files (`agent_epoch_00206.pt` or similar)
|
| 8 |
+
3. Git and Git LFS installed
|
| 9 |
+
|
| 10 |
+
### Step 1: Prepare Repository
|
| 11 |
+
|
| 12 |
+
1. **Clone/Fork this repository**
|
| 13 |
+
2. **Install Git LFS** (for large model files):
|
| 14 |
+
```bash
|
| 15 |
+
git lfs install
|
| 16 |
+
git lfs track "*.pt"
|
| 17 |
+
git add .gitattributes
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
3. **Add your model checkpoint**:
|
| 21 |
+
```bash
|
| 22 |
+
# Copy your trained model to the project root
|
| 23 |
+
cp /path/to/your/agent_epoch_00206.pt .
|
| 24 |
+
git add agent_epoch_00206.pt
|
| 25 |
+
git commit -m "Add trained model checkpoint"
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
### Step 2: Create Hugging Face Space
|
| 29 |
+
|
| 30 |
+
1. Go to [Hugging Face Spaces](https://huggingface.co/spaces)
|
| 31 |
+
2. Click "Create new Space"
|
| 32 |
+
3. Configure:
|
| 33 |
+
- **Space name**: `diamond-csgo-ai` (or your choice)
|
| 34 |
+
- **License**: Your preferred license
|
| 35 |
+
- **Space SDK**: `Docker`
|
| 36 |
+
- **Space hardware**:
|
| 37 |
+
- `CPU basic` (free) - for demo/testing
|
| 38 |
+
- `GPU T4 small` (paid) - for better performance
|
| 39 |
+
|
| 40 |
+
### Step 3: Upload Code
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
# Clone your new space
|
| 44 |
+
git clone https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 45 |
+
cd YOUR_SPACE_NAME
|
| 46 |
+
|
| 47 |
+
# Copy all project files
|
| 48 |
+
cp -r /path/to/diamond/* .
|
| 49 |
+
|
| 50 |
+
# Commit and push
|
| 51 |
+
git add .
|
| 52 |
+
git commit -m "Initial Diamond CSGO AI deployment"
|
| 53 |
+
git push
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
### Step 4: Configuration Files
|
| 57 |
+
|
| 58 |
+
Ensure these files are in your space root:
|
| 59 |
+
|
| 60 |
+
- `app.py` - Main FastAPI application
|
| 61 |
+
- `requirements.txt` - Python dependencies
|
| 62 |
+
- `Dockerfile` - Container configuration
|
| 63 |
+
- `README.md` - Space description
|
| 64 |
+
- `packages.txt` - System packages (if needed)
|
| 65 |
+
|
| 66 |
+
### Step 5: Model Setup
|
| 67 |
+
|
| 68 |
+
If your model is too large for Git (>100MB), use Git LFS or download from Hub:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
# In your app.py, add model downloading:
|
| 72 |
+
from huggingface_hub import hf_hub_download
|
| 73 |
+
|
| 74 |
+
def download_model():
|
| 75 |
+
return hf_hub_download(
|
| 76 |
+
repo_id="YOUR_USERNAME/YOUR_MODEL_REPO",
|
| 77 |
+
filename="agent_epoch_00206.pt"
|
| 78 |
+
)
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## 🔧 Local Testing
|
| 82 |
+
|
| 83 |
+
Before deploying, test locally:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
# Install dependencies
|
| 87 |
+
pip install -r requirements.txt
|
| 88 |
+
|
| 89 |
+
# Run tests
|
| 90 |
+
python test_web_app.py
|
| 91 |
+
|
| 92 |
+
# Start local server
|
| 93 |
+
python run_web_demo.py
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
Visit `http://localhost:7860` to test the interface.
|
| 97 |
+
|
| 98 |
+
## ⚙️ Configuration Options
|
| 99 |
+
|
| 100 |
+
### Hardware Requirements
|
| 101 |
+
|
| 102 |
+
| Tier | CPU | RAM | GPU | Performance |
|
| 103 |
+
|------|-----|-----|-----|-------------|
|
| 104 |
+
| Free | 2 vCPU | 16GB | None | Basic demo |
|
| 105 |
+
| Basic GPU | 4 vCPU | 16GB | T4 | Good performance |
|
| 106 |
+
| Premium | 8 vCPU | 32GB | A10G | Best experience |
|
| 107 |
+
|
| 108 |
+
### Environment Variables
|
| 109 |
+
|
| 110 |
+
Add these in your Space settings:
|
| 111 |
+
- `CUDA_VISIBLE_DEVICES=""` (for CPU-only)
|
| 112 |
+
- `PYTHONPATH="/app/src:/app"`
|
| 113 |
+
|
| 114 |
+
## 🎮 Usage Instructions
|
| 115 |
+
|
| 116 |
+
Once deployed, users can:
|
| 117 |
+
|
| 118 |
+
1. Visit your Space URL
|
| 119 |
+
2. Click on the game canvas
|
| 120 |
+
3. Use keyboard controls:
|
| 121 |
+
- **WASD** - Movement
|
| 122 |
+
- **Space** - Jump
|
| 123 |
+
- **Arrow keys** - Camera
|
| 124 |
+
- **1,2,3** - Weapons
|
| 125 |
+
- **R** - Reload
|
| 126 |
+
- **M** - Switch Human/AI mode
|
| 127 |
+
|
| 128 |
+
## 🐛 Troubleshooting
|
| 129 |
+
|
| 130 |
+
### Common Issues
|
| 131 |
+
|
| 132 |
+
1. **Model not loading**:
|
| 133 |
+
- Check checkpoint file exists
|
| 134 |
+
- Verify file size (<5GB for Spaces)
|
| 135 |
+
- Use Git LFS for large files
|
| 136 |
+
|
| 137 |
+
2. **Import errors**:
|
| 138 |
+
- Check `requirements.txt` is complete
|
| 139 |
+
- Verify Python path in `Dockerfile`
|
| 140 |
+
|
| 141 |
+
3. **Performance issues**:
|
| 142 |
+
- Use GPU hardware tier
|
| 143 |
+
- Reduce model complexity
|
| 144 |
+
- Lower frame rate
|
| 145 |
+
|
| 146 |
+
4. **WebSocket connection failed**:
|
| 147 |
+
- Check firewall settings
|
| 148 |
+
- Verify port 7860 is accessible
|
| 149 |
+
- Try different browser
|
| 150 |
+
|
| 151 |
+
### Debug Mode
|
| 152 |
+
|
| 153 |
+
Enable debug logging:
|
| 154 |
+
```python
|
| 155 |
+
import logging
|
| 156 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## 📊 Monitoring
|
| 160 |
+
|
| 161 |
+
Monitor your Space:
|
| 162 |
+
- View logs in HF Spaces interface
|
| 163 |
+
- Check GPU utilization
|
| 164 |
+
- Monitor user sessions
|
| 165 |
+
|
| 166 |
+
## 🔄 Updates
|
| 167 |
+
|
| 168 |
+
To update your deployed Space:
|
| 169 |
+
```bash
|
| 170 |
+
git pull # Get latest changes
|
| 171 |
+
git add .
|
| 172 |
+
git commit -m "Update to latest version"
|
| 173 |
+
git push # Automatically redeploys
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## 💡 Tips for Success
|
| 177 |
+
|
| 178 |
+
1. **Start with CPU tier** to test basic functionality
|
| 179 |
+
2. **Use smaller models** for faster loading
|
| 180 |
+
3. **Test thoroughly locally** before deploying
|
| 181 |
+
4. **Monitor resource usage** to optimize costs
|
| 182 |
+
5. **Add usage instructions** in your Space README
|
| 183 |
+
|
| 184 |
+
## 🎯 Next Steps
|
| 185 |
+
|
| 186 |
+
After successful deployment:
|
| 187 |
+
- Share your Space with the community
|
| 188 |
+
- Collect user feedback
|
| 189 |
+
- Iterate on the interface
|
| 190 |
+
- Add new features like replay saving
|
| 191 |
+
- Consider multi-user support
|
| 192 |
+
|
| 193 |
+
Happy deploying! 🚀
|
HF_SPACES_CACHE_FIX.md
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🔧 HF Spaces Cache Permission Fix
|
| 2 |
+
|
| 3 |
+
## ❌ **Problem:**
|
| 4 |
+
```
|
| 5 |
+
ERROR:app:Failed to load model: [Errno 13] Permission denied: '/.cache'
|
| 6 |
+
```
|
| 7 |
+
|
| 8 |
+
HF Spaces containers can't write to the root `/.cache` directory, causing model downloads to fail.
|
| 9 |
+
|
| 10 |
+
## ✅ **Solution Applied:**
|
| 11 |
+
|
| 12 |
+
### 1. **Fixed Cache Directory in app.py**
|
| 13 |
+
- ✅ Set custom cache directory: `/tmp/torch_cache`
|
| 14 |
+
- ✅ Added proper permissions handling
|
| 15 |
+
- ✅ Fixed OMP_NUM_THREADS environment variable issue
|
| 16 |
+
|
| 17 |
+
### 2. **Updated Dockerfile**
|
| 18 |
+
- ✅ Set environment variables to use `/tmp` for caches
|
| 19 |
+
- ✅ Pre-create cache directories
|
| 20 |
+
- ✅ Fixed OMP_NUM_THREADS value
|
| 21 |
+
|
| 22 |
+
### 3. **Key Changes Made:**
|
| 23 |
+
|
| 24 |
+
#### **app.py Changes:**
|
| 25 |
+
```python
|
| 26 |
+
# Fixed cache directory for torch.hub
|
| 27 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 28 |
+
model_url,
|
| 29 |
+
map_location=device,
|
| 30 |
+
model_dir=cache_dir, # Custom cache dir
|
| 31 |
+
check_hash=False # Skip hash check for speed
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Fixed environment variables
|
| 35 |
+
os.environ["OMP_NUM_THREADS"] = "2" # Valid integer
|
| 36 |
+
os.environ["TORCH_HOME"] = "/tmp/torch"
|
| 37 |
+
os.environ["HF_HOME"] = "/tmp/huggingface"
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
#### **Dockerfile Changes:**
|
| 41 |
+
```dockerfile
|
| 42 |
+
ENV OMP_NUM_THREADS=2
|
| 43 |
+
ENV TORCH_HOME=/tmp/torch
|
| 44 |
+
ENV HF_HOME=/tmp/huggingface
|
| 45 |
+
ENV TRANSFORMERS_CACHE=/tmp/transformers
|
| 46 |
+
|
| 47 |
+
RUN mkdir -p /tmp/torch /tmp/huggingface /tmp/transformers
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## 🚀 **Expected Results:**
|
| 51 |
+
- ✅ No more "Permission denied: /.cache" errors
|
| 52 |
+
- ✅ No more "Invalid value for environment variable OMP_NUM_THREADS" warnings
|
| 53 |
+
- ✅ Model downloads work properly on HF Spaces
|
| 54 |
+
- ✅ App starts correctly and clicking works
|
| 55 |
+
|
| 56 |
+
## 📋 **To Deploy:**
|
| 57 |
+
1. **Commit the changes**: `git add . && git commit -m "Fix HF Spaces cache permissions"`
|
| 58 |
+
2. **Push to HF Spaces**: `git push`
|
| 59 |
+
3. **Monitor logs**: Check that download succeeds without permission errors
|
| 60 |
+
4. **Test**: Click the game area - should work now!
|
| 61 |
+
|
| 62 |
+
## 🔍 **Log Messages to Look For:**
|
| 63 |
+
### ✅ **Success:**
|
| 64 |
+
```
|
| 65 |
+
INFO:app:Loading state dict from https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt
|
| 66 |
+
INFO:app:State dict loaded, applying to agent...
|
| 67 |
+
INFO:app:Model has actor_critic weights: False
|
| 68 |
+
INFO:app:Actor-critic model exists but has no trained weights - using dummy mode!
|
| 69 |
+
INFO:app:WebPlayEnv set to human control mode (no trained weights)
|
| 70 |
+
INFO:app:Models initialized successfully!
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### ❌ **If Still Failing:**
|
| 74 |
+
```
|
| 75 |
+
ERROR:app:Failed to load model: [Errno 13] Permission denied
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## 🎯 **What This Fixes:**
|
| 79 |
+
1. ✅ **Model downloading** - now uses writable `/tmp` directory
|
| 80 |
+
2. ✅ **Environment variables** - OMP_NUM_THREADS is valid
|
| 81 |
+
3. ✅ **Game clicking** - works after model loads (even without actor_critic)
|
| 82 |
+
4. ✅ **HF Spaces compatibility** - follows container best practices
|
| 83 |
+
|
| 84 |
+
The app should now work perfectly on HF Spaces! 🎉
|
HF_SPACES_DEPLOYMENT_GUIDE.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Hugging Face Spaces Deployment - Troubleshooting Guide
|
| 2 |
+
|
| 3 |
+
## ✅ **Your Local Fix Applied**
|
| 4 |
+
Great news! The core issue has been resolved locally. The problem was that the downloaded model doesn't contain `actor_critic` weights, but the code assumed it did. This caused a `NoneType` error when clicking to start the game.
|
| 5 |
+
|
| 6 |
+
**Fixed**: The app now properly detects when `actor_critic` weights are missing and falls back to human control mode instead of crashing.
|
| 7 |
+
|
| 8 |
+
## 🔍 **Potential HF Spaces Issues & Solutions**
|
| 9 |
+
|
| 10 |
+
### **Issue 1: Model Download Timeouts** ⏰
|
| 11 |
+
|
| 12 |
+
**Symptoms:**
|
| 13 |
+
- "Model loading timed out" message
|
| 14 |
+
- App shows loading forever
|
| 15 |
+
- Click doesn't start the game
|
| 16 |
+
|
| 17 |
+
**Root Cause:** HF Spaces network can be slower, 5-minute timeout may not be enough.
|
| 18 |
+
|
| 19 |
+
**Solution:**
|
| 20 |
+
```python
|
| 21 |
+
# In app.py, update the timeout in _load_model_from_url_async():
|
| 22 |
+
success = await asyncio.wait_for(future, timeout=900.0) # 15 minutes instead of 5
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### **Issue 2: Memory Limitations** 💾
|
| 26 |
+
|
| 27 |
+
**Symptoms:**
|
| 28 |
+
- App crashes during model loading
|
| 29 |
+
- "Out of memory" errors in logs
|
| 30 |
+
- Models load but inference fails
|
| 31 |
+
|
| 32 |
+
**Root Cause:** HF Spaces free tier has only 16GB RAM.
|
| 33 |
+
|
| 34 |
+
**Quick Fix:** Force CPU-only mode
|
| 35 |
+
```python
|
| 36 |
+
# Add at the top of app.py
|
| 37 |
+
import os
|
| 38 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU mode for HF Spaces
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
**Better Solution:** Add memory management
|
| 42 |
+
```python
|
| 43 |
+
# Add memory cleanup after model loading
|
| 44 |
+
import gc
|
| 45 |
+
gc.collect()
|
| 46 |
+
if torch.cuda.is_available():
|
| 47 |
+
torch.cuda.empty_cache()
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### **Issue 3: WebSocket Connection Failures** 🔌
|
| 51 |
+
|
| 52 |
+
**Symptoms:**
|
| 53 |
+
- "Connection Error" or "Disconnected" status
|
| 54 |
+
- Click works but no response
|
| 55 |
+
- Frequent reconnections
|
| 56 |
+
|
| 57 |
+
**Root Cause:** HF Spaces proxy/domain restrictions.
|
| 58 |
+
|
| 59 |
+
**Solution:** Update the WebSocket connection code in the HTML template:
|
| 60 |
+
```javascript
|
| 61 |
+
// Replace the connectWebSocket function in app.py HTML
|
| 62 |
+
function connectWebSocket() {
|
| 63 |
+
const isHFSpaces = window.location.hostname.includes('huggingface.co');
|
| 64 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 65 |
+
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
| 66 |
+
|
| 67 |
+
ws = new WebSocket(wsUrl);
|
| 68 |
+
|
| 69 |
+
// Longer timeout for HF Spaces
|
| 70 |
+
const timeout = isHFSpaces ? 30000 : 10000;
|
| 71 |
+
|
| 72 |
+
const connectTimer = setTimeout(() => {
|
| 73 |
+
if (ws.readyState !== WebSocket.OPEN) {
|
| 74 |
+
ws.close();
|
| 75 |
+
setTimeout(connectWebSocket, 5000); // Retry after 5s
|
| 76 |
+
}
|
| 77 |
+
}, timeout);
|
| 78 |
+
|
| 79 |
+
ws.onopen = function(event) {
|
| 80 |
+
clearTimeout(connectTimer);
|
| 81 |
+
statusEl.textContent = 'Connected';
|
| 82 |
+
statusEl.style.color = '#00ff00';
|
| 83 |
+
|
| 84 |
+
// Re-send start if user already clicked
|
| 85 |
+
if (gameStarted && !gamePlaying) {
|
| 86 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 87 |
+
}
|
| 88 |
+
};
|
| 89 |
+
}
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### **Issue 4: Actor-Critic Model Missing** 🧠
|
| 93 |
+
|
| 94 |
+
**Already Fixed!** ✅ The app now handles this gracefully:
|
| 95 |
+
- Detects missing `actor_critic` weights
|
| 96 |
+
- Falls back to human control mode
|
| 97 |
+
- Shows proper warning messages
|
| 98 |
+
- Game still works (user can control manually)
|
| 99 |
+
|
| 100 |
+
### **Issue 5: Dockerfile Optimization** 🐳
|
| 101 |
+
|
| 102 |
+
**Update your Dockerfile for HF Spaces:**
|
| 103 |
+
```dockerfile
|
| 104 |
+
# Add these optimizations
|
| 105 |
+
ENV SHM_SIZE=2g
|
| 106 |
+
ENV PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512
|
| 107 |
+
ENV OMP_NUM_THREADS=4
|
| 108 |
+
|
| 109 |
+
# Add health check
|
| 110 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s \
|
| 111 |
+
CMD curl --fail http://localhost:7860/health || exit 1
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## 🚀 **Quick Deployment Checklist**
|
| 115 |
+
|
| 116 |
+
### **Before Deploying:**
|
| 117 |
+
1. ✅ **Test locally with conda**: `conda activate diamond && python run_web_demo.py`
|
| 118 |
+
2. ✅ **Verify the fix works**: Click should now work (even without actor_critic weights)
|
| 119 |
+
3. ✅ **Check model download**: Test internet connectivity for HF model URL
|
| 120 |
+
|
| 121 |
+
### **For HF Spaces Deployment:**
|
| 122 |
+
|
| 123 |
+
1. **Update timeout values:**
|
| 124 |
+
```python
|
| 125 |
+
# In app.py line ~153
|
| 126 |
+
success = await asyncio.wait_for(future, timeout=900.0) # 15 min
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
2. **Add health check endpoint:**
|
| 130 |
+
```python
|
| 131 |
+
@app.get("/health")
|
| 132 |
+
async def health_check():
|
| 133 |
+
return {
|
| 134 |
+
"status": "healthy",
|
| 135 |
+
"models_ready": game_engine.models_ready,
|
| 136 |
+
"actor_critic_loaded": game_engine.actor_critic_loaded
|
| 137 |
+
}
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
3. **Force CPU mode for free tier:**
|
| 141 |
+
```python
|
| 142 |
+
# Add at app.py startup
|
| 143 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
4. **Update Dockerfile** with the optimizations above
|
| 147 |
+
|
| 148 |
+
5. **Test WebSocket connection** - add the improved connection handling
|
| 149 |
+
|
| 150 |
+
## 🔧 **Debugging on HF Spaces**
|
| 151 |
+
|
| 152 |
+
### **Check Logs:**
|
| 153 |
+
1. Go to your Space page on HuggingFace
|
| 154 |
+
2. Click "Logs" tab
|
| 155 |
+
3. Look for these messages:
|
| 156 |
+
- ✅ `"Actor-critic model exists but has no trained weights - using dummy mode!"`
|
| 157 |
+
- ✅ `"WebPlayEnv set to human control mode"`
|
| 158 |
+
- ❌ `"Model loading timed out"`
|
| 159 |
+
- ❌ `"WebSocket error"`
|
| 160 |
+
|
| 161 |
+
### **Test Health Endpoint:**
|
| 162 |
+
- Visit: `https://your-space.hf.space/health`
|
| 163 |
+
- Should return JSON with status info
|
| 164 |
+
|
| 165 |
+
### **Browser Console:**
|
| 166 |
+
- Open Developer Tools (F12)
|
| 167 |
+
- Check for WebSocket connection errors
|
| 168 |
+
- Look for JavaScript errors during click
|
| 169 |
+
|
| 170 |
+
## 🎯 **Expected Behavior After Fixes**
|
| 171 |
+
|
| 172 |
+
1. **App loads** → Shows loading progress bar
|
| 173 |
+
2. **Models initialize** → Either loads actor_critic OR shows "no trained weights"
|
| 174 |
+
3. **User clicks game area** → Game starts immediately (no hanging)
|
| 175 |
+
4. **If actor_critic missing** → User gets manual control (still playable!)
|
| 176 |
+
5. **If actor_critic loaded** → AI takes control automatically
|
| 177 |
+
|
| 178 |
+
## 🆘 **If Issues Persist**
|
| 179 |
+
|
| 180 |
+
**Quick Diagnostic:**
|
| 181 |
+
```python
|
| 182 |
+
# Add this test endpoint to app.py
|
| 183 |
+
@app.get("/debug")
|
| 184 |
+
async def debug_info():
|
| 185 |
+
return {
|
| 186 |
+
"models_ready": game_engine.models_ready,
|
| 187 |
+
"actor_critic_loaded": game_engine.actor_critic_loaded,
|
| 188 |
+
"loading_status": game_engine.loading_status,
|
| 189 |
+
"game_started": game_engine.game_started,
|
| 190 |
+
"obs_shape": str(game_engine.obs.shape) if game_engine.obs is not None else "None",
|
| 191 |
+
"connected_clients": len(connected_clients),
|
| 192 |
+
"cuda_available": torch.cuda.is_available(),
|
| 193 |
+
"device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 194 |
+
}
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
Visit `/debug` endpoint to see the current state.
|
| 198 |
+
|
| 199 |
+
**Most Common Issue:** If clicking still doesn't work on HF Spaces, it's usually the WebSocket connection. Update the connection handling as described above.
|
| 200 |
+
|
| 201 |
+
The core model/clicking issue is now fixed - the remaining items are deployment optimizations for HF Spaces' specific environment! 🎉
|
HF_SPACES_GPU_FIX.md
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 HF Spaces GPU Acceleration Fix
|
| 2 |
+
|
| 3 |
+
## ❌ **Problem Identified:**
|
| 4 |
+
Your T4 GPU wasn't being used because:
|
| 5 |
+
|
| 6 |
+
1. **Dockerfile disabled CUDA**: `ENV CUDA_VISIBLE_DEVICES=""`
|
| 7 |
+
2. **Environment variable issues**: OMP_NUM_THREADS causing warnings
|
| 8 |
+
3. **App running on CPU**: Despite having T4 GPU hardware
|
| 9 |
+
|
| 10 |
+
## ✅ **Complete Fix Applied:**
|
| 11 |
+
|
| 12 |
+
### **1. Dockerfile Changes**
|
| 13 |
+
```dockerfile
|
| 14 |
+
# REMOVED this line that was disabling GPU:
|
| 15 |
+
# ENV CUDA_VISIBLE_DEVICES=""
|
| 16 |
+
|
| 17 |
+
# Fixed environment variables:
|
| 18 |
+
ENV OMP_NUM_THREADS=2
|
| 19 |
+
ENV MKL_NUM_THREADS=2
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
### **2. App.py Improvements**
|
| 23 |
+
- ✅ **Fixed OMP_NUM_THREADS early**: Set before any imports
|
| 24 |
+
- ✅ **Improved GPU detection**: Better logging and detection
|
| 25 |
+
- ✅ **Cache directories**: Moved setup to very beginning
|
| 26 |
+
|
| 27 |
+
### **3. Environment Variable Priority**
|
| 28 |
+
Environment variables are now set in this order:
|
| 29 |
+
1. **Dockerfile** - Base container settings
|
| 30 |
+
2. **app.py top** - Python-level fixes (before imports)
|
| 31 |
+
3. **HF Spaces** - Runtime overrides
|
| 32 |
+
|
| 33 |
+
## 🎯 **Expected Results After Fix:**
|
| 34 |
+
|
| 35 |
+
### **Before (CPU mode):**
|
| 36 |
+
```
|
| 37 |
+
INFO:app:Using device: cpu
|
| 38 |
+
INFO:app:CUDA not available, using CPU - this is normal for HF Spaces free tier
|
| 39 |
+
CPU 56%
|
| 40 |
+
GPU 0%
|
| 41 |
+
GPU VRAM 0/16 GB
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
### **After (GPU mode):**
|
| 45 |
+
```
|
| 46 |
+
INFO:app:Using device: cuda
|
| 47 |
+
INFO:app:CUDA available: True
|
| 48 |
+
INFO:app:GPU device count: 1
|
| 49 |
+
INFO:app:Current GPU: Tesla T4
|
| 50 |
+
INFO:app:GPU memory: 15.1 GB
|
| 51 |
+
INFO:app:🚀 GPU acceleration enabled!
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### **Performance Improvement:**
|
| 55 |
+
- **CPU usage**: Should drop to ~20-30%
|
| 56 |
+
- **GPU usage**: Should show 10-50% during AI inference
|
| 57 |
+
- **GPU VRAM**: Should show 2-4GB usage
|
| 58 |
+
- **AI FPS**: Should increase from ~2 FPS to 10+ FPS
|
| 59 |
+
|
| 60 |
+
## 📋 **Deployment Steps:**
|
| 61 |
+
|
| 62 |
+
1. **Commit and push changes:**
|
| 63 |
+
```bash
|
| 64 |
+
git add .
|
| 65 |
+
git commit -m "Enable GPU acceleration for HF Spaces T4"
|
| 66 |
+
git push
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
2. **Wait for rebuild** (HF Spaces will restart automatically)
|
| 70 |
+
|
| 71 |
+
3. **Check new logs** for GPU detection:
|
| 72 |
+
```
|
| 73 |
+
INFO:app:🚀 GPU acceleration enabled!
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
4. **Monitor system stats:**
|
| 77 |
+
- GPU usage should now show activity
|
| 78 |
+
- GPU VRAM should show memory allocation
|
| 79 |
+
- Overall performance should be much faster
|
| 80 |
+
|
| 81 |
+
## 🔍 **Debugging Commands:**
|
| 82 |
+
|
| 83 |
+
### **Check CUDA in container:**
|
| 84 |
+
```python
|
| 85 |
+
import torch
|
| 86 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 87 |
+
print(f"GPU count: {torch.cuda.device_count()}")
|
| 88 |
+
if torch.cuda.is_available():
|
| 89 |
+
print(f"GPU name: {torch.cuda.get_device_name(0)}")
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### **Check environment variables:**
|
| 93 |
+
```python
|
| 94 |
+
import os
|
| 95 |
+
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
|
| 96 |
+
print(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS')}")
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
## 🚨 **If GPU Still Not Working:**
|
| 100 |
+
|
| 101 |
+
### **1. Verify HF Spaces Hardware:**
|
| 102 |
+
- Check your Space settings
|
| 103 |
+
- Ensure "T4 small" or "T4 medium" is selected
|
| 104 |
+
- Free tier doesn't have GPU access
|
| 105 |
+
|
| 106 |
+
### **2. Check Container Logs:**
|
| 107 |
+
Look for these messages:
|
| 108 |
+
- ✅ `"🚀 GPU acceleration enabled!"`
|
| 109 |
+
- ❌ `"CUDA not available"`
|
| 110 |
+
|
| 111 |
+
### **3. Alternative: Force GPU Detection**
|
| 112 |
+
If needed, add this debug code to app.py:
|
| 113 |
+
```python
|
| 114 |
+
# Debug GPU detection
|
| 115 |
+
logger.info(f"Environment CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'Not set')}")
|
| 116 |
+
logger.info(f"PyTorch CUDA compiled: {torch.version.cuda}")
|
| 117 |
+
logger.info(f"PyTorch version: {torch.__version__}")
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## ⚡ **Performance Optimization Tips:**
|
| 121 |
+
|
| 122 |
+
### **For T4 GPU:**
|
| 123 |
+
1. **Enable model compilation** (optional):
|
| 124 |
+
```bash
|
| 125 |
+
# Set environment variable in HF Spaces settings:
|
| 126 |
+
ENABLE_TORCH_COMPILE=1
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
2. **Increase AI FPS** (if needed):
|
| 130 |
+
```python
|
| 131 |
+
# In app.py, line ~86:
|
| 132 |
+
self.ai_fps = 15 # Increase from 10 to 15
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
3. **Monitor GPU memory**:
|
| 136 |
+
- T4 has 16GB VRAM
|
| 137 |
+
- App should use 2-4GB
|
| 138 |
+
- Leave headroom for other processes
|
| 139 |
+
|
| 140 |
+
## 🎮 **Expected User Experience:**
|
| 141 |
+
|
| 142 |
+
1. **Faster loading**: Models load to GPU memory
|
| 143 |
+
2. **Responsive gameplay**: AI inference runs at 10+ FPS
|
| 144 |
+
3. **Smoother visuals**: Display updates without lag
|
| 145 |
+
4. **Better AI performance**: GPU acceleration improves model inference
|
| 146 |
+
|
| 147 |
+
Your HF Spaces deployment should now fully utilize the T4 GPU! 🚀
|
| 148 |
+
|
| 149 |
+
## 📊 **Monitor These Metrics:**
|
| 150 |
+
- **GPU Utilization**: 10-50% during gameplay
|
| 151 |
+
- **GPU Memory**: 2-4GB allocated
|
| 152 |
+
- **AI FPS**: 10-15 FPS (displayed in web interface)
|
| 153 |
+
- **CPU Usage**: Should decrease to 20-30%
|
| 154 |
+
|
| 155 |
+
The game should feel much more responsive now! 🎉
|
__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (37.2 kB). View file
|
|
|
__pycache__/config_web.cpython-310.pyc
ADDED
|
Binary file (5.48 kB). View file
|
|
|
__pycache__/debug_app.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
__pycache__/hf_space_config.cpython-310.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
app.py.backup
ADDED
|
@@ -0,0 +1,1114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web-based Diamond CSGO AI Player for Hugging Face Spaces
|
| 3 |
+
Uses FastAPI + WebSocket for real-time keyboard input and game streaming
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Fix environment variables FIRST, before any other imports
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
# Fix OMP_NUM_THREADS immediately (before PyTorch/NumPy imports)
|
| 11 |
+
if "OMP_NUM_THREADS" not in os.environ or not os.environ.get("OMP_NUM_THREADS", "").isdigit():
|
| 12 |
+
os.environ["OMP_NUM_THREADS"] = "2"
|
| 13 |
+
|
| 14 |
+
# Set up cache directories immediately
|
| 15 |
+
temp_dir = tempfile.gettempdir()
|
| 16 |
+
os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch"))
|
| 17 |
+
os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface"))
|
| 18 |
+
os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers"))
|
| 19 |
+
|
| 20 |
+
# Create cache directories
|
| 21 |
+
for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]:
|
| 22 |
+
cache_path = os.environ[cache_var]
|
| 23 |
+
os.makedirs(cache_path, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
import asyncio
|
| 26 |
+
import base64
|
| 27 |
+
import io
|
| 28 |
+
import json
|
| 29 |
+
import logging
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Dict, List, Optional, Set
|
| 32 |
+
|
| 33 |
+
import cv2
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
import uvicorn
|
| 37 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 38 |
+
from fastapi.responses import HTMLResponse
|
| 39 |
+
from fastapi.staticfiles import StaticFiles
|
| 40 |
+
from hydra import compose, initialize
|
| 41 |
+
from hydra.utils import instantiate
|
| 42 |
+
from omegaconf import DictConfig, OmegaConf
|
| 43 |
+
from PIL import Image
|
| 44 |
+
|
| 45 |
+
# Import your modules
|
| 46 |
+
import sys
|
| 47 |
+
from pathlib import Path
|
| 48 |
+
|
| 49 |
+
# Add project root to path for src package imports
|
| 50 |
+
project_root = Path(__file__).parent
|
| 51 |
+
if str(project_root) not in sys.path:
|
| 52 |
+
sys.path.insert(0, str(project_root))
|
| 53 |
+
|
| 54 |
+
from src.agent import Agent
|
| 55 |
+
from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
|
| 56 |
+
from src.envs import WorldModelEnv
|
| 57 |
+
from src.game.web_play_env import WebPlayEnv
|
| 58 |
+
from src.utils import extract_state_dict
|
| 59 |
+
from config_web import web_config
|
| 60 |
+
|
| 61 |
+
# Configure logging
|
| 62 |
+
logging.basicConfig(level=logging.INFO)
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
+
# Global variables
|
| 66 |
+
app = FastAPI(title="Diamond CSGO AI Player")
|
| 67 |
+
|
| 68 |
+
# Set safe defaults for headless CI/Spaces environments
|
| 69 |
+
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
|
| 70 |
+
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
|
| 71 |
+
os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
|
| 72 |
+
|
| 73 |
+
# Environment variables already set at top of file
|
| 74 |
+
connected_clients: Set[WebSocket] = set()
|
| 75 |
+
|
| 76 |
+
class WebKeyMap:
|
| 77 |
+
"""Map web key codes to pygame-like keys for CSGO actions"""
|
| 78 |
+
WEB_TO_CSGO = {
|
| 79 |
+
'KeyW': 'w',
|
| 80 |
+
'KeyA': 'a',
|
| 81 |
+
'KeyS': 's',
|
| 82 |
+
'KeyD': 'd',
|
| 83 |
+
'Space': 'space',
|
| 84 |
+
'ControlLeft': 'left ctrl',
|
| 85 |
+
'ShiftLeft': 'left shift',
|
| 86 |
+
'Digit1': '1',
|
| 87 |
+
'Digit2': '2',
|
| 88 |
+
'Digit3': '3',
|
| 89 |
+
'KeyR': 'r',
|
| 90 |
+
'ArrowUp': 'camera_up',
|
| 91 |
+
'ArrowDown': 'camera_down',
|
| 92 |
+
'ArrowLeft': 'camera_left',
|
| 93 |
+
'ArrowRight': 'camera_right'
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
class WebGameEngine:
|
| 97 |
+
"""Web-compatible game engine that replaces pygame functionality"""
|
| 98 |
+
|
| 99 |
+
def __init__(self):
|
| 100 |
+
self.play_env: Optional[WebPlayEnv] = None
|
| 101 |
+
self.obs = None
|
| 102 |
+
self.running = False
|
| 103 |
+
self.game_started = False
|
| 104 |
+
self.fps = 30 # Display FPS
|
| 105 |
+
self.ai_fps = 15 # AI inference FPS (matching standalone play.py performance)
|
| 106 |
+
self.frame_count = 0
|
| 107 |
+
self.ai_frame_count = 0
|
| 108 |
+
self.last_ai_time = 0
|
| 109 |
+
self.start_time = 0 # Track when AI started for proper FPS calculation
|
| 110 |
+
self.last_frame_send_time = 0 # Track frame sending for optimization
|
| 111 |
+
self.web_fps = 20 # Web display FPS (lower than AI FPS to reduce network overhead)
|
| 112 |
+
self.pressed_keys: Set[str] = set()
|
| 113 |
+
self.mouse_x = 0
|
| 114 |
+
self.mouse_y = 0
|
| 115 |
+
self.l_click = False
|
| 116 |
+
self.r_click = False
|
| 117 |
+
self.should_reset = False
|
| 118 |
+
self.cached_obs = None # Cache last observation for frame skipping
|
| 119 |
+
self.first_inference_done = False # Track if first inference completed
|
| 120 |
+
self.models_ready = False # Track if models are loaded
|
| 121 |
+
self.download_progress = 0 # Track download progress (0-100)
|
| 122 |
+
self.loading_status = "Initializing..." # Loading status message
|
| 123 |
+
self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
|
| 124 |
+
import time
|
| 125 |
+
self.time_module = time
|
| 126 |
+
|
| 127 |
+
async def _load_model_from_url_async(self, agent, device):
|
| 128 |
+
"""Load model from URL using torch.hub (HF Spaces compatible)"""
|
| 129 |
+
import asyncio
|
| 130 |
+
import concurrent.futures
|
| 131 |
+
|
| 132 |
+
def load_model_weights():
|
| 133 |
+
"""Load model weights in thread pool to avoid blocking"""
|
| 134 |
+
try:
|
| 135 |
+
logger.info("Loading model using torch.hub.load_state_dict_from_url...")
|
| 136 |
+
self.loading_status = "Downloading model..."
|
| 137 |
+
self.download_progress = 10
|
| 138 |
+
|
| 139 |
+
model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
|
| 140 |
+
|
| 141 |
+
# Use torch.hub to download and load state dict with custom cache dir
|
| 142 |
+
logger.info(f"Loading state dict from {model_url}")
|
| 143 |
+
|
| 144 |
+
# Set custom cache directory that we have write permissions for
|
| 145 |
+
cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache")
|
| 146 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
# Use torch.hub with custom cache directory
|
| 149 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 150 |
+
model_url,
|
| 151 |
+
map_location=device,
|
| 152 |
+
model_dir=cache_dir,
|
| 153 |
+
check_hash=False # Skip hash check for faster loading
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.download_progress = 60
|
| 157 |
+
self.loading_status = "Loading model weights into agent..."
|
| 158 |
+
logger.info("State dict loaded, applying to agent...")
|
| 159 |
+
|
| 160 |
+
# Load state dict into agent, but skip actor_critic if not present
|
| 161 |
+
has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
|
| 162 |
+
logger.info(f"Model has actor_critic weights: {has_actor_critic}")
|
| 163 |
+
agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
|
| 164 |
+
|
| 165 |
+
# Track if actor_critic was actually loaded with trained weights
|
| 166 |
+
self.actor_critic_loaded = has_actor_critic
|
| 167 |
+
|
| 168 |
+
self.download_progress = 100
|
| 169 |
+
self.loading_status = "Model loaded successfully!"
|
| 170 |
+
logger.info("All model weights loaded successfully!")
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Failed to load model: {e}")
|
| 175 |
+
import traceback
|
| 176 |
+
traceback.print_exc()
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
# Run in thread pool to avoid blocking with timeout
|
| 180 |
+
loop = asyncio.get_event_loop()
|
| 181 |
+
try:
|
| 182 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 183 |
+
# Add timeout for model loading (5 minutes max)
|
| 184 |
+
future = loop.run_in_executor(executor, load_model_weights)
|
| 185 |
+
success = await asyncio.wait_for(future, timeout=300.0) # 5 minute timeout
|
| 186 |
+
return success
|
| 187 |
+
except asyncio.TimeoutError:
|
| 188 |
+
logger.error("Model loading timed out after 5 minutes")
|
| 189 |
+
self.loading_status = "Model loading timed out - using dummy mode"
|
| 190 |
+
return False
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Error in model loading executor: {e}")
|
| 193 |
+
self.loading_status = f"Model loading error: {str(e)[:50]}..."
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
async def initialize_models(self):
|
| 197 |
+
"""Initialize the AI models and environment"""
|
| 198 |
+
try:
|
| 199 |
+
import torch
|
| 200 |
+
logger.info("Initializing models...")
|
| 201 |
+
|
| 202 |
+
# Setup environment and paths
|
| 203 |
+
web_config.setup_environment_variables()
|
| 204 |
+
web_config.create_default_configs()
|
| 205 |
+
|
| 206 |
+
config_path = web_config.get_config_path()
|
| 207 |
+
logger.info(f"Using config path: {config_path}")
|
| 208 |
+
|
| 209 |
+
# For Hydra, use relative path from app.py location
|
| 210 |
+
# Since app.py is in project root, config is simply "config"
|
| 211 |
+
relative_config_path = "config"
|
| 212 |
+
logger.info(f"Relative config path: {relative_config_path}")
|
| 213 |
+
|
| 214 |
+
with initialize(version_base="1.3", config_path=relative_config_path):
|
| 215 |
+
cfg = compose(config_name="trainer")
|
| 216 |
+
|
| 217 |
+
# Override config for deployment
|
| 218 |
+
cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml")
|
| 219 |
+
cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml")
|
| 220 |
+
|
| 221 |
+
# Use GPU if available, otherwise fall back to CPU
|
| 222 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 223 |
+
logger.info(f"Using device: {device}")
|
| 224 |
+
|
| 225 |
+
# Log GPU availability and CUDA info for debugging
|
| 226 |
+
if torch.cuda.is_available():
|
| 227 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
| 228 |
+
logger.info(f"GPU device count: {torch.cuda.device_count()}")
|
| 229 |
+
logger.info(f"Current GPU: {torch.cuda.get_device_name(0)}")
|
| 230 |
+
logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 231 |
+
logger.info("🚀 GPU acceleration enabled!")
|
| 232 |
+
else:
|
| 233 |
+
logger.info("CUDA not available, using CPU mode")
|
| 234 |
+
|
| 235 |
+
# Initialize agent first
|
| 236 |
+
num_actions = cfg.env.num_actions
|
| 237 |
+
agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
|
| 238 |
+
|
| 239 |
+
# Get spawn directory
|
| 240 |
+
spawn_dir = web_config.get_spawn_dir()
|
| 241 |
+
|
| 242 |
+
# Try to load checkpoint (remote first, then local, then dummy mode)
|
| 243 |
+
try:
|
| 244 |
+
# First try to load from Hugging Face Hub using torch.hub
|
| 245 |
+
logger.info("Loading model from Hugging Face Hub with torch.hub...")
|
| 246 |
+
|
| 247 |
+
success = await self._load_model_from_url_async(agent, device)
|
| 248 |
+
|
| 249 |
+
if success:
|
| 250 |
+
logger.info("Successfully loaded checkpoint from HF Hub")
|
| 251 |
+
else:
|
| 252 |
+
# Fallback to local checkpoint if available
|
| 253 |
+
logger.error("Failed to load from HF Hub! Check the detailed error above.")
|
| 254 |
+
checkpoint_path = web_config.get_checkpoint_path()
|
| 255 |
+
if checkpoint_path.exists():
|
| 256 |
+
logger.info(f"Loading local checkpoint: {checkpoint_path}")
|
| 257 |
+
self.loading_status = "Loading local checkpoint..."
|
| 258 |
+
agent.load(checkpoint_path)
|
| 259 |
+
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
| 260 |
+
# Assume local checkpoint has actor_critic weights (may need verification)
|
| 261 |
+
self.actor_critic_loaded = True
|
| 262 |
+
else:
|
| 263 |
+
logger.error(f"No local checkpoint found at: {checkpoint_path}")
|
| 264 |
+
raise FileNotFoundError("No model checkpoint available (local or remote)")
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"Failed to load any checkpoint: {e}")
|
| 268 |
+
self._init_dummy_mode()
|
| 269 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 270 |
+
return True
|
| 271 |
+
|
| 272 |
+
# Initialize world model environment
|
| 273 |
+
try:
|
| 274 |
+
sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
|
| 275 |
+
if agent.upsampler is not None:
|
| 276 |
+
sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
|
| 277 |
+
wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
|
| 278 |
+
wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model,
|
| 279 |
+
spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
|
| 280 |
+
|
| 281 |
+
# Create play environment
|
| 282 |
+
self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
|
| 283 |
+
|
| 284 |
+
# Verify actor-critic is loaded and ready for inference
|
| 285 |
+
if agent.actor_critic is not None and self.actor_critic_loaded:
|
| 286 |
+
logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
|
| 287 |
+
logger.info(f"Actor-critic device: {agent.actor_critic.device}")
|
| 288 |
+
# Force AI control for web demo
|
| 289 |
+
self.play_env.is_human_player = False
|
| 290 |
+
logger.info("WebPlayEnv set to AI control mode")
|
| 291 |
+
elif agent.actor_critic is not None and not self.actor_critic_loaded:
|
| 292 |
+
logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
|
| 293 |
+
self.play_env.is_human_player = True
|
| 294 |
+
logger.info("WebPlayEnv set to human control mode (no trained weights)")
|
| 295 |
+
else:
|
| 296 |
+
logger.warning("No actor-critic model found - AI inference will not work!")
|
| 297 |
+
self.play_env.is_human_player = True
|
| 298 |
+
logger.info("WebPlayEnv set to human control mode (fallback)")
|
| 299 |
+
|
| 300 |
+
# Enable model compilation for better performance (like standalone play.py)
|
| 301 |
+
# This gives 20-50% speedup but causes 10-30s delay on first inference
|
| 302 |
+
import os
|
| 303 |
+
enable_compile = device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1"
|
| 304 |
+
|
| 305 |
+
if enable_compile:
|
| 306 |
+
logger.info("🚀 Compiling models for faster inference (like standalone play.py)...")
|
| 307 |
+
logger.info("⏱️ First inference will take 10-30s, but subsequent inferences will be much faster!")
|
| 308 |
+
try:
|
| 309 |
+
wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
|
| 310 |
+
if wm_env.upsample_next_obs is not None:
|
| 311 |
+
wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
|
| 312 |
+
logger.info("✅ Model compilation enabled - expect 20-50% speedup!")
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.warning(f"Model compilation failed: {e}")
|
| 315 |
+
enable_compile = False
|
| 316 |
+
|
| 317 |
+
if not enable_compile:
|
| 318 |
+
logger.info("Model compilation disabled. Set ENABLE_TORCH_COMPILE=1 for better performance.")
|
| 319 |
+
|
| 320 |
+
# Reset environment
|
| 321 |
+
self.obs, _ = self.play_env.reset()
|
| 322 |
+
self.cached_obs = self.obs # Initialize cache
|
| 323 |
+
|
| 324 |
+
logger.info("Models initialized successfully!")
|
| 325 |
+
logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}")
|
| 326 |
+
self.models_ready = True
|
| 327 |
+
self.loading_status = "Ready!"
|
| 328 |
+
return True
|
| 329 |
+
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.error(f"Failed to initialize world model environment: {e}")
|
| 332 |
+
self._init_dummy_mode()
|
| 333 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 334 |
+
self.models_ready = True
|
| 335 |
+
self.loading_status = "Using dummy mode"
|
| 336 |
+
return True
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"Failed to initialize models: {e}")
|
| 340 |
+
import traceback
|
| 341 |
+
traceback.print_exc()
|
| 342 |
+
self._init_dummy_mode()
|
| 343 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 344 |
+
self.models_ready = True
|
| 345 |
+
self.loading_status = "Error - using dummy mode"
|
| 346 |
+
return True
|
| 347 |
+
|
| 348 |
+
def _init_dummy_mode(self):
|
| 349 |
+
"""Initialize dummy mode for testing without models"""
|
| 350 |
+
logger.info("Initializing dummy mode...")
|
| 351 |
+
|
| 352 |
+
# Create a test observation
|
| 353 |
+
height, width = 150, 600
|
| 354 |
+
img_array = np.zeros((height, width, 3), dtype=np.uint8)
|
| 355 |
+
|
| 356 |
+
# Add test pattern
|
| 357 |
+
for y in range(height):
|
| 358 |
+
for x in range(width):
|
| 359 |
+
img_array[y, x, 0] = (x % 256) # Red gradient
|
| 360 |
+
img_array[y, x, 1] = (y % 256) # Green gradient
|
| 361 |
+
img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
|
| 362 |
+
|
| 363 |
+
# Convert to torch tensor in expected format [-1, 1]
|
| 364 |
+
tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
|
| 365 |
+
tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
|
| 366 |
+
tensor = tensor.unsqueeze(0) # Add batch dimension
|
| 367 |
+
|
| 368 |
+
self.obs = tensor
|
| 369 |
+
self.play_env = None # No real environment in dummy mode
|
| 370 |
+
logger.info("Dummy mode initialized with test pattern")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def step_environment(self):
|
| 374 |
+
"""Step the environment with current input state (with intelligent frame skipping)"""
|
| 375 |
+
if self.play_env is None:
|
| 376 |
+
# Dummy mode - just return current observation
|
| 377 |
+
return self.obs, 0.0, False, False, {"mode": "dummy"}
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
# Check if reset is requested
|
| 381 |
+
if self.should_reset:
|
| 382 |
+
self.reset_environment()
|
| 383 |
+
self.should_reset = False
|
| 384 |
+
self.last_ai_time = self.time_module.time() # Reset AI timer
|
| 385 |
+
return self.obs, 0.0, False, False, {"reset": True}
|
| 386 |
+
|
| 387 |
+
# Intelligent frame skipping: only run AI inference at target FPS
|
| 388 |
+
current_time = self.time_module.time()
|
| 389 |
+
time_since_last_ai = current_time - self.last_ai_time
|
| 390 |
+
should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
|
| 391 |
+
|
| 392 |
+
if should_run_ai:
|
| 393 |
+
# Show loading indicator for first inference (can be slow)
|
| 394 |
+
if not self.first_inference_done:
|
| 395 |
+
logger.info("Running first AI inference (may take 5-15 seconds)...")
|
| 396 |
+
|
| 397 |
+
# Run AI inference
|
| 398 |
+
inference_start = self.time_module.time()
|
| 399 |
+
next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
|
| 400 |
+
pressed_keys=self.pressed_keys,
|
| 401 |
+
mouse_x=self.mouse_x,
|
| 402 |
+
mouse_y=self.mouse_y,
|
| 403 |
+
l_click=self.l_click,
|
| 404 |
+
r_click=self.r_click
|
| 405 |
+
)
|
| 406 |
+
inference_time = self.time_module.time() - inference_start
|
| 407 |
+
|
| 408 |
+
# Log first inference completion
|
| 409 |
+
if not self.first_inference_done:
|
| 410 |
+
self.first_inference_done = True
|
| 411 |
+
logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!")
|
| 412 |
+
|
| 413 |
+
# Cache the new observation and update timing
|
| 414 |
+
self.cached_obs = next_obs
|
| 415 |
+
self.last_ai_time = current_time
|
| 416 |
+
self.ai_frame_count += 1
|
| 417 |
+
|
| 418 |
+
# Add AI performance info
|
| 419 |
+
info = info or {}
|
| 420 |
+
info["ai_inference"] = True
|
| 421 |
+
|
| 422 |
+
# Calculate proper AI FPS: frames / elapsed time since start
|
| 423 |
+
elapsed_time = current_time - self.start_time
|
| 424 |
+
if elapsed_time > 0 and self.ai_frame_count > 0:
|
| 425 |
+
ai_fps = self.ai_frame_count / elapsed_time
|
| 426 |
+
# Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference)
|
| 427 |
+
info["ai_fps"] = min(ai_fps, 100.0)
|
| 428 |
+
else:
|
| 429 |
+
info["ai_fps"] = 0
|
| 430 |
+
|
| 431 |
+
info["inference_time"] = inference_time
|
| 432 |
+
|
| 433 |
+
return next_obs, reward, done, truncated, info
|
| 434 |
+
else:
|
| 435 |
+
# Use cached observation for smoother display without AI overhead
|
| 436 |
+
obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
|
| 437 |
+
|
| 438 |
+
# Calculate AI FPS for cached frames too
|
| 439 |
+
elapsed_time = current_time - self.start_time
|
| 440 |
+
if elapsed_time > 0 and self.ai_frame_count > 0:
|
| 441 |
+
ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS
|
| 442 |
+
else:
|
| 443 |
+
ai_fps = 0
|
| 444 |
+
|
| 445 |
+
return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps}
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
logger.error(f"Error stepping environment: {e}")
|
| 449 |
+
obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
|
| 450 |
+
return obs_to_return, 0.0, False, False, {"error": str(e)}
|
| 451 |
+
|
| 452 |
+
def reset_environment(self):
|
| 453 |
+
"""Reset the environment"""
|
| 454 |
+
try:
|
| 455 |
+
if self.play_env is not None:
|
| 456 |
+
self.obs, _ = self.play_env.reset()
|
| 457 |
+
self.cached_obs = self.obs # Update cache
|
| 458 |
+
logger.info("Environment reset successfully")
|
| 459 |
+
else:
|
| 460 |
+
# Dummy mode - recreate test pattern
|
| 461 |
+
self._init_dummy_mode()
|
| 462 |
+
self.cached_obs = self.obs # Update cache
|
| 463 |
+
logger.info("Dummy environment reset")
|
| 464 |
+
except Exception as e:
|
| 465 |
+
logger.error(f"Error resetting environment: {e}")
|
| 466 |
+
|
| 467 |
+
def request_reset(self):
|
| 468 |
+
"""Request environment reset on next step"""
|
| 469 |
+
self.should_reset = True
|
| 470 |
+
logger.info("Environment reset requested")
|
| 471 |
+
|
| 472 |
+
def start_game(self):
|
| 473 |
+
"""Start the game"""
|
| 474 |
+
self.game_started = True
|
| 475 |
+
self.start_time = self.time_module.time() # Reset start time for FPS calculation
|
| 476 |
+
self.ai_frame_count = 0 # Reset AI frame count
|
| 477 |
+
logger.info("Game started")
|
| 478 |
+
|
| 479 |
+
def pause_game(self):
|
| 480 |
+
"""Pause/stop the game"""
|
| 481 |
+
self.game_started = False
|
| 482 |
+
logger.info("Game paused")
|
| 483 |
+
|
| 484 |
+
def obs_to_base64(self, obs: torch.Tensor) -> str:
|
| 485 |
+
"""Convert observation tensor to base64 image for web display (optimized for speed)"""
|
| 486 |
+
if obs is None:
|
| 487 |
+
return ""
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
# Handle observation tensor conversion based on dimensions
|
| 491 |
+
if obs.ndim == 4 and obs.size(0) == 1:
|
| 492 |
+
# 4D tensor with batch dimension [1, C, H, W] -> [C, H, W]
|
| 493 |
+
img_tensor = obs[0]
|
| 494 |
+
elif obs.ndim == 3:
|
| 495 |
+
# 3D tensor [C, H, W]
|
| 496 |
+
img_tensor = obs
|
| 497 |
+
elif obs.ndim == 2:
|
| 498 |
+
# 2D tensor - likely an error, return empty string
|
| 499 |
+
logger.warning(f"Unexpected 2D observation tensor: {obs.shape}")
|
| 500 |
+
return ""
|
| 501 |
+
else:
|
| 502 |
+
logger.warning(f"Unexpected observation dimensions: {obs.shape}")
|
| 503 |
+
return ""
|
| 504 |
+
|
| 505 |
+
# Convert to numpy with proper range conversion
|
| 506 |
+
img_array = img_tensor.mul(127.5).add_(127.5).clamp_(0, 255).byte()
|
| 507 |
+
img_array = img_array.permute(1, 2, 0).cpu().numpy()
|
| 508 |
+
|
| 509 |
+
# Direct resize with OpenCV (much faster than PIL)
|
| 510 |
+
img_array = cv2.resize(img_array, (600, 150), interpolation=cv2.INTER_NEAREST)
|
| 511 |
+
|
| 512 |
+
# Convert BGR to RGB if needed (OpenCV uses BGR)
|
| 513 |
+
if img_array.shape[2] == 3:
|
| 514 |
+
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2RGB)
|
| 515 |
+
|
| 516 |
+
# Optimized JPEG encoding with OpenCV (faster than PIL)
|
| 517 |
+
success, buffer = cv2.imencode('.jpg', img_array, [cv2.IMWRITE_JPEG_QUALITY, 80])
|
| 518 |
+
if success:
|
| 519 |
+
img_str = base64.b64encode(buffer).decode()
|
| 520 |
+
return f"data:image/jpeg;base64,{img_str}"
|
| 521 |
+
else:
|
| 522 |
+
logger.warning("Frame encoding failed, using fallback")
|
| 523 |
+
return ""
|
| 524 |
+
|
| 525 |
+
except Exception as e:
|
| 526 |
+
logger.error(f"Error converting observation to base64: {e}")
|
| 527 |
+
return ""
|
| 528 |
+
|
| 529 |
+
async def game_loop(self):
|
| 530 |
+
"""Main game loop that runs continuously"""
|
| 531 |
+
self.running = True
|
| 532 |
+
|
| 533 |
+
while self.running:
|
| 534 |
+
try:
|
| 535 |
+
# Check if models are ready
|
| 536 |
+
if not self.models_ready:
|
| 537 |
+
# Send loading status to clients
|
| 538 |
+
if connected_clients:
|
| 539 |
+
loading_data = {
|
| 540 |
+
'type': 'loading',
|
| 541 |
+
'status': self.loading_status,
|
| 542 |
+
'progress': self.download_progress,
|
| 543 |
+
'ready': False
|
| 544 |
+
}
|
| 545 |
+
disconnected = set()
|
| 546 |
+
for client in connected_clients.copy():
|
| 547 |
+
try:
|
| 548 |
+
await client.send_text(json.dumps(loading_data))
|
| 549 |
+
except:
|
| 550 |
+
disconnected.add(client)
|
| 551 |
+
connected_clients.difference_update(disconnected)
|
| 552 |
+
|
| 553 |
+
await asyncio.sleep(0.5) # Check every 500ms during loading
|
| 554 |
+
continue
|
| 555 |
+
|
| 556 |
+
# Only step environment if game is started
|
| 557 |
+
if not self.game_started:
|
| 558 |
+
# Game not started - just send current observation without stepping
|
| 559 |
+
should_send_frame = True if (self.obs is not None and connected_clients) else False
|
| 560 |
+
# Don't modify self.obs when game isn't started!
|
| 561 |
+
await asyncio.sleep(0.1)
|
| 562 |
+
else:
|
| 563 |
+
# Game is started - step environment
|
| 564 |
+
should_send_frame = True
|
| 565 |
+
if self.play_env is None:
|
| 566 |
+
await asyncio.sleep(0.1)
|
| 567 |
+
continue
|
| 568 |
+
|
| 569 |
+
# Step environment with current input state
|
| 570 |
+
next_obs, reward, done, truncated, info = self.step_environment()
|
| 571 |
+
|
| 572 |
+
if done or truncated:
|
| 573 |
+
# Auto-reset when episode ends
|
| 574 |
+
self.reset_environment()
|
| 575 |
+
else:
|
| 576 |
+
self.obs = next_obs
|
| 577 |
+
|
| 578 |
+
# Send frame to all connected clients with smart throttling for performance
|
| 579 |
+
current_time = self.time_module.time()
|
| 580 |
+
time_since_last_frame_send = current_time - self.last_frame_send_time
|
| 581 |
+
should_send_web_frame = time_since_last_frame_send >= (1.0 / self.web_fps)
|
| 582 |
+
|
| 583 |
+
if should_send_frame and should_send_web_frame and connected_clients and self.obs is not None:
|
| 584 |
+
# Set default values for when game isn't running
|
| 585 |
+
if not self.game_started:
|
| 586 |
+
reward = 0.0
|
| 587 |
+
info = {"waiting": True}
|
| 588 |
+
# If game is started, reward and info should be set above
|
| 589 |
+
|
| 590 |
+
# Convert observation to base64
|
| 591 |
+
image_data = self.obs_to_base64(self.obs)
|
| 592 |
+
|
| 593 |
+
# Debug logging for first few frames
|
| 594 |
+
if self.frame_count < 5:
|
| 595 |
+
logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
|
| 596 |
+
f"image_data_length={len(image_data) if image_data else 0}, "
|
| 597 |
+
f"game_started={self.game_started}")
|
| 598 |
+
|
| 599 |
+
frame_data = {
|
| 600 |
+
'type': 'frame',
|
| 601 |
+
'image': image_data,
|
| 602 |
+
'frame_count': self.frame_count,
|
| 603 |
+
'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
|
| 604 |
+
'info': str(info) if info else "",
|
| 605 |
+
'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
|
| 606 |
+
'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False,
|
| 607 |
+
'web_fps': self.web_fps, # Add web FPS for monitoring
|
| 608 |
+
'ai_target_fps': self.ai_fps # Add target AI FPS for monitoring
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
# Send to all connected clients
|
| 612 |
+
disconnected = set()
|
| 613 |
+
for client in connected_clients.copy():
|
| 614 |
+
try:
|
| 615 |
+
await client.send_text(json.dumps(frame_data))
|
| 616 |
+
except:
|
| 617 |
+
disconnected.add(client)
|
| 618 |
+
|
| 619 |
+
# Remove disconnected clients
|
| 620 |
+
connected_clients.difference_update(disconnected)
|
| 621 |
+
|
| 622 |
+
# Update frame send timing
|
| 623 |
+
self.last_frame_send_time = current_time
|
| 624 |
+
|
| 625 |
+
self.frame_count += 1
|
| 626 |
+
await asyncio.sleep(1.0 / self.fps) # Control FPS
|
| 627 |
+
|
| 628 |
+
except Exception as e:
|
| 629 |
+
logger.error(f"Error in game loop: {e}")
|
| 630 |
+
await asyncio.sleep(0.1)
|
| 631 |
+
|
| 632 |
+
# Global game engine instance
|
| 633 |
+
game_engine = WebGameEngine()
|
| 634 |
+
|
| 635 |
+
@app.on_event("startup")
|
| 636 |
+
async def startup_event():
|
| 637 |
+
"""Initialize models when the app starts"""
|
| 638 |
+
# Start the game loop immediately (it will handle loading state)
|
| 639 |
+
asyncio.create_task(game_engine.game_loop())
|
| 640 |
+
|
| 641 |
+
# Initialize models in background (non-blocking)
|
| 642 |
+
asyncio.create_task(game_engine.initialize_models())
|
| 643 |
+
|
| 644 |
+
@app.get("/performance")
|
| 645 |
+
async def get_performance_stats():
|
| 646 |
+
"""Get current performance statistics"""
|
| 647 |
+
current_time = game_engine.time_module.time()
|
| 648 |
+
elapsed_time = current_time - game_engine.start_time if game_engine.start_time > 0 else 0
|
| 649 |
+
|
| 650 |
+
return {
|
| 651 |
+
"ai_fps_current": game_engine.ai_frame_count / elapsed_time if elapsed_time > 0 else 0,
|
| 652 |
+
"ai_fps_target": game_engine.ai_fps,
|
| 653 |
+
"web_fps_target": game_engine.web_fps,
|
| 654 |
+
"display_fps_target": game_engine.fps,
|
| 655 |
+
"models_ready": game_engine.models_ready,
|
| 656 |
+
"actor_critic_loaded": game_engine.actor_critic_loaded,
|
| 657 |
+
"game_started": game_engine.game_started,
|
| 658 |
+
"connected_clients": len(connected_clients),
|
| 659 |
+
"total_ai_frames": game_engine.ai_frame_count,
|
| 660 |
+
"total_display_frames": game_engine.frame_count,
|
| 661 |
+
"elapsed_time": elapsed_time,
|
| 662 |
+
"torch_compile_enabled": os.environ.get("ENABLE_TORCH_COMPILE", "1") == "1",
|
| 663 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
@app.get("/", response_class=HTMLResponse)
|
| 667 |
+
async def get_homepage():
|
| 668 |
+
"""Serve the main game interface"""
|
| 669 |
+
html_content = """
|
| 670 |
+
<!DOCTYPE html>
|
| 671 |
+
<html>
|
| 672 |
+
<head>
|
| 673 |
+
<title>Physics-informed BEV World Model</title>
|
| 674 |
+
<style>
|
| 675 |
+
body {
|
| 676 |
+
margin: 0;
|
| 677 |
+
padding: 20px;
|
| 678 |
+
background: #1a1a1a;
|
| 679 |
+
color: white;
|
| 680 |
+
font-family: 'Courier New', monospace;
|
| 681 |
+
text-align: center;
|
| 682 |
+
}
|
| 683 |
+
#gameCanvas {
|
| 684 |
+
border: 2px solid #00ff00;
|
| 685 |
+
background: #000;
|
| 686 |
+
margin: 20px auto;
|
| 687 |
+
display: block;
|
| 688 |
+
}
|
| 689 |
+
#controls {
|
| 690 |
+
margin: 20px;
|
| 691 |
+
display: grid;
|
| 692 |
+
grid-template-columns: 1fr 1fr;
|
| 693 |
+
gap: 20px;
|
| 694 |
+
max-width: 800px;
|
| 695 |
+
margin: 20px auto;
|
| 696 |
+
}
|
| 697 |
+
.control-section {
|
| 698 |
+
background: #2a2a2a;
|
| 699 |
+
padding: 15px;
|
| 700 |
+
border-radius: 8px;
|
| 701 |
+
border: 1px solid #444;
|
| 702 |
+
}
|
| 703 |
+
.key-display {
|
| 704 |
+
background: #333;
|
| 705 |
+
border: 1px solid #555;
|
| 706 |
+
padding: 5px 10px;
|
| 707 |
+
margin: 2px;
|
| 708 |
+
border-radius: 4px;
|
| 709 |
+
display: inline-block;
|
| 710 |
+
min-width: 30px;
|
| 711 |
+
}
|
| 712 |
+
.key-pressed {
|
| 713 |
+
background: #00ff00;
|
| 714 |
+
color: #000;
|
| 715 |
+
}
|
| 716 |
+
#status {
|
| 717 |
+
margin: 10px;
|
| 718 |
+
padding: 10px;
|
| 719 |
+
background: #2a2a2a;
|
| 720 |
+
border-radius: 4px;
|
| 721 |
+
}
|
| 722 |
+
.info {
|
| 723 |
+
color: #00ff00;
|
| 724 |
+
margin: 5px 0;
|
| 725 |
+
}
|
| 726 |
+
</style>
|
| 727 |
+
</head>
|
| 728 |
+
<body>
|
| 729 |
+
<h1>🎮 Physics-informed BEV World Model</h1>
|
| 730 |
+
<p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
|
| 731 |
+
<p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p>
|
| 732 |
+
|
| 733 |
+
<!-- Model Download Progress -->
|
| 734 |
+
<div id="downloadSection" style="display: none; margin: 20px;">
|
| 735 |
+
<p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p>
|
| 736 |
+
<div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;">
|
| 737 |
+
<div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div>
|
| 738 |
+
</div>
|
| 739 |
+
<p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p>
|
| 740 |
+
</div>
|
| 741 |
+
|
| 742 |
+
<canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas>
|
| 743 |
+
|
| 744 |
+
<div id="status">
|
| 745 |
+
<div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
|
| 746 |
+
<div class="info">Game: <span id="gameStatus">Click to Start</span></div>
|
| 747 |
+
<div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div>
|
| 748 |
+
<div class="info">Reward: <span id="reward">0</span></div>
|
| 749 |
+
</div>
|
| 750 |
+
|
| 751 |
+
<div id="controls">
|
| 752 |
+
<div class="control-section">
|
| 753 |
+
<h3>Movement</h3>
|
| 754 |
+
<div>
|
| 755 |
+
<span class="key-display" id="key-w">W</span> Forward<br>
|
| 756 |
+
<span class="key-display" id="key-a">A</span> Left
|
| 757 |
+
<span class="key-display" id="key-s">S</span> Back
|
| 758 |
+
<span class="key-display" id="key-d">D</span> Right<br>
|
| 759 |
+
<span class="key-display" id="key-space">Space</span> Jump
|
| 760 |
+
<span class="key-display" id="key-ctrl">Ctrl</span> Crouch
|
| 761 |
+
<span class="key-display" id="key-shift">Shift</span> Walk
|
| 762 |
+
</div>
|
| 763 |
+
</div>
|
| 764 |
+
|
| 765 |
+
<div class="control-section">
|
| 766 |
+
<h3>Actions</h3>
|
| 767 |
+
<div>
|
| 768 |
+
<span class="key-display" id="key-1">1</span> Weapon 1<br>
|
| 769 |
+
<span class="key-display" id="key-2">2</span> Weapon 2
|
| 770 |
+
<span class="key-display" id="key-3">3</span> Weapon 3<br>
|
| 771 |
+
<span class="key-display" id="key-r">R</span> Reload<br>
|
| 772 |
+
<span class="key-display" id="key-arrows">↑↓←→</span> Camera<br>
|
| 773 |
+
<span class="key-display" id="key-enter">Enter</span> Reset Game<br>
|
| 774 |
+
<span class="key-display" id="key-esc">Esc</span> Pause/Quit
|
| 775 |
+
</div>
|
| 776 |
+
</div>
|
| 777 |
+
</div>
|
| 778 |
+
|
| 779 |
+
<script>
|
| 780 |
+
const canvas = document.getElementById('gameCanvas');
|
| 781 |
+
const ctx = canvas.getContext('2d');
|
| 782 |
+
const statusEl = document.getElementById('connectionStatus');
|
| 783 |
+
const gameStatusEl = document.getElementById('gameStatus');
|
| 784 |
+
const frameEl = document.getElementById('frameCount');
|
| 785 |
+
const aiFpsEl = document.getElementById('aiFps');
|
| 786 |
+
const rewardEl = document.getElementById('reward');
|
| 787 |
+
const loadingEl = document.getElementById('loadingIndicator');
|
| 788 |
+
const downloadSectionEl = document.getElementById('downloadSection');
|
| 789 |
+
const downloadStatusEl = document.getElementById('downloadStatus');
|
| 790 |
+
const progressBarEl = document.getElementById('progressBar');
|
| 791 |
+
const progressTextEl = document.getElementById('progressText');
|
| 792 |
+
|
| 793 |
+
let ws = null;
|
| 794 |
+
let pressedKeys = new Set();
|
| 795 |
+
let gameStarted = false;
|
| 796 |
+
|
| 797 |
+
// Key mapping
|
| 798 |
+
const keyDisplayMap = {
|
| 799 |
+
'KeyW': 'key-w',
|
| 800 |
+
'KeyA': 'key-a',
|
| 801 |
+
'KeyS': 'key-s',
|
| 802 |
+
'KeyD': 'key-d',
|
| 803 |
+
'Space': 'key-space',
|
| 804 |
+
'ControlLeft': 'key-ctrl',
|
| 805 |
+
'ShiftLeft': 'key-shift',
|
| 806 |
+
'Digit1': 'key-1',
|
| 807 |
+
'Digit2': 'key-2',
|
| 808 |
+
'Digit3': 'key-3',
|
| 809 |
+
'KeyR': 'key-r',
|
| 810 |
+
'ArrowUp': 'key-arrows',
|
| 811 |
+
'ArrowDown': 'key-arrows',
|
| 812 |
+
'ArrowLeft': 'key-arrows',
|
| 813 |
+
'ArrowRight': 'key-arrows',
|
| 814 |
+
'Enter': 'key-enter',
|
| 815 |
+
'Escape': 'key-esc'
|
| 816 |
+
};
|
| 817 |
+
|
| 818 |
+
function connectWebSocket() {
|
| 819 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 820 |
+
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
| 821 |
+
|
| 822 |
+
ws = new WebSocket(wsUrl);
|
| 823 |
+
|
| 824 |
+
ws.onopen = function(event) {
|
| 825 |
+
statusEl.textContent = 'Connected';
|
| 826 |
+
statusEl.style.color = '#00ff00';
|
| 827 |
+
// If user already clicked to start before WS was ready, send start now
|
| 828 |
+
if (gameStarted) {
|
| 829 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 830 |
+
}
|
| 831 |
+
};
|
| 832 |
+
|
| 833 |
+
ws.onmessage = function(event) {
|
| 834 |
+
const data = JSON.parse(event.data);
|
| 835 |
+
|
| 836 |
+
if (data.type === 'loading') {
|
| 837 |
+
// Handle loading status
|
| 838 |
+
downloadSectionEl.style.display = 'block';
|
| 839 |
+
downloadStatusEl.textContent = data.status;
|
| 840 |
+
|
| 841 |
+
if (data.progress !== undefined) {
|
| 842 |
+
progressBarEl.style.width = data.progress + '%';
|
| 843 |
+
progressTextEl.textContent = data.progress + '% - ' + data.status;
|
| 844 |
+
} else {
|
| 845 |
+
progressTextEl.textContent = data.status;
|
| 846 |
+
}
|
| 847 |
+
|
| 848 |
+
gameStatusEl.textContent = 'Loading Models...';
|
| 849 |
+
gameStatusEl.style.color = '#ffaa00';
|
| 850 |
+
|
| 851 |
+
} else if (data.type === 'frame') {
|
| 852 |
+
// Hide loading indicators once we get frames
|
| 853 |
+
downloadSectionEl.style.display = 'none';
|
| 854 |
+
// Update frame display
|
| 855 |
+
if (data.image) {
|
| 856 |
+
const img = new Image();
|
| 857 |
+
img.onload = function() {
|
| 858 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
| 859 |
+
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
| 860 |
+
};
|
| 861 |
+
img.src = data.image;
|
| 862 |
+
}
|
| 863 |
+
|
| 864 |
+
frameEl.textContent = data.frame_count;
|
| 865 |
+
rewardEl.textContent = data.reward.toFixed(2);
|
| 866 |
+
|
| 867 |
+
// Update AI FPS display and hide loading indicator once AI starts
|
| 868 |
+
if (data.ai_fps !== undefined && data.ai_fps !== null) {
|
| 869 |
+
// Ensure FPS value is reasonable
|
| 870 |
+
const aiFps = Math.min(Math.max(data.ai_fps, 0), 100);
|
| 871 |
+
aiFpsEl.textContent = aiFps.toFixed(1);
|
| 872 |
+
|
| 873 |
+
// Color code AI FPS for performance indication
|
| 874 |
+
if (aiFps >= 8) {
|
| 875 |
+
aiFpsEl.style.color = '#00ff00'; // Green for good performance
|
| 876 |
+
} else if (aiFps >= 5) {
|
| 877 |
+
aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance
|
| 878 |
+
} else if (aiFps > 0) {
|
| 879 |
+
aiFpsEl.style.color = '#ff0000'; // Red for poor performance
|
| 880 |
+
} else {
|
| 881 |
+
aiFpsEl.style.color = '#888888'; // Gray for inactive
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
// Hide loading indicator once AI inference starts working
|
| 885 |
+
if (aiFps > 0 && gameStarted) {
|
| 886 |
+
loadingEl.style.display = 'none';
|
| 887 |
+
gameStatusEl.textContent = 'Playing';
|
| 888 |
+
gameStatusEl.style.color = '#00ff00';
|
| 889 |
+
}
|
| 890 |
+
}
|
| 891 |
+
}
|
| 892 |
+
};
|
| 893 |
+
|
| 894 |
+
ws.onclose = function(event) {
|
| 895 |
+
statusEl.textContent = 'Disconnected';
|
| 896 |
+
statusEl.style.color = '#ff0000';
|
| 897 |
+
setTimeout(connectWebSocket, 1000); // Reconnect after 1 second
|
| 898 |
+
};
|
| 899 |
+
|
| 900 |
+
ws.onerror = function(event) {
|
| 901 |
+
statusEl.textContent = 'Error';
|
| 902 |
+
statusEl.style.color = '#ff0000';
|
| 903 |
+
};
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
function sendKeyState() {
|
| 907 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 908 |
+
ws.send(JSON.stringify({
|
| 909 |
+
type: 'keys',
|
| 910 |
+
keys: Array.from(pressedKeys)
|
| 911 |
+
}));
|
| 912 |
+
}
|
| 913 |
+
}
|
| 914 |
+
|
| 915 |
+
function startGame() {
|
| 916 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 917 |
+
ws.send(JSON.stringify({
|
| 918 |
+
type: 'start'
|
| 919 |
+
}));
|
| 920 |
+
gameStarted = true;
|
| 921 |
+
gameStatusEl.textContent = 'Starting AI...';
|
| 922 |
+
gameStatusEl.style.color = '#ffff00';
|
| 923 |
+
loadingEl.style.display = 'block';
|
| 924 |
+
console.log('Game started');
|
| 925 |
+
}
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
function pauseGame() {
|
| 929 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 930 |
+
ws.send(JSON.stringify({
|
| 931 |
+
type: 'pause'
|
| 932 |
+
}));
|
| 933 |
+
gameStarted = false;
|
| 934 |
+
gameStatusEl.textContent = 'Paused - Click to Resume';
|
| 935 |
+
gameStatusEl.style.color = '#ffff00';
|
| 936 |
+
console.log('Game paused');
|
| 937 |
+
}
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
function updateKeyDisplay() {
|
| 941 |
+
// Reset all key displays
|
| 942 |
+
Object.values(keyDisplayMap).forEach(id => {
|
| 943 |
+
const el = document.getElementById(id);
|
| 944 |
+
if (el) el.classList.remove('key-pressed');
|
| 945 |
+
});
|
| 946 |
+
|
| 947 |
+
// Highlight pressed keys
|
| 948 |
+
pressedKeys.forEach(key => {
|
| 949 |
+
const displayId = keyDisplayMap[key];
|
| 950 |
+
if (displayId) {
|
| 951 |
+
const el = document.getElementById(displayId);
|
| 952 |
+
if (el) el.classList.add('key-pressed');
|
| 953 |
+
}
|
| 954 |
+
});
|
| 955 |
+
}
|
| 956 |
+
|
| 957 |
+
// Focus canvas and handle keyboard events
|
| 958 |
+
canvas.addEventListener('click', () => {
|
| 959 |
+
canvas.focus();
|
| 960 |
+
if (!gameStarted) {
|
| 961 |
+
// Queue start locally and send immediately if WS is open
|
| 962 |
+
gameStarted = true;
|
| 963 |
+
gameStatusEl.textContent = 'Starting AI...';
|
| 964 |
+
gameStatusEl.style.color = '#ffff00';
|
| 965 |
+
loadingEl.style.display = 'block';
|
| 966 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 967 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 968 |
+
}
|
| 969 |
+
}
|
| 970 |
+
});
|
| 971 |
+
|
| 972 |
+
canvas.addEventListener('keydown', (event) => {
|
| 973 |
+
event.preventDefault();
|
| 974 |
+
|
| 975 |
+
// Handle special keys
|
| 976 |
+
if (event.code === 'Enter') {
|
| 977 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 978 |
+
ws.send(JSON.stringify({
|
| 979 |
+
type: 'reset'
|
| 980 |
+
}));
|
| 981 |
+
console.log('Environment reset requested');
|
| 982 |
+
}
|
| 983 |
+
// Add to pressedKeys for visual feedback
|
| 984 |
+
pressedKeys.add(event.code);
|
| 985 |
+
updateKeyDisplay();
|
| 986 |
+
|
| 987 |
+
// Remove Enter from pressedKeys after a short delay for visual feedback
|
| 988 |
+
setTimeout(() => {
|
| 989 |
+
pressedKeys.delete(event.code);
|
| 990 |
+
updateKeyDisplay();
|
| 991 |
+
}, 200);
|
| 992 |
+
} else if (event.code === 'Escape') {
|
| 993 |
+
pauseGame();
|
| 994 |
+
// Add to pressedKeys for visual feedback
|
| 995 |
+
pressedKeys.add(event.code);
|
| 996 |
+
updateKeyDisplay();
|
| 997 |
+
|
| 998 |
+
// Remove ESC from pressedKeys after a short delay for visual feedback
|
| 999 |
+
setTimeout(() => {
|
| 1000 |
+
pressedKeys.delete(event.code);
|
| 1001 |
+
updateKeyDisplay();
|
| 1002 |
+
}, 200);
|
| 1003 |
+
} else {
|
| 1004 |
+
// Only send game keys if game is started
|
| 1005 |
+
if (gameStarted) {
|
| 1006 |
+
pressedKeys.add(event.code);
|
| 1007 |
+
updateKeyDisplay();
|
| 1008 |
+
sendKeyState();
|
| 1009 |
+
}
|
| 1010 |
+
}
|
| 1011 |
+
});
|
| 1012 |
+
|
| 1013 |
+
canvas.addEventListener('keyup', (event) => {
|
| 1014 |
+
event.preventDefault();
|
| 1015 |
+
|
| 1016 |
+
// Don't handle special keys release (handled in keydown with timeout)
|
| 1017 |
+
if (event.code !== 'Enter' && event.code !== 'Escape') {
|
| 1018 |
+
if (gameStarted) {
|
| 1019 |
+
pressedKeys.delete(event.code);
|
| 1020 |
+
updateKeyDisplay();
|
| 1021 |
+
sendKeyState();
|
| 1022 |
+
}
|
| 1023 |
+
}
|
| 1024 |
+
});
|
| 1025 |
+
|
| 1026 |
+
// Handle mouse events for clicks
|
| 1027 |
+
canvas.addEventListener('mousedown', (event) => {
|
| 1028 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 1029 |
+
ws.send(JSON.stringify({
|
| 1030 |
+
type: 'mouse',
|
| 1031 |
+
button: event.button,
|
| 1032 |
+
action: 'down',
|
| 1033 |
+
x: event.offsetX,
|
| 1034 |
+
y: event.offsetY
|
| 1035 |
+
}));
|
| 1036 |
+
}
|
| 1037 |
+
});
|
| 1038 |
+
|
| 1039 |
+
canvas.addEventListener('mouseup', (event) => {
|
| 1040 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 1041 |
+
ws.send(JSON.stringify({
|
| 1042 |
+
type: 'mouse',
|
| 1043 |
+
button: event.button,
|
| 1044 |
+
action: 'up',
|
| 1045 |
+
x: event.offsetX,
|
| 1046 |
+
y: event.offsetY
|
| 1047 |
+
}));
|
| 1048 |
+
}
|
| 1049 |
+
});
|
| 1050 |
+
|
| 1051 |
+
// Initialize
|
| 1052 |
+
connectWebSocket();
|
| 1053 |
+
canvas.focus();
|
| 1054 |
+
</script>
|
| 1055 |
+
</body>
|
| 1056 |
+
</html>
|
| 1057 |
+
"""
|
| 1058 |
+
return html_content
|
| 1059 |
+
|
| 1060 |
+
@app.websocket("/ws")
|
| 1061 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 1062 |
+
"""Handle WebSocket connections for real-time game communication"""
|
| 1063 |
+
await websocket.accept()
|
| 1064 |
+
connected_clients.add(websocket)
|
| 1065 |
+
|
| 1066 |
+
try:
|
| 1067 |
+
while True:
|
| 1068 |
+
# Receive messages from client
|
| 1069 |
+
data = await websocket.receive_text()
|
| 1070 |
+
message = json.loads(data)
|
| 1071 |
+
|
| 1072 |
+
if message['type'] == 'keys':
|
| 1073 |
+
# Update pressed keys
|
| 1074 |
+
game_engine.pressed_keys = set(message['keys'])
|
| 1075 |
+
|
| 1076 |
+
elif message['type'] == 'reset':
|
| 1077 |
+
# Handle environment reset request
|
| 1078 |
+
game_engine.request_reset()
|
| 1079 |
+
|
| 1080 |
+
elif message['type'] == 'start':
|
| 1081 |
+
# Handle game start request
|
| 1082 |
+
game_engine.start_game()
|
| 1083 |
+
|
| 1084 |
+
elif message['type'] == 'pause':
|
| 1085 |
+
# Handle game pause request
|
| 1086 |
+
game_engine.pause_game()
|
| 1087 |
+
|
| 1088 |
+
elif message['type'] == 'mouse':
|
| 1089 |
+
# Handle mouse events
|
| 1090 |
+
if message['action'] == 'down':
|
| 1091 |
+
if message['button'] == 0: # Left click
|
| 1092 |
+
game_engine.l_click = True
|
| 1093 |
+
elif message['button'] == 2: # Right click
|
| 1094 |
+
game_engine.r_click = True
|
| 1095 |
+
elif message['action'] == 'up':
|
| 1096 |
+
if message['button'] == 0: # Left click
|
| 1097 |
+
game_engine.l_click = False
|
| 1098 |
+
elif message['button'] == 2: # Right click
|
| 1099 |
+
game_engine.r_click = False
|
| 1100 |
+
|
| 1101 |
+
# Update mouse position (relative to canvas)
|
| 1102 |
+
game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px
|
| 1103 |
+
game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px
|
| 1104 |
+
|
| 1105 |
+
except WebSocketDisconnect:
|
| 1106 |
+
connected_clients.discard(websocket)
|
| 1107 |
+
except Exception as e:
|
| 1108 |
+
logger.error(f"WebSocket error: {e}")
|
| 1109 |
+
connected_clients.discard(websocket)
|
| 1110 |
+
|
| 1111 |
+
if __name__ == "__main__":
|
| 1112 |
+
# For local development
|
| 1113 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
| 1114 |
+
|
app.py.backup2
ADDED
|
@@ -0,0 +1,1112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web-based Diamond CSGO AI Player for Hugging Face Spaces
|
| 3 |
+
Uses FastAPI + WebSocket for real-time keyboard input and game streaming
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# Fix environment variables FIRST, before any other imports
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
# Fix OMP_NUM_THREADS immediately (before PyTorch/NumPy imports)
|
| 11 |
+
if "OMP_NUM_THREADS" not in os.environ or not os.environ.get("OMP_NUM_THREADS", "").isdigit():
|
| 12 |
+
os.environ["OMP_NUM_THREADS"] = "2"
|
| 13 |
+
|
| 14 |
+
# Set up cache directories immediately
|
| 15 |
+
temp_dir = tempfile.gettempdir()
|
| 16 |
+
os.environ.setdefault("TORCH_HOME", os.path.join(temp_dir, "torch"))
|
| 17 |
+
os.environ.setdefault("HF_HOME", os.path.join(temp_dir, "huggingface"))
|
| 18 |
+
os.environ.setdefault("TRANSFORMERS_CACHE", os.path.join(temp_dir, "transformers"))
|
| 19 |
+
|
| 20 |
+
# Create cache directories
|
| 21 |
+
for cache_var in ["TORCH_HOME", "HF_HOME", "TRANSFORMERS_CACHE"]:
|
| 22 |
+
cache_path = os.environ[cache_var]
|
| 23 |
+
os.makedirs(cache_path, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
import asyncio
|
| 26 |
+
import base64
|
| 27 |
+
import io
|
| 28 |
+
import json
|
| 29 |
+
import logging
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Dict, List, Optional, Set
|
| 32 |
+
|
| 33 |
+
import cv2
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
import uvicorn
|
| 37 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 38 |
+
from fastapi.responses import HTMLResponse
|
| 39 |
+
from fastapi.staticfiles import StaticFiles
|
| 40 |
+
from hydra import compose, initialize
|
| 41 |
+
from hydra.utils import instantiate
|
| 42 |
+
from omegaconf import DictConfig, OmegaConf
|
| 43 |
+
from PIL import Image
|
| 44 |
+
|
| 45 |
+
# Import your modules
|
| 46 |
+
import sys
|
| 47 |
+
from pathlib import Path
|
| 48 |
+
|
| 49 |
+
# Add project root to path for src package imports
|
| 50 |
+
project_root = Path(__file__).parent
|
| 51 |
+
if str(project_root) not in sys.path:
|
| 52 |
+
sys.path.insert(0, str(project_root))
|
| 53 |
+
|
| 54 |
+
from src.agent import Agent
|
| 55 |
+
from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
|
| 56 |
+
from src.envs import WorldModelEnv
|
| 57 |
+
from src.game.web_play_env import WebPlayEnv
|
| 58 |
+
from src.utils import extract_state_dict
|
| 59 |
+
from config_web import web_config
|
| 60 |
+
|
| 61 |
+
# Configure logging
|
| 62 |
+
logging.basicConfig(level=logging.INFO)
|
| 63 |
+
logger = logging.getLogger(__name__)
|
| 64 |
+
|
| 65 |
+
# Global variables
|
| 66 |
+
app = FastAPI(title="Diamond CSGO AI Player")
|
| 67 |
+
|
| 68 |
+
# Set safe defaults for headless CI/Spaces environments
|
| 69 |
+
os.environ.setdefault("SDL_VIDEODRIVER", "dummy")
|
| 70 |
+
os.environ.setdefault("SDL_AUDIODRIVER", "dummy")
|
| 71 |
+
os.environ.setdefault("PYGAME_HIDE_SUPPORT_PROMPT", "1")
|
| 72 |
+
|
| 73 |
+
# Environment variables already set at top of file
|
| 74 |
+
connected_clients: Set[WebSocket] = set()
|
| 75 |
+
|
| 76 |
+
class WebKeyMap:
|
| 77 |
+
"""Map web key codes to pygame-like keys for CSGO actions"""
|
| 78 |
+
WEB_TO_CSGO = {
|
| 79 |
+
'KeyW': 'w',
|
| 80 |
+
'KeyA': 'a',
|
| 81 |
+
'KeyS': 's',
|
| 82 |
+
'KeyD': 'd',
|
| 83 |
+
'Space': 'space',
|
| 84 |
+
'ControlLeft': 'left ctrl',
|
| 85 |
+
'ShiftLeft': 'left shift',
|
| 86 |
+
'Digit1': '1',
|
| 87 |
+
'Digit2': '2',
|
| 88 |
+
'Digit3': '3',
|
| 89 |
+
'KeyR': 'r',
|
| 90 |
+
'ArrowUp': 'camera_up',
|
| 91 |
+
'ArrowDown': 'camera_down',
|
| 92 |
+
'ArrowLeft': 'camera_left',
|
| 93 |
+
'ArrowRight': 'camera_right'
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
class WebGameEngine:
|
| 97 |
+
"""Web-compatible game engine that replaces pygame functionality"""
|
| 98 |
+
|
| 99 |
+
def __init__(self):
|
| 100 |
+
self.play_env: Optional[WebPlayEnv] = None
|
| 101 |
+
self.obs = None
|
| 102 |
+
self.running = False
|
| 103 |
+
self.game_started = False
|
| 104 |
+
self.fps = 30 # Display FPS
|
| 105 |
+
self.ai_fps = 40 # AI inference FPS (matching standalone play.py performance)
|
| 106 |
+
self.frame_count = 0
|
| 107 |
+
self.ai_frame_count = 0
|
| 108 |
+
self.last_ai_time = 0
|
| 109 |
+
self.start_time = 0 # Track when AI started for proper FPS calculation
|
| 110 |
+
self.last_frame_send_time = 0 # Track frame sending for optimization
|
| 111 |
+
self.web_fps = 20 # Web display FPS (lower than AI FPS to reduce network overhead)
|
| 112 |
+
self.pressed_keys: Set[str] = set()
|
| 113 |
+
self.mouse_x = 0
|
| 114 |
+
self.mouse_y = 0
|
| 115 |
+
self.l_click = False
|
| 116 |
+
self.r_click = False
|
| 117 |
+
self.should_reset = False
|
| 118 |
+
self.cached_obs = None # Cache last observation for frame skipping
|
| 119 |
+
self.first_inference_done = False # Track if first inference completed
|
| 120 |
+
self.models_ready = False # Track if models are loaded
|
| 121 |
+
self.download_progress = 0 # Track download progress (0-100)
|
| 122 |
+
self.loading_status = "Initializing..." # Loading status message
|
| 123 |
+
self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
|
| 124 |
+
import time
|
| 125 |
+
self.time_module = time
|
| 126 |
+
|
| 127 |
+
async def _load_model_from_url_async(self, agent, device):
|
| 128 |
+
"""Load model from URL using torch.hub (HF Spaces compatible)"""
|
| 129 |
+
import asyncio
|
| 130 |
+
import concurrent.futures
|
| 131 |
+
|
| 132 |
+
def load_model_weights():
|
| 133 |
+
"""Load model weights in thread pool to avoid blocking"""
|
| 134 |
+
try:
|
| 135 |
+
logger.info("Loading model using torch.hub.load_state_dict_from_url...")
|
| 136 |
+
self.loading_status = "Downloading model..."
|
| 137 |
+
self.download_progress = 10
|
| 138 |
+
|
| 139 |
+
model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
|
| 140 |
+
|
| 141 |
+
# Use torch.hub to download and load state dict with custom cache dir
|
| 142 |
+
logger.info(f"Loading state dict from {model_url}")
|
| 143 |
+
|
| 144 |
+
# Set custom cache directory that we have write permissions for
|
| 145 |
+
cache_dir = os.path.join(tempfile.gettempdir(), "torch_cache")
|
| 146 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
# Use torch.hub with custom cache directory
|
| 149 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 150 |
+
model_url,
|
| 151 |
+
map_location=device,
|
| 152 |
+
model_dir=cache_dir,
|
| 153 |
+
check_hash=False # Skip hash check for faster loading
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.download_progress = 60
|
| 157 |
+
self.loading_status = "Loading model weights into agent..."
|
| 158 |
+
logger.info("State dict loaded, applying to agent...")
|
| 159 |
+
|
| 160 |
+
# Load state dict into agent, but skip actor_critic if not present
|
| 161 |
+
has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
|
| 162 |
+
logger.info(f"Model has actor_critic weights: {has_actor_critic}")
|
| 163 |
+
agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
|
| 164 |
+
|
| 165 |
+
# Track if actor_critic was actually loaded with trained weights
|
| 166 |
+
self.actor_critic_loaded = has_actor_critic
|
| 167 |
+
|
| 168 |
+
self.download_progress = 100
|
| 169 |
+
self.loading_status = "Model loaded successfully!"
|
| 170 |
+
logger.info("All model weights loaded successfully!")
|
| 171 |
+
return True
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
logger.error(f"Failed to load model: {e}")
|
| 175 |
+
import traceback
|
| 176 |
+
traceback.print_exc()
|
| 177 |
+
return False
|
| 178 |
+
|
| 179 |
+
# Run in thread pool to avoid blocking with timeout
|
| 180 |
+
loop = asyncio.get_event_loop()
|
| 181 |
+
try:
|
| 182 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 183 |
+
# Add timeout for model loading (5 minutes max)
|
| 184 |
+
future = loop.run_in_executor(executor, load_model_weights)
|
| 185 |
+
success = await asyncio.wait_for(future, timeout=300.0) # 5 minute timeout
|
| 186 |
+
return success
|
| 187 |
+
except asyncio.TimeoutError:
|
| 188 |
+
logger.error("Model loading timed out after 5 minutes")
|
| 189 |
+
self.loading_status = "Model loading timed out - using dummy mode"
|
| 190 |
+
return False
|
| 191 |
+
except Exception as e:
|
| 192 |
+
logger.error(f"Error in model loading executor: {e}")
|
| 193 |
+
self.loading_status = f"Model loading error: {str(e)[:50]}..."
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
async def initialize_models(self):
|
| 197 |
+
"""Initialize the AI models and environment"""
|
| 198 |
+
try:
|
| 199 |
+
import torch
|
| 200 |
+
logger.info("Initializing models...")
|
| 201 |
+
|
| 202 |
+
# Setup environment and paths
|
| 203 |
+
web_config.setup_environment_variables()
|
| 204 |
+
web_config.create_default_configs()
|
| 205 |
+
|
| 206 |
+
config_path = web_config.get_config_path()
|
| 207 |
+
logger.info(f"Using config path: {config_path}")
|
| 208 |
+
|
| 209 |
+
# For Hydra, use relative path from app.py location
|
| 210 |
+
# Since app.py is in project root, config is simply "config"
|
| 211 |
+
relative_config_path = "config"
|
| 212 |
+
logger.info(f"Relative config path: {relative_config_path}")
|
| 213 |
+
|
| 214 |
+
with initialize(version_base="1.3", config_path=relative_config_path):
|
| 215 |
+
cfg = compose(config_name="trainer")
|
| 216 |
+
|
| 217 |
+
# Override config for deployment
|
| 218 |
+
cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml")
|
| 219 |
+
cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml")
|
| 220 |
+
|
| 221 |
+
# Use GPU if available, otherwise fall back to CPU
|
| 222 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 223 |
+
logger.info(f"Using device: {device}")
|
| 224 |
+
|
| 225 |
+
# Log GPU availability and CUDA info for debugging
|
| 226 |
+
if torch.cuda.is_available():
|
| 227 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
| 228 |
+
logger.info(f"GPU device count: {torch.cuda.device_count()}")
|
| 229 |
+
logger.info(f"Current GPU: {torch.cuda.get_device_name(0)}")
|
| 230 |
+
logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
|
| 231 |
+
logger.info("🚀 GPU acceleration enabled!")
|
| 232 |
+
else:
|
| 233 |
+
logger.info("CUDA not available, using CPU mode")
|
| 234 |
+
|
| 235 |
+
# Initialize agent first
|
| 236 |
+
num_actions = cfg.env.num_actions
|
| 237 |
+
agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
|
| 238 |
+
|
| 239 |
+
# Get spawn directory
|
| 240 |
+
spawn_dir = web_config.get_spawn_dir()
|
| 241 |
+
|
| 242 |
+
# Try to load checkpoint (remote first, then local, then dummy mode)
|
| 243 |
+
try:
|
| 244 |
+
# First try to load from Hugging Face Hub using torch.hub
|
| 245 |
+
logger.info("Loading model from Hugging Face Hub with torch.hub...")
|
| 246 |
+
|
| 247 |
+
success = await self._load_model_from_url_async(agent, device)
|
| 248 |
+
|
| 249 |
+
if success:
|
| 250 |
+
logger.info("Successfully loaded checkpoint from HF Hub")
|
| 251 |
+
else:
|
| 252 |
+
# Fallback to local checkpoint if available
|
| 253 |
+
logger.error("Failed to load from HF Hub! Check the detailed error above.")
|
| 254 |
+
checkpoint_path = web_config.get_checkpoint_path()
|
| 255 |
+
if checkpoint_path.exists():
|
| 256 |
+
logger.info(f"Loading local checkpoint: {checkpoint_path}")
|
| 257 |
+
self.loading_status = "Loading local checkpoint..."
|
| 258 |
+
agent.load(checkpoint_path)
|
| 259 |
+
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
| 260 |
+
# Assume local checkpoint has actor_critic weights (may need verification)
|
| 261 |
+
self.actor_critic_loaded = True
|
| 262 |
+
else:
|
| 263 |
+
logger.error(f"No local checkpoint found at: {checkpoint_path}")
|
| 264 |
+
raise FileNotFoundError("No model checkpoint available (local or remote)")
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.error(f"Failed to load any checkpoint: {e}")
|
| 268 |
+
self._init_dummy_mode()
|
| 269 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 270 |
+
return True
|
| 271 |
+
|
| 272 |
+
# Initialize world model environment
|
| 273 |
+
try:
|
| 274 |
+
sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
|
| 275 |
+
if agent.upsampler is not None:
|
| 276 |
+
sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
|
| 277 |
+
wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
|
| 278 |
+
wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model,
|
| 279 |
+
spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
|
| 280 |
+
|
| 281 |
+
# Create play environment
|
| 282 |
+
self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
|
| 283 |
+
|
| 284 |
+
# Verify actor-critic is loaded and ready for inference
|
| 285 |
+
if agent.actor_critic is not None and self.actor_critic_loaded:
|
| 286 |
+
logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
|
| 287 |
+
logger.info(f"Actor-critic device: {agent.actor_critic.device}")
|
| 288 |
+
# Force AI control for web demo
|
| 289 |
+
self.play_env.is_human_player = False
|
| 290 |
+
logger.info("WebPlayEnv set to AI control mode")
|
| 291 |
+
elif agent.actor_critic is not None and not self.actor_critic_loaded:
|
| 292 |
+
logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
|
| 293 |
+
self.play_env.is_human_player = True
|
| 294 |
+
logger.info("WebPlayEnv set to human control mode (no trained weights)")
|
| 295 |
+
else:
|
| 296 |
+
logger.warning("No actor-critic model found - AI inference will not work!")
|
| 297 |
+
self.play_env.is_human_player = True
|
| 298 |
+
logger.info("WebPlayEnv set to human control mode (fallback)")
|
| 299 |
+
|
| 300 |
+
# Enable model compilation for better performance (like standalone play.py)
|
| 301 |
+
# This gives 20-50% speedup but causes 10-30s delay on first inference
|
| 302 |
+
import os
|
| 303 |
+
enable_compile = device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1"
|
| 304 |
+
|
| 305 |
+
if enable_compile:
|
| 306 |
+
logger.info("🚀 Compiling models for faster inference (like standalone play.py)...")
|
| 307 |
+
logger.info("⏱️ First inference will take 10-30s, but subsequent inferences will be much faster!")
|
| 308 |
+
try:
|
| 309 |
+
wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
|
| 310 |
+
if wm_env.upsample_next_obs is not None:
|
| 311 |
+
wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
|
| 312 |
+
logger.info("✅ Model compilation enabled - expect 20-50% speedup!")
|
| 313 |
+
except Exception as e:
|
| 314 |
+
logger.warning(f"Model compilation failed: {e}")
|
| 315 |
+
enable_compile = False
|
| 316 |
+
|
| 317 |
+
if not enable_compile:
|
| 318 |
+
logger.info("Model compilation disabled. Set ENABLE_TORCH_COMPILE=1 for better performance.")
|
| 319 |
+
|
| 320 |
+
# Reset environment
|
| 321 |
+
self.obs, _ = self.play_env.reset()
|
| 322 |
+
self.cached_obs = self.obs # Initialize cache
|
| 323 |
+
|
| 324 |
+
logger.info("Models initialized successfully!")
|
| 325 |
+
logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}")
|
| 326 |
+
self.models_ready = True
|
| 327 |
+
self.loading_status = "Ready!"
|
| 328 |
+
return True
|
| 329 |
+
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.error(f"Failed to initialize world model environment: {e}")
|
| 332 |
+
self._init_dummy_mode()
|
| 333 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 334 |
+
self.models_ready = True
|
| 335 |
+
self.loading_status = "Using dummy mode"
|
| 336 |
+
return True
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"Failed to initialize models: {e}")
|
| 340 |
+
import traceback
|
| 341 |
+
traceback.print_exc()
|
| 342 |
+
self._init_dummy_mode()
|
| 343 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 344 |
+
self.models_ready = True
|
| 345 |
+
self.loading_status = "Error - using dummy mode"
|
| 346 |
+
return True
|
| 347 |
+
|
| 348 |
+
def _init_dummy_mode(self):
|
| 349 |
+
"""Initialize dummy mode for testing without models"""
|
| 350 |
+
logger.info("Initializing dummy mode...")
|
| 351 |
+
|
| 352 |
+
# Create a test observation
|
| 353 |
+
height, width = 150, 600
|
| 354 |
+
img_array = np.zeros((height, width, 3), dtype=np.uint8)
|
| 355 |
+
|
| 356 |
+
# Add test pattern
|
| 357 |
+
for y in range(height):
|
| 358 |
+
for x in range(width):
|
| 359 |
+
img_array[y, x, 0] = (x % 256) # Red gradient
|
| 360 |
+
img_array[y, x, 1] = (y % 256) # Green gradient
|
| 361 |
+
img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
|
| 362 |
+
|
| 363 |
+
# Convert to torch tensor in expected format [-1, 1]
|
| 364 |
+
tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
|
| 365 |
+
tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
|
| 366 |
+
tensor = tensor.unsqueeze(0) # Add batch dimension
|
| 367 |
+
|
| 368 |
+
self.obs = tensor
|
| 369 |
+
self.play_env = None # No real environment in dummy mode
|
| 370 |
+
logger.info("Dummy mode initialized with test pattern")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def step_environment(self):
|
| 374 |
+
"""Step the environment with current input state (with intelligent frame skipping)"""
|
| 375 |
+
if self.play_env is None:
|
| 376 |
+
# Dummy mode - just return current observation
|
| 377 |
+
return self.obs, 0.0, False, False, {"mode": "dummy"}
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
# Check if reset is requested
|
| 381 |
+
if self.should_reset:
|
| 382 |
+
self.reset_environment()
|
| 383 |
+
self.should_reset = False
|
| 384 |
+
self.last_ai_time = self.time_module.time() # Reset AI timer
|
| 385 |
+
return self.obs, 0.0, False, False, {"reset": True}
|
| 386 |
+
|
| 387 |
+
# Intelligent frame skipping: only run AI inference at target FPS
|
| 388 |
+
current_time = self.time_module.time()
|
| 389 |
+
time_since_last_ai = current_time - self.last_ai_time
|
| 390 |
+
should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
|
| 391 |
+
|
| 392 |
+
if should_run_ai:
|
| 393 |
+
# Show loading indicator for first inference (can be slow)
|
| 394 |
+
if not self.first_inference_done:
|
| 395 |
+
logger.info("Running first AI inference (may take 5-15 seconds)...")
|
| 396 |
+
|
| 397 |
+
# Run AI inference
|
| 398 |
+
inference_start = self.time_module.time()
|
| 399 |
+
next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
|
| 400 |
+
pressed_keys=self.pressed_keys,
|
| 401 |
+
mouse_x=self.mouse_x,
|
| 402 |
+
mouse_y=self.mouse_y,
|
| 403 |
+
l_click=self.l_click,
|
| 404 |
+
r_click=self.r_click
|
| 405 |
+
)
|
| 406 |
+
inference_time = self.time_module.time() - inference_start
|
| 407 |
+
|
| 408 |
+
# Log first inference completion
|
| 409 |
+
if not self.first_inference_done:
|
| 410 |
+
self.first_inference_done = True
|
| 411 |
+
logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!")
|
| 412 |
+
|
| 413 |
+
# Cache the new observation and update timing
|
| 414 |
+
self.cached_obs = next_obs
|
| 415 |
+
self.last_ai_time = current_time
|
| 416 |
+
self.ai_frame_count += 1
|
| 417 |
+
|
| 418 |
+
# Add AI performance info
|
| 419 |
+
info = info or {}
|
| 420 |
+
info["ai_inference"] = True
|
| 421 |
+
|
| 422 |
+
# Calculate proper AI FPS: frames / elapsed time since start
|
| 423 |
+
elapsed_time = current_time - self.start_time
|
| 424 |
+
if elapsed_time > 0 and self.ai_frame_count > 0:
|
| 425 |
+
ai_fps = self.ai_frame_count / elapsed_time
|
| 426 |
+
# Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference)
|
| 427 |
+
info["ai_fps"] = min(ai_fps, 100.0)
|
| 428 |
+
else:
|
| 429 |
+
info["ai_fps"] = 0
|
| 430 |
+
|
| 431 |
+
info["inference_time"] = inference_time
|
| 432 |
+
|
| 433 |
+
return next_obs, reward, done, truncated, info
|
| 434 |
+
else:
|
| 435 |
+
# Use cached observation for smoother display without AI overhead
|
| 436 |
+
obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
|
| 437 |
+
|
| 438 |
+
# Calculate AI FPS for cached frames too
|
| 439 |
+
elapsed_time = current_time - self.start_time
|
| 440 |
+
if elapsed_time > 0 and self.ai_frame_count > 0:
|
| 441 |
+
ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS
|
| 442 |
+
else:
|
| 443 |
+
ai_fps = 0
|
| 444 |
+
|
| 445 |
+
return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps}
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
logger.error(f"Error stepping environment: {e}")
|
| 449 |
+
obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
|
| 450 |
+
return obs_to_return, 0.0, False, False, {"error": str(e)}
|
| 451 |
+
|
| 452 |
+
def reset_environment(self):
|
| 453 |
+
"""Reset the environment"""
|
| 454 |
+
try:
|
| 455 |
+
if self.play_env is not None:
|
| 456 |
+
self.obs, _ = self.play_env.reset()
|
| 457 |
+
self.cached_obs = self.obs # Update cache
|
| 458 |
+
logger.info("Environment reset successfully")
|
| 459 |
+
else:
|
| 460 |
+
# Dummy mode - recreate test pattern
|
| 461 |
+
self._init_dummy_mode()
|
| 462 |
+
self.cached_obs = self.obs # Update cache
|
| 463 |
+
logger.info("Dummy environment reset")
|
| 464 |
+
except Exception as e:
|
| 465 |
+
logger.error(f"Error resetting environment: {e}")
|
| 466 |
+
|
| 467 |
+
def request_reset(self):
|
| 468 |
+
"""Request environment reset on next step"""
|
| 469 |
+
self.should_reset = True
|
| 470 |
+
logger.info("Environment reset requested")
|
| 471 |
+
|
| 472 |
+
def start_game(self):
|
| 473 |
+
"""Start the game"""
|
| 474 |
+
self.game_started = True
|
| 475 |
+
self.start_time = self.time_module.time() # Reset start time for FPS calculation
|
| 476 |
+
self.ai_frame_count = 0 # Reset AI frame count
|
| 477 |
+
logger.info("Game started")
|
| 478 |
+
|
| 479 |
+
def pause_game(self):
|
| 480 |
+
"""Pause/stop the game"""
|
| 481 |
+
self.game_started = False
|
| 482 |
+
logger.info("Game paused")
|
| 483 |
+
|
| 484 |
+
def obs_to_base64(self, obs: torch.Tensor) -> str:
|
| 485 |
+
"""Convert observation tensor to base64 image for web display (optimized for speed)"""
|
| 486 |
+
if obs is None:
|
| 487 |
+
return ""
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
# Handle observation tensor conversion based on dimensions
|
| 491 |
+
if obs.ndim == 4 and obs.size(0) == 1:
|
| 492 |
+
# 4D tensor with batch dimension [1, C, H, W] -> [C, H, W]
|
| 493 |
+
img_tensor = obs[0]
|
| 494 |
+
elif obs.ndim == 3:
|
| 495 |
+
# 3D tensor [C, H, W]
|
| 496 |
+
img_tensor = obs
|
| 497 |
+
elif obs.ndim == 2:
|
| 498 |
+
# 2D tensor - likely an error, return empty string
|
| 499 |
+
logger.warning(f"Unexpected 2D observation tensor: {obs.shape}")
|
| 500 |
+
return ""
|
| 501 |
+
else:
|
| 502 |
+
logger.warning(f"Unexpected observation dimensions: {obs.shape}")
|
| 503 |
+
return ""
|
| 504 |
+
|
| 505 |
+
# Convert to numpy with proper range conversion
|
| 506 |
+
img_array = img_tensor.mul(127.5).add_(127.5).clamp_(0, 255).byte()
|
| 507 |
+
img_array = img_array.permute(1, 2, 0).cpu().numpy()
|
| 508 |
+
|
| 509 |
+
# Direct resize with OpenCV (much faster than PIL)
|
| 510 |
+
img_array = cv2.resize(img_array, (600, 150), interpolation=cv2.INTER_CUBIC)
|
| 511 |
+
|
| 512 |
+
# Note: img_array is already in RGB format from PyTorch tensor, no conversion needed
|
| 513 |
+
|
| 514 |
+
# Optimized JPEG encoding with OpenCV (faster than PIL)
|
| 515 |
+
success, buffer = cv2.imencode('.jpg', img_array, [cv2.IMWRITE_JPEG_QUALITY, 95])
|
| 516 |
+
if success:
|
| 517 |
+
img_str = base64.b64encode(buffer).decode()
|
| 518 |
+
return f"data:image/jpeg;base64,{img_str}"
|
| 519 |
+
else:
|
| 520 |
+
logger.warning("Frame encoding failed, using fallback")
|
| 521 |
+
return ""
|
| 522 |
+
|
| 523 |
+
except Exception as e:
|
| 524 |
+
logger.error(f"Error converting observation to base64: {e}")
|
| 525 |
+
return ""
|
| 526 |
+
|
| 527 |
+
async def game_loop(self):
|
| 528 |
+
"""Main game loop that runs continuously"""
|
| 529 |
+
self.running = True
|
| 530 |
+
|
| 531 |
+
while self.running:
|
| 532 |
+
try:
|
| 533 |
+
# Check if models are ready
|
| 534 |
+
if not self.models_ready:
|
| 535 |
+
# Send loading status to clients
|
| 536 |
+
if connected_clients:
|
| 537 |
+
loading_data = {
|
| 538 |
+
'type': 'loading',
|
| 539 |
+
'status': self.loading_status,
|
| 540 |
+
'progress': self.download_progress,
|
| 541 |
+
'ready': False
|
| 542 |
+
}
|
| 543 |
+
disconnected = set()
|
| 544 |
+
for client in connected_clients.copy():
|
| 545 |
+
try:
|
| 546 |
+
await client.send_text(json.dumps(loading_data))
|
| 547 |
+
except:
|
| 548 |
+
disconnected.add(client)
|
| 549 |
+
connected_clients.difference_update(disconnected)
|
| 550 |
+
|
| 551 |
+
await asyncio.sleep(0.5) # Check every 500ms during loading
|
| 552 |
+
continue
|
| 553 |
+
|
| 554 |
+
# Only step environment if game is started
|
| 555 |
+
if not self.game_started:
|
| 556 |
+
# Game not started - just send current observation without stepping
|
| 557 |
+
should_send_frame = True if (self.obs is not None and connected_clients) else False
|
| 558 |
+
# Don't modify self.obs when game isn't started!
|
| 559 |
+
await asyncio.sleep(0.1)
|
| 560 |
+
else:
|
| 561 |
+
# Game is started - step environment
|
| 562 |
+
should_send_frame = True
|
| 563 |
+
if self.play_env is None:
|
| 564 |
+
await asyncio.sleep(0.1)
|
| 565 |
+
continue
|
| 566 |
+
|
| 567 |
+
# Step environment with current input state
|
| 568 |
+
next_obs, reward, done, truncated, info = self.step_environment()
|
| 569 |
+
|
| 570 |
+
if done or truncated:
|
| 571 |
+
# Auto-reset when episode ends
|
| 572 |
+
self.reset_environment()
|
| 573 |
+
else:
|
| 574 |
+
self.obs = next_obs
|
| 575 |
+
|
| 576 |
+
# Send frame to all connected clients with smart throttling for performance
|
| 577 |
+
current_time = self.time_module.time()
|
| 578 |
+
time_since_last_frame_send = current_time - self.last_frame_send_time
|
| 579 |
+
should_send_web_frame = time_since_last_frame_send >= (1.0 / self.web_fps)
|
| 580 |
+
|
| 581 |
+
if should_send_frame and should_send_web_frame and connected_clients and self.obs is not None:
|
| 582 |
+
# Set default values for when game isn't running
|
| 583 |
+
if not self.game_started:
|
| 584 |
+
reward = 0.0
|
| 585 |
+
info = {"waiting": True}
|
| 586 |
+
# If game is started, reward and info should be set above
|
| 587 |
+
|
| 588 |
+
# Convert observation to base64
|
| 589 |
+
image_data = self.obs_to_base64(self.obs)
|
| 590 |
+
|
| 591 |
+
# Debug logging for first few frames
|
| 592 |
+
if self.frame_count < 5:
|
| 593 |
+
logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
|
| 594 |
+
f"image_data_length={len(image_data) if image_data else 0}, "
|
| 595 |
+
f"game_started={self.game_started}")
|
| 596 |
+
|
| 597 |
+
frame_data = {
|
| 598 |
+
'type': 'frame',
|
| 599 |
+
'image': image_data,
|
| 600 |
+
'frame_count': self.frame_count,
|
| 601 |
+
'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
|
| 602 |
+
'info': str(info) if info else "",
|
| 603 |
+
'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
|
| 604 |
+
'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False,
|
| 605 |
+
'web_fps': self.web_fps, # Add web FPS for monitoring
|
| 606 |
+
'ai_target_fps': self.ai_fps # Add target AI FPS for monitoring
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
# Send to all connected clients
|
| 610 |
+
disconnected = set()
|
| 611 |
+
for client in connected_clients.copy():
|
| 612 |
+
try:
|
| 613 |
+
await client.send_text(json.dumps(frame_data))
|
| 614 |
+
except:
|
| 615 |
+
disconnected.add(client)
|
| 616 |
+
|
| 617 |
+
# Remove disconnected clients
|
| 618 |
+
connected_clients.difference_update(disconnected)
|
| 619 |
+
|
| 620 |
+
# Update frame send timing
|
| 621 |
+
self.last_frame_send_time = current_time
|
| 622 |
+
|
| 623 |
+
self.frame_count += 1
|
| 624 |
+
await asyncio.sleep(1.0 / self.fps) # Control FPS
|
| 625 |
+
|
| 626 |
+
except Exception as e:
|
| 627 |
+
logger.error(f"Error in game loop: {e}")
|
| 628 |
+
await asyncio.sleep(0.1)
|
| 629 |
+
|
| 630 |
+
# Global game engine instance
|
| 631 |
+
game_engine = WebGameEngine()
|
| 632 |
+
|
| 633 |
+
@app.on_event("startup")
|
| 634 |
+
async def startup_event():
|
| 635 |
+
"""Initialize models when the app starts"""
|
| 636 |
+
# Start the game loop immediately (it will handle loading state)
|
| 637 |
+
asyncio.create_task(game_engine.game_loop())
|
| 638 |
+
|
| 639 |
+
# Initialize models in background (non-blocking)
|
| 640 |
+
asyncio.create_task(game_engine.initialize_models())
|
| 641 |
+
|
| 642 |
+
@app.get("/performance")
|
| 643 |
+
async def get_performance_stats():
|
| 644 |
+
"""Get current performance statistics"""
|
| 645 |
+
current_time = game_engine.time_module.time()
|
| 646 |
+
elapsed_time = current_time - game_engine.start_time if game_engine.start_time > 0 else 0
|
| 647 |
+
|
| 648 |
+
return {
|
| 649 |
+
"ai_fps_current": game_engine.ai_frame_count / elapsed_time if elapsed_time > 0 else 0,
|
| 650 |
+
"ai_fps_target": game_engine.ai_fps,
|
| 651 |
+
"web_fps_target": game_engine.web_fps,
|
| 652 |
+
"display_fps_target": game_engine.fps,
|
| 653 |
+
"models_ready": game_engine.models_ready,
|
| 654 |
+
"actor_critic_loaded": game_engine.actor_critic_loaded,
|
| 655 |
+
"game_started": game_engine.game_started,
|
| 656 |
+
"connected_clients": len(connected_clients),
|
| 657 |
+
"total_ai_frames": game_engine.ai_frame_count,
|
| 658 |
+
"total_display_frames": game_engine.frame_count,
|
| 659 |
+
"elapsed_time": elapsed_time,
|
| 660 |
+
"torch_compile_enabled": os.environ.get("ENABLE_TORCH_COMPILE", "1") == "1",
|
| 661 |
+
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
@app.get("/", response_class=HTMLResponse)
|
| 665 |
+
async def get_homepage():
|
| 666 |
+
"""Serve the main game interface"""
|
| 667 |
+
html_content = """
|
| 668 |
+
<!DOCTYPE html>
|
| 669 |
+
<html>
|
| 670 |
+
<head>
|
| 671 |
+
<title>Physics-informed BEV World Model</title>
|
| 672 |
+
<style>
|
| 673 |
+
body {
|
| 674 |
+
margin: 0;
|
| 675 |
+
padding: 20px;
|
| 676 |
+
background: #1a1a1a;
|
| 677 |
+
color: white;
|
| 678 |
+
font-family: 'Courier New', monospace;
|
| 679 |
+
text-align: center;
|
| 680 |
+
}
|
| 681 |
+
#gameCanvas {
|
| 682 |
+
border: 2px solid #00ff00;
|
| 683 |
+
background: #000;
|
| 684 |
+
margin: 20px auto;
|
| 685 |
+
display: block;
|
| 686 |
+
}
|
| 687 |
+
#controls {
|
| 688 |
+
margin: 20px;
|
| 689 |
+
display: grid;
|
| 690 |
+
grid-template-columns: 1fr 1fr;
|
| 691 |
+
gap: 20px;
|
| 692 |
+
max-width: 800px;
|
| 693 |
+
margin: 20px auto;
|
| 694 |
+
}
|
| 695 |
+
.control-section {
|
| 696 |
+
background: #2a2a2a;
|
| 697 |
+
padding: 15px;
|
| 698 |
+
border-radius: 8px;
|
| 699 |
+
border: 1px solid #444;
|
| 700 |
+
}
|
| 701 |
+
.key-display {
|
| 702 |
+
background: #333;
|
| 703 |
+
border: 1px solid #555;
|
| 704 |
+
padding: 5px 10px;
|
| 705 |
+
margin: 2px;
|
| 706 |
+
border-radius: 4px;
|
| 707 |
+
display: inline-block;
|
| 708 |
+
min-width: 30px;
|
| 709 |
+
}
|
| 710 |
+
.key-pressed {
|
| 711 |
+
background: #00ff00;
|
| 712 |
+
color: #000;
|
| 713 |
+
}
|
| 714 |
+
#status {
|
| 715 |
+
margin: 10px;
|
| 716 |
+
padding: 10px;
|
| 717 |
+
background: #2a2a2a;
|
| 718 |
+
border-radius: 4px;
|
| 719 |
+
}
|
| 720 |
+
.info {
|
| 721 |
+
color: #00ff00;
|
| 722 |
+
margin: 5px 0;
|
| 723 |
+
}
|
| 724 |
+
</style>
|
| 725 |
+
</head>
|
| 726 |
+
<body>
|
| 727 |
+
<h1>🎮 Physics-informed BEV World Model</h1>
|
| 728 |
+
<p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
|
| 729 |
+
<p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p>
|
| 730 |
+
|
| 731 |
+
<!-- Model Download Progress -->
|
| 732 |
+
<div id="downloadSection" style="display: none; margin: 20px;">
|
| 733 |
+
<p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p>
|
| 734 |
+
<div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;">
|
| 735 |
+
<div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div>
|
| 736 |
+
</div>
|
| 737 |
+
<p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p>
|
| 738 |
+
</div>
|
| 739 |
+
|
| 740 |
+
<canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas>
|
| 741 |
+
|
| 742 |
+
<div id="status">
|
| 743 |
+
<div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
|
| 744 |
+
<div class="info">Game: <span id="gameStatus">Click to Start</span></div>
|
| 745 |
+
<div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div>
|
| 746 |
+
<div class="info">Reward: <span id="reward">0</span></div>
|
| 747 |
+
</div>
|
| 748 |
+
|
| 749 |
+
<div id="controls">
|
| 750 |
+
<div class="control-section">
|
| 751 |
+
<h3>Movement</h3>
|
| 752 |
+
<div>
|
| 753 |
+
<span class="key-display" id="key-w">W</span> Forward<br>
|
| 754 |
+
<span class="key-display" id="key-a">A</span> Left
|
| 755 |
+
<span class="key-display" id="key-s">S</span> Back
|
| 756 |
+
<span class="key-display" id="key-d">D</span> Right<br>
|
| 757 |
+
<span class="key-display" id="key-space">Space</span> Jump
|
| 758 |
+
<span class="key-display" id="key-ctrl">Ctrl</span> Crouch
|
| 759 |
+
<span class="key-display" id="key-shift">Shift</span> Walk
|
| 760 |
+
</div>
|
| 761 |
+
</div>
|
| 762 |
+
|
| 763 |
+
<div class="control-section">
|
| 764 |
+
<h3>Actions</h3>
|
| 765 |
+
<div>
|
| 766 |
+
<span class="key-display" id="key-1">1</span> Weapon 1<br>
|
| 767 |
+
<span class="key-display" id="key-2">2</span> Weapon 2
|
| 768 |
+
<span class="key-display" id="key-3">3</span> Weapon 3<br>
|
| 769 |
+
<span class="key-display" id="key-r">R</span> Reload<br>
|
| 770 |
+
<span class="key-display" id="key-arrows">↑↓←→</span> Camera<br>
|
| 771 |
+
<span class="key-display" id="key-enter">Enter</span> Reset Game<br>
|
| 772 |
+
<span class="key-display" id="key-esc">Esc</span> Pause/Quit
|
| 773 |
+
</div>
|
| 774 |
+
</div>
|
| 775 |
+
</div>
|
| 776 |
+
|
| 777 |
+
<script>
|
| 778 |
+
const canvas = document.getElementById('gameCanvas');
|
| 779 |
+
const ctx = canvas.getContext('2d');
|
| 780 |
+
const statusEl = document.getElementById('connectionStatus');
|
| 781 |
+
const gameStatusEl = document.getElementById('gameStatus');
|
| 782 |
+
const frameEl = document.getElementById('frameCount');
|
| 783 |
+
const aiFpsEl = document.getElementById('aiFps');
|
| 784 |
+
const rewardEl = document.getElementById('reward');
|
| 785 |
+
const loadingEl = document.getElementById('loadingIndicator');
|
| 786 |
+
const downloadSectionEl = document.getElementById('downloadSection');
|
| 787 |
+
const downloadStatusEl = document.getElementById('downloadStatus');
|
| 788 |
+
const progressBarEl = document.getElementById('progressBar');
|
| 789 |
+
const progressTextEl = document.getElementById('progressText');
|
| 790 |
+
|
| 791 |
+
let ws = null;
|
| 792 |
+
let pressedKeys = new Set();
|
| 793 |
+
let gameStarted = false;
|
| 794 |
+
|
| 795 |
+
// Key mapping
|
| 796 |
+
const keyDisplayMap = {
|
| 797 |
+
'KeyW': 'key-w',
|
| 798 |
+
'KeyA': 'key-a',
|
| 799 |
+
'KeyS': 'key-s',
|
| 800 |
+
'KeyD': 'key-d',
|
| 801 |
+
'Space': 'key-space',
|
| 802 |
+
'ControlLeft': 'key-ctrl',
|
| 803 |
+
'ShiftLeft': 'key-shift',
|
| 804 |
+
'Digit1': 'key-1',
|
| 805 |
+
'Digit2': 'key-2',
|
| 806 |
+
'Digit3': 'key-3',
|
| 807 |
+
'KeyR': 'key-r',
|
| 808 |
+
'ArrowUp': 'key-arrows',
|
| 809 |
+
'ArrowDown': 'key-arrows',
|
| 810 |
+
'ArrowLeft': 'key-arrows',
|
| 811 |
+
'ArrowRight': 'key-arrows',
|
| 812 |
+
'Enter': 'key-enter',
|
| 813 |
+
'Escape': 'key-esc'
|
| 814 |
+
};
|
| 815 |
+
|
| 816 |
+
function connectWebSocket() {
|
| 817 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 818 |
+
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
| 819 |
+
|
| 820 |
+
ws = new WebSocket(wsUrl);
|
| 821 |
+
|
| 822 |
+
ws.onopen = function(event) {
|
| 823 |
+
statusEl.textContent = 'Connected';
|
| 824 |
+
statusEl.style.color = '#00ff00';
|
| 825 |
+
// If user already clicked to start before WS was ready, send start now
|
| 826 |
+
if (gameStarted) {
|
| 827 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 828 |
+
}
|
| 829 |
+
};
|
| 830 |
+
|
| 831 |
+
ws.onmessage = function(event) {
|
| 832 |
+
const data = JSON.parse(event.data);
|
| 833 |
+
|
| 834 |
+
if (data.type === 'loading') {
|
| 835 |
+
// Handle loading status
|
| 836 |
+
downloadSectionEl.style.display = 'block';
|
| 837 |
+
downloadStatusEl.textContent = data.status;
|
| 838 |
+
|
| 839 |
+
if (data.progress !== undefined) {
|
| 840 |
+
progressBarEl.style.width = data.progress + '%';
|
| 841 |
+
progressTextEl.textContent = data.progress + '% - ' + data.status;
|
| 842 |
+
} else {
|
| 843 |
+
progressTextEl.textContent = data.status;
|
| 844 |
+
}
|
| 845 |
+
|
| 846 |
+
gameStatusEl.textContent = 'Loading Models...';
|
| 847 |
+
gameStatusEl.style.color = '#ffaa00';
|
| 848 |
+
|
| 849 |
+
} else if (data.type === 'frame') {
|
| 850 |
+
// Hide loading indicators once we get frames
|
| 851 |
+
downloadSectionEl.style.display = 'none';
|
| 852 |
+
// Update frame display
|
| 853 |
+
if (data.image) {
|
| 854 |
+
const img = new Image();
|
| 855 |
+
img.onload = function() {
|
| 856 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
| 857 |
+
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
| 858 |
+
};
|
| 859 |
+
img.src = data.image;
|
| 860 |
+
}
|
| 861 |
+
|
| 862 |
+
frameEl.textContent = data.frame_count;
|
| 863 |
+
rewardEl.textContent = data.reward.toFixed(2);
|
| 864 |
+
|
| 865 |
+
// Update AI FPS display and hide loading indicator once AI starts
|
| 866 |
+
if (data.ai_fps !== undefined && data.ai_fps !== null) {
|
| 867 |
+
// Ensure FPS value is reasonable
|
| 868 |
+
const aiFps = Math.min(Math.max(data.ai_fps, 0), 100);
|
| 869 |
+
aiFpsEl.textContent = aiFps.toFixed(1);
|
| 870 |
+
|
| 871 |
+
// Color code AI FPS for performance indication
|
| 872 |
+
if (aiFps >= 8) {
|
| 873 |
+
aiFpsEl.style.color = '#00ff00'; // Green for good performance
|
| 874 |
+
} else if (aiFps >= 5) {
|
| 875 |
+
aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance
|
| 876 |
+
} else if (aiFps > 0) {
|
| 877 |
+
aiFpsEl.style.color = '#ff0000'; // Red for poor performance
|
| 878 |
+
} else {
|
| 879 |
+
aiFpsEl.style.color = '#888888'; // Gray for inactive
|
| 880 |
+
}
|
| 881 |
+
|
| 882 |
+
// Hide loading indicator once AI inference starts working
|
| 883 |
+
if (aiFps > 0 && gameStarted) {
|
| 884 |
+
loadingEl.style.display = 'none';
|
| 885 |
+
gameStatusEl.textContent = 'Playing';
|
| 886 |
+
gameStatusEl.style.color = '#00ff00';
|
| 887 |
+
}
|
| 888 |
+
}
|
| 889 |
+
}
|
| 890 |
+
};
|
| 891 |
+
|
| 892 |
+
ws.onclose = function(event) {
|
| 893 |
+
statusEl.textContent = 'Disconnected';
|
| 894 |
+
statusEl.style.color = '#ff0000';
|
| 895 |
+
setTimeout(connectWebSocket, 1000); // Reconnect after 1 second
|
| 896 |
+
};
|
| 897 |
+
|
| 898 |
+
ws.onerror = function(event) {
|
| 899 |
+
statusEl.textContent = 'Error';
|
| 900 |
+
statusEl.style.color = '#ff0000';
|
| 901 |
+
};
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
function sendKeyState() {
|
| 905 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 906 |
+
ws.send(JSON.stringify({
|
| 907 |
+
type: 'keys',
|
| 908 |
+
keys: Array.from(pressedKeys)
|
| 909 |
+
}));
|
| 910 |
+
}
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
function startGame() {
|
| 914 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 915 |
+
ws.send(JSON.stringify({
|
| 916 |
+
type: 'start'
|
| 917 |
+
}));
|
| 918 |
+
gameStarted = true;
|
| 919 |
+
gameStatusEl.textContent = 'Starting AI...';
|
| 920 |
+
gameStatusEl.style.color = '#ffff00';
|
| 921 |
+
loadingEl.style.display = 'block';
|
| 922 |
+
console.log('Game started');
|
| 923 |
+
}
|
| 924 |
+
}
|
| 925 |
+
|
| 926 |
+
function pauseGame() {
|
| 927 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 928 |
+
ws.send(JSON.stringify({
|
| 929 |
+
type: 'pause'
|
| 930 |
+
}));
|
| 931 |
+
gameStarted = false;
|
| 932 |
+
gameStatusEl.textContent = 'Paused - Click to Resume';
|
| 933 |
+
gameStatusEl.style.color = '#ffff00';
|
| 934 |
+
console.log('Game paused');
|
| 935 |
+
}
|
| 936 |
+
}
|
| 937 |
+
|
| 938 |
+
function updateKeyDisplay() {
|
| 939 |
+
// Reset all key displays
|
| 940 |
+
Object.values(keyDisplayMap).forEach(id => {
|
| 941 |
+
const el = document.getElementById(id);
|
| 942 |
+
if (el) el.classList.remove('key-pressed');
|
| 943 |
+
});
|
| 944 |
+
|
| 945 |
+
// Highlight pressed keys
|
| 946 |
+
pressedKeys.forEach(key => {
|
| 947 |
+
const displayId = keyDisplayMap[key];
|
| 948 |
+
if (displayId) {
|
| 949 |
+
const el = document.getElementById(displayId);
|
| 950 |
+
if (el) el.classList.add('key-pressed');
|
| 951 |
+
}
|
| 952 |
+
});
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
// Focus canvas and handle keyboard events
|
| 956 |
+
canvas.addEventListener('click', () => {
|
| 957 |
+
canvas.focus();
|
| 958 |
+
if (!gameStarted) {
|
| 959 |
+
// Queue start locally and send immediately if WS is open
|
| 960 |
+
gameStarted = true;
|
| 961 |
+
gameStatusEl.textContent = 'Starting AI...';
|
| 962 |
+
gameStatusEl.style.color = '#ffff00';
|
| 963 |
+
loadingEl.style.display = 'block';
|
| 964 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 965 |
+
ws.send(JSON.stringify({ type: 'start' }));
|
| 966 |
+
}
|
| 967 |
+
}
|
| 968 |
+
});
|
| 969 |
+
|
| 970 |
+
canvas.addEventListener('keydown', (event) => {
|
| 971 |
+
event.preventDefault();
|
| 972 |
+
|
| 973 |
+
// Handle special keys
|
| 974 |
+
if (event.code === 'Enter') {
|
| 975 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 976 |
+
ws.send(JSON.stringify({
|
| 977 |
+
type: 'reset'
|
| 978 |
+
}));
|
| 979 |
+
console.log('Environment reset requested');
|
| 980 |
+
}
|
| 981 |
+
// Add to pressedKeys for visual feedback
|
| 982 |
+
pressedKeys.add(event.code);
|
| 983 |
+
updateKeyDisplay();
|
| 984 |
+
|
| 985 |
+
// Remove Enter from pressedKeys after a short delay for visual feedback
|
| 986 |
+
setTimeout(() => {
|
| 987 |
+
pressedKeys.delete(event.code);
|
| 988 |
+
updateKeyDisplay();
|
| 989 |
+
}, 200);
|
| 990 |
+
} else if (event.code === 'Escape') {
|
| 991 |
+
pauseGame();
|
| 992 |
+
// Add to pressedKeys for visual feedback
|
| 993 |
+
pressedKeys.add(event.code);
|
| 994 |
+
updateKeyDisplay();
|
| 995 |
+
|
| 996 |
+
// Remove ESC from pressedKeys after a short delay for visual feedback
|
| 997 |
+
setTimeout(() => {
|
| 998 |
+
pressedKeys.delete(event.code);
|
| 999 |
+
updateKeyDisplay();
|
| 1000 |
+
}, 200);
|
| 1001 |
+
} else {
|
| 1002 |
+
// Only send game keys if game is started
|
| 1003 |
+
if (gameStarted) {
|
| 1004 |
+
pressedKeys.add(event.code);
|
| 1005 |
+
updateKeyDisplay();
|
| 1006 |
+
sendKeyState();
|
| 1007 |
+
}
|
| 1008 |
+
}
|
| 1009 |
+
});
|
| 1010 |
+
|
| 1011 |
+
canvas.addEventListener('keyup', (event) => {
|
| 1012 |
+
event.preventDefault();
|
| 1013 |
+
|
| 1014 |
+
// Don't handle special keys release (handled in keydown with timeout)
|
| 1015 |
+
if (event.code !== 'Enter' && event.code !== 'Escape') {
|
| 1016 |
+
if (gameStarted) {
|
| 1017 |
+
pressedKeys.delete(event.code);
|
| 1018 |
+
updateKeyDisplay();
|
| 1019 |
+
sendKeyState();
|
| 1020 |
+
}
|
| 1021 |
+
}
|
| 1022 |
+
});
|
| 1023 |
+
|
| 1024 |
+
// Handle mouse events for clicks
|
| 1025 |
+
canvas.addEventListener('mousedown', (event) => {
|
| 1026 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 1027 |
+
ws.send(JSON.stringify({
|
| 1028 |
+
type: 'mouse',
|
| 1029 |
+
button: event.button,
|
| 1030 |
+
action: 'down',
|
| 1031 |
+
x: event.offsetX,
|
| 1032 |
+
y: event.offsetY
|
| 1033 |
+
}));
|
| 1034 |
+
}
|
| 1035 |
+
});
|
| 1036 |
+
|
| 1037 |
+
canvas.addEventListener('mouseup', (event) => {
|
| 1038 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 1039 |
+
ws.send(JSON.stringify({
|
| 1040 |
+
type: 'mouse',
|
| 1041 |
+
button: event.button,
|
| 1042 |
+
action: 'up',
|
| 1043 |
+
x: event.offsetX,
|
| 1044 |
+
y: event.offsetY
|
| 1045 |
+
}));
|
| 1046 |
+
}
|
| 1047 |
+
});
|
| 1048 |
+
|
| 1049 |
+
// Initialize
|
| 1050 |
+
connectWebSocket();
|
| 1051 |
+
canvas.focus();
|
| 1052 |
+
</script>
|
| 1053 |
+
</body>
|
| 1054 |
+
</html>
|
| 1055 |
+
"""
|
| 1056 |
+
return html_content
|
| 1057 |
+
|
| 1058 |
+
@app.websocket("/ws")
|
| 1059 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 1060 |
+
"""Handle WebSocket connections for real-time game communication"""
|
| 1061 |
+
await websocket.accept()
|
| 1062 |
+
connected_clients.add(websocket)
|
| 1063 |
+
|
| 1064 |
+
try:
|
| 1065 |
+
while True:
|
| 1066 |
+
# Receive messages from client
|
| 1067 |
+
data = await websocket.receive_text()
|
| 1068 |
+
message = json.loads(data)
|
| 1069 |
+
|
| 1070 |
+
if message['type'] == 'keys':
|
| 1071 |
+
# Update pressed keys
|
| 1072 |
+
game_engine.pressed_keys = set(message['keys'])
|
| 1073 |
+
|
| 1074 |
+
elif message['type'] == 'reset':
|
| 1075 |
+
# Handle environment reset request
|
| 1076 |
+
game_engine.request_reset()
|
| 1077 |
+
|
| 1078 |
+
elif message['type'] == 'start':
|
| 1079 |
+
# Handle game start request
|
| 1080 |
+
game_engine.start_game()
|
| 1081 |
+
|
| 1082 |
+
elif message['type'] == 'pause':
|
| 1083 |
+
# Handle game pause request
|
| 1084 |
+
game_engine.pause_game()
|
| 1085 |
+
|
| 1086 |
+
elif message['type'] == 'mouse':
|
| 1087 |
+
# Handle mouse events
|
| 1088 |
+
if message['action'] == 'down':
|
| 1089 |
+
if message['button'] == 0: # Left click
|
| 1090 |
+
game_engine.l_click = True
|
| 1091 |
+
elif message['button'] == 2: # Right click
|
| 1092 |
+
game_engine.r_click = True
|
| 1093 |
+
elif message['action'] == 'up':
|
| 1094 |
+
if message['button'] == 0: # Left click
|
| 1095 |
+
game_engine.l_click = False
|
| 1096 |
+
elif message['button'] == 2: # Right click
|
| 1097 |
+
game_engine.r_click = False
|
| 1098 |
+
|
| 1099 |
+
# Update mouse position (relative to canvas)
|
| 1100 |
+
game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px
|
| 1101 |
+
game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px
|
| 1102 |
+
|
| 1103 |
+
except WebSocketDisconnect:
|
| 1104 |
+
connected_clients.discard(websocket)
|
| 1105 |
+
except Exception as e:
|
| 1106 |
+
logger.error(f"WebSocket error: {e}")
|
| 1107 |
+
connected_clients.discard(websocket)
|
| 1108 |
+
|
| 1109 |
+
if __name__ == "__main__":
|
| 1110 |
+
# For local development
|
| 1111 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
| 1112 |
+
|
app_config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hugging Face Spaces configuration
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Hugging Face Spaces configuration
|
| 8 |
+
title = "Diamond CSGO AI Player"
|
| 9 |
+
description = """
|
| 10 |
+
# 🎮 Diamond CSGO AI Player
|
| 11 |
+
|
| 12 |
+
Experience an AI agent trained with diffusion models playing Counter-Strike: Global Offensive!
|
| 13 |
+
|
| 14 |
+
## How to Play
|
| 15 |
+
1. Click on the game canvas to focus it
|
| 16 |
+
2. Use keyboard controls:
|
| 17 |
+
- **WASD** - Movement
|
| 18 |
+
- **Space** - Jump
|
| 19 |
+
- **Ctrl** - Crouch
|
| 20 |
+
- **1,2,3** - Switch weapons
|
| 21 |
+
- **R** - Reload
|
| 22 |
+
- **Arrow Keys** - Camera movement
|
| 23 |
+
- **Mouse clicks** - Fire
|
| 24 |
+
3. Press **M** to switch between Human/AI control
|
| 25 |
+
|
| 26 |
+
## Technical Details
|
| 27 |
+
This demo showcases the Diamond framework, which combines:
|
| 28 |
+
- Diffusion models for world modeling
|
| 29 |
+
- Actor-critic reinforcement learning
|
| 30 |
+
- Multi-step planning in imagination
|
| 31 |
+
|
| 32 |
+
The AI learns to play by predicting future game states and optimizing actions through a learned world model.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# App settings
|
| 36 |
+
port = int(os.getenv("PORT", 7860))
|
| 37 |
+
host = "0.0.0.0"
|
copy.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Essential files to copy to your HF Space directory:
|
| 2 |
+
cp app.py /home/alienware3/Documents/diamond-ai-player/
|
| 3 |
+
cp requirements.txt /home/alienware3/Documents/diamond-ai-player/
|
| 4 |
+
cp Dockerfile /home/alienware3/Documents/diamond-ai-player/
|
| 5 |
+
cp packages.txt /home/alienware3/Documents/diamond-ai-player/
|
| 6 |
+
cp README.md /home/alienware3/Documents/diamond-ai-player/
|
| 7 |
+
cp config_web.py /home/alienware3/Documents/diamond-ai-player/
|
| 8 |
+
|
| 9 |
+
# Copy entire directories
|
| 10 |
+
cp -r src/ /home/alienware3/Documents/diamond-ai-player/
|
| 11 |
+
cp -r config/ /home/alienware3/Documents/diamond-ai-player/
|
| 12 |
+
cp -r csgo/ /home/alienware3/Documents/diamond-ai-player/
|
| 13 |
+
|
| 14 |
+
# Copy your trained model (choose one)
|
| 15 |
+
# cp agent_epoch_00003.pt /home/alienware3/Documents/diamond-ai-player/ # OR
|
| 16 |
+
#cp agent_epoch_00206.pt /home/alienware3/Documents/diamond-ai-player/
|
| 17 |
+
|
| 18 |
+
cd /home/alienware3/Documents/diamond-ai-player/
|
| 19 |
+
git add .
|
| 20 |
+
git commit -m "Fix initial bugs"
|
| 21 |
+
git push origin main
|
debug_app.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Debug version of the web app to isolate the black screen issue
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import base64
|
| 8 |
+
import io
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Optional, Set
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 19 |
+
from fastapi.responses import HTMLResponse
|
| 20 |
+
from PIL import Image
|
| 21 |
+
|
| 22 |
+
# Add src to Python path
|
| 23 |
+
src_path = Path(__file__).parent / "src"
|
| 24 |
+
sys.path.insert(0, str(src_path))
|
| 25 |
+
|
| 26 |
+
# Configure logging
|
| 27 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
# Simple debug app
|
| 31 |
+
app = FastAPI(title="Diamond CSGO AI Player - Debug")
|
| 32 |
+
|
| 33 |
+
class DebugGameEngine:
|
| 34 |
+
"""Simple debug version to test image rendering"""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.obs = None
|
| 38 |
+
self.frame_count = 0
|
| 39 |
+
self.initialized = False
|
| 40 |
+
|
| 41 |
+
def create_test_observation(self):
|
| 42 |
+
"""Create a test observation for debugging"""
|
| 43 |
+
# Create a test image with some pattern
|
| 44 |
+
height, width = 150, 600
|
| 45 |
+
|
| 46 |
+
# Create RGB pattern
|
| 47 |
+
img_array = np.zeros((height, width, 3), dtype=np.uint8)
|
| 48 |
+
|
| 49 |
+
# Add some pattern to make it visible
|
| 50 |
+
for y in range(height):
|
| 51 |
+
for x in range(width):
|
| 52 |
+
img_array[y, x, 0] = (x % 256) # Red gradient
|
| 53 |
+
img_array[y, x, 1] = (y % 256) # Green gradient
|
| 54 |
+
img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
|
| 55 |
+
|
| 56 |
+
# Convert to torch tensor in the expected format [-1, 1]
|
| 57 |
+
tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
|
| 58 |
+
tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
|
| 59 |
+
tensor = tensor.unsqueeze(0) # Add batch dimension
|
| 60 |
+
|
| 61 |
+
logger.info(f"Created test observation: {tensor.shape}, range: {tensor.min():.3f} to {tensor.max():.3f}")
|
| 62 |
+
return tensor
|
| 63 |
+
|
| 64 |
+
def obs_to_base64(self, obs: torch.Tensor) -> str:
|
| 65 |
+
"""Convert observation tensor to base64 image for web display"""
|
| 66 |
+
if obs is None:
|
| 67 |
+
logger.warning("Observation is None")
|
| 68 |
+
return ""
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
logger.debug(f"Converting obs: shape={obs.shape}, dtype={obs.dtype}")
|
| 72 |
+
|
| 73 |
+
# Convert tensor to PIL Image
|
| 74 |
+
if obs.ndim == 4 and obs.size(0) == 1:
|
| 75 |
+
img_array = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
|
| 76 |
+
else:
|
| 77 |
+
img_array = obs.add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
|
| 78 |
+
|
| 79 |
+
logger.debug(f"Image array: shape={img_array.shape}, range={img_array.min()} to {img_array.max()}")
|
| 80 |
+
|
| 81 |
+
img = Image.fromarray(img_array)
|
| 82 |
+
|
| 83 |
+
# Resize for web display
|
| 84 |
+
img = img.resize((600, 300), Image.BICUBIC)
|
| 85 |
+
|
| 86 |
+
# Convert to base64
|
| 87 |
+
buffer = io.BytesIO()
|
| 88 |
+
img.save(buffer, format='PNG')
|
| 89 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
| 90 |
+
|
| 91 |
+
logger.debug(f"Successfully converted to base64, length: {len(img_str)}")
|
| 92 |
+
return f"data:image/png;base64,{img_str}"
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Error converting observation to base64: {e}")
|
| 96 |
+
import traceback
|
| 97 |
+
traceback.print_exc()
|
| 98 |
+
return ""
|
| 99 |
+
|
| 100 |
+
async def initialize(self):
|
| 101 |
+
"""Initialize with test data"""
|
| 102 |
+
logger.info("Initializing debug game engine...")
|
| 103 |
+
self.obs = self.create_test_observation()
|
| 104 |
+
self.initialized = True
|
| 105 |
+
logger.info("Debug game engine initialized successfully!")
|
| 106 |
+
return True
|
| 107 |
+
|
| 108 |
+
# Global debug engine
|
| 109 |
+
debug_engine = DebugGameEngine()
|
| 110 |
+
connected_clients: Set[WebSocket] = set()
|
| 111 |
+
|
| 112 |
+
@app.on_event("startup")
|
| 113 |
+
async def startup_event():
|
| 114 |
+
"""Initialize debug engine"""
|
| 115 |
+
success = await debug_engine.initialize()
|
| 116 |
+
if success:
|
| 117 |
+
# Start a simple game loop
|
| 118 |
+
asyncio.create_task(debug_game_loop())
|
| 119 |
+
|
| 120 |
+
async def debug_game_loop():
|
| 121 |
+
"""Simple debug game loop"""
|
| 122 |
+
while True:
|
| 123 |
+
try:
|
| 124 |
+
if debug_engine.initialized and connected_clients:
|
| 125 |
+
# Send test frame to all connected clients
|
| 126 |
+
frame_data = {
|
| 127 |
+
'type': 'frame',
|
| 128 |
+
'image': debug_engine.obs_to_base64(debug_engine.obs),
|
| 129 |
+
'frame_count': debug_engine.frame_count,
|
| 130 |
+
'reward': 0.0,
|
| 131 |
+
'info': f"Debug frame {debug_engine.frame_count}"
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
logger.debug(f"Sending frame {debug_engine.frame_count}")
|
| 135 |
+
|
| 136 |
+
# Send to all connected clients
|
| 137 |
+
disconnected = set()
|
| 138 |
+
for client in connected_clients.copy():
|
| 139 |
+
try:
|
| 140 |
+
await client.send_text(json.dumps(frame_data))
|
| 141 |
+
except Exception as e:
|
| 142 |
+
logger.error(f"Error sending to client: {e}")
|
| 143 |
+
disconnected.add(client)
|
| 144 |
+
|
| 145 |
+
# Remove disconnected clients
|
| 146 |
+
connected_clients.difference_update(disconnected)
|
| 147 |
+
|
| 148 |
+
debug_engine.frame_count += 1
|
| 149 |
+
|
| 150 |
+
await asyncio.sleep(1.0 / 15) # 15 FPS
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"Error in debug game loop: {e}")
|
| 154 |
+
await asyncio.sleep(0.1)
|
| 155 |
+
|
| 156 |
+
@app.get("/", response_class=HTMLResponse)
|
| 157 |
+
async def get_homepage():
|
| 158 |
+
"""Serve debug interface"""
|
| 159 |
+
html_content = """
|
| 160 |
+
<!DOCTYPE html>
|
| 161 |
+
<html>
|
| 162 |
+
<head>
|
| 163 |
+
<title>Diamond CSGO AI - Debug</title>
|
| 164 |
+
<style>
|
| 165 |
+
body {
|
| 166 |
+
margin: 0;
|
| 167 |
+
padding: 20px;
|
| 168 |
+
background: #1a1a1a;
|
| 169 |
+
color: white;
|
| 170 |
+
font-family: monospace;
|
| 171 |
+
text-align: center;
|
| 172 |
+
}
|
| 173 |
+
#gameCanvas {
|
| 174 |
+
border: 2px solid #00ff00;
|
| 175 |
+
background: #000;
|
| 176 |
+
margin: 20px auto;
|
| 177 |
+
display: block;
|
| 178 |
+
}
|
| 179 |
+
#status {
|
| 180 |
+
margin: 10px;
|
| 181 |
+
padding: 10px;
|
| 182 |
+
background: #2a2a2a;
|
| 183 |
+
border-radius: 4px;
|
| 184 |
+
}
|
| 185 |
+
.info {
|
| 186 |
+
color: #00ff00;
|
| 187 |
+
margin: 5px 0;
|
| 188 |
+
}
|
| 189 |
+
</style>
|
| 190 |
+
</head>
|
| 191 |
+
<body>
|
| 192 |
+
<h1>🔧 Diamond CSGO AI - Debug Mode</h1>
|
| 193 |
+
<p>Testing image rendering and WebSocket communication</p>
|
| 194 |
+
|
| 195 |
+
<canvas id="gameCanvas" width="600" height="300"></canvas>
|
| 196 |
+
|
| 197 |
+
<div id="status">
|
| 198 |
+
<div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
|
| 199 |
+
<div class="info">Frame: <span id="frameCount">0</span></div>
|
| 200 |
+
<div class="info">Info: <span id="info">Waiting...</span></div>
|
| 201 |
+
</div>
|
| 202 |
+
|
| 203 |
+
<div id="debug">
|
| 204 |
+
<h3>Debug Information</h3>
|
| 205 |
+
<p id="lastUpdate">No updates yet</p>
|
| 206 |
+
<p id="imageInfo">No image data</p>
|
| 207 |
+
</div>
|
| 208 |
+
|
| 209 |
+
<script>
|
| 210 |
+
const canvas = document.getElementById('gameCanvas');
|
| 211 |
+
const ctx = canvas.getContext('2d');
|
| 212 |
+
const statusEl = document.getElementById('connectionStatus');
|
| 213 |
+
const frameEl = document.getElementById('frameCount');
|
| 214 |
+
const infoEl = document.getElementById('info');
|
| 215 |
+
const lastUpdateEl = document.getElementById('lastUpdate');
|
| 216 |
+
const imageInfoEl = document.getElementById('imageInfo');
|
| 217 |
+
|
| 218 |
+
let ws = null;
|
| 219 |
+
|
| 220 |
+
function connectWebSocket() {
|
| 221 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 222 |
+
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
| 223 |
+
console.log('Connecting to:', wsUrl);
|
| 224 |
+
|
| 225 |
+
ws = new WebSocket(wsUrl);
|
| 226 |
+
|
| 227 |
+
ws.onopen = function(event) {
|
| 228 |
+
console.log('WebSocket connected');
|
| 229 |
+
statusEl.textContent = 'Connected';
|
| 230 |
+
statusEl.style.color = '#00ff00';
|
| 231 |
+
};
|
| 232 |
+
|
| 233 |
+
ws.onmessage = function(event) {
|
| 234 |
+
console.log('Received message');
|
| 235 |
+
const data = JSON.parse(event.data);
|
| 236 |
+
|
| 237 |
+
if (data.type === 'frame') {
|
| 238 |
+
console.log('Received frame:', data.frame_count);
|
| 239 |
+
lastUpdateEl.textContent = `Last update: ${new Date().toLocaleTimeString()}`;
|
| 240 |
+
|
| 241 |
+
// Update frame display
|
| 242 |
+
if (data.image && data.image.length > 0) {
|
| 243 |
+
console.log('Loading image, length:', data.image.length);
|
| 244 |
+
imageInfoEl.textContent = `Image data length: ${data.image.length}`;
|
| 245 |
+
|
| 246 |
+
const img = new Image();
|
| 247 |
+
img.onload = function() {
|
| 248 |
+
console.log('Image loaded successfully');
|
| 249 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
| 250 |
+
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
| 251 |
+
};
|
| 252 |
+
img.onerror = function() {
|
| 253 |
+
console.error('Failed to load image');
|
| 254 |
+
imageInfoEl.textContent = 'Failed to load image';
|
| 255 |
+
};
|
| 256 |
+
img.src = data.image;
|
| 257 |
+
} else {
|
| 258 |
+
console.warn('No image data received');
|
| 259 |
+
imageInfoEl.textContent = 'No image data received';
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
frameEl.textContent = data.frame_count;
|
| 263 |
+
infoEl.textContent = data.info || 'No info';
|
| 264 |
+
}
|
| 265 |
+
};
|
| 266 |
+
|
| 267 |
+
ws.onclose = function(event) {
|
| 268 |
+
console.log('WebSocket disconnected');
|
| 269 |
+
statusEl.textContent = 'Disconnected';
|
| 270 |
+
statusEl.style.color = '#ff0000';
|
| 271 |
+
setTimeout(connectWebSocket, 1000);
|
| 272 |
+
};
|
| 273 |
+
|
| 274 |
+
ws.onerror = function(event) {
|
| 275 |
+
console.error('WebSocket error:', event);
|
| 276 |
+
statusEl.textContent = 'Error';
|
| 277 |
+
statusEl.style.color = '#ff0000';
|
| 278 |
+
};
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
// Initialize
|
| 282 |
+
connectWebSocket();
|
| 283 |
+
</script>
|
| 284 |
+
</body>
|
| 285 |
+
</html>
|
| 286 |
+
"""
|
| 287 |
+
return html_content
|
| 288 |
+
|
| 289 |
+
@app.websocket("/ws")
|
| 290 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 291 |
+
"""Handle WebSocket connections"""
|
| 292 |
+
await websocket.accept()
|
| 293 |
+
connected_clients.add(websocket)
|
| 294 |
+
logger.info(f"Client connected. Total clients: {len(connected_clients)}")
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
while True:
|
| 298 |
+
# Just wait for disconnection
|
| 299 |
+
await websocket.receive_text()
|
| 300 |
+
except WebSocketDisconnect:
|
| 301 |
+
connected_clients.discard(websocket)
|
| 302 |
+
logger.info(f"Client disconnected. Total clients: {len(connected_clients)}")
|
| 303 |
+
except Exception as e:
|
| 304 |
+
logger.error(f"WebSocket error: {e}")
|
| 305 |
+
connected_clients.discard(websocket)
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
import uvicorn
|
| 309 |
+
uvicorn.run("debug_app:app", host="0.0.0.0", port=7861, reload=False)
|
download_models.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to download models from Hugging Face Hub if not present locally
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
def download_checkpoint_if_needed():
|
| 14 |
+
"""Download model checkpoint if not present locally"""
|
| 15 |
+
# Check if we have any local checkpoints
|
| 16 |
+
possible_checkpoints = [
|
| 17 |
+
Path("agent_epoch_00206.pt"),
|
| 18 |
+
Path("agent_epoch_00003.pt"),
|
| 19 |
+
Path("checkpoints/agent_epoch_00206.pt"),
|
| 20 |
+
Path("checkpoints/agent_epoch_00003.pt"),
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
for ckpt_path in possible_checkpoints:
|
| 24 |
+
if ckpt_path.exists():
|
| 25 |
+
logger.info(f"Found local checkpoint: {ckpt_path}")
|
| 26 |
+
return True
|
| 27 |
+
|
| 28 |
+
logger.info("No local checkpoint found, attempting to download from Hugging Face Hub...")
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
from huggingface_hub import hf_hub_download
|
| 32 |
+
|
| 33 |
+
# This would download from a hypothetical HF model repository
|
| 34 |
+
# You would need to upload your models to HF Hub first
|
| 35 |
+
# Example:
|
| 36 |
+
# checkpoint_path = hf_hub_download(
|
| 37 |
+
# repo_id="your-username/diamond-csgo-model",
|
| 38 |
+
# filename="agent_epoch_00206.pt",
|
| 39 |
+
# cache_dir="./checkpoints"
|
| 40 |
+
# )
|
| 41 |
+
|
| 42 |
+
logger.warning("Model download not implemented yet.")
|
| 43 |
+
logger.warning("Please ensure you have model checkpoints available locally.")
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
except ImportError:
|
| 47 |
+
logger.error("huggingface_hub not installed. Cannot download models.")
|
| 48 |
+
return False
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.error(f"Failed to download models: {e}")
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
def setup_demo_data():
|
| 54 |
+
"""Set up minimal demo data if models are not available"""
|
| 55 |
+
spawn_dir = Path("csgo/spawn/0")
|
| 56 |
+
spawn_dir.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
|
| 58 |
+
# Create minimal dummy files for demo
|
| 59 |
+
import numpy as np
|
| 60 |
+
import json
|
| 61 |
+
|
| 62 |
+
files_to_create = {
|
| 63 |
+
"act.npy": np.zeros((100, 51)), # 100 timesteps, 51 actions
|
| 64 |
+
"low_res.npy": np.zeros((100, 3, 150, 600)), # 100 frames
|
| 65 |
+
"full_res.npy": np.zeros((100, 3, 300, 1200)), # 100 high-res frames
|
| 66 |
+
"next_act.npy": np.zeros((100, 51)),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
for filename, data in files_to_create.items():
|
| 70 |
+
file_path = spawn_dir / filename
|
| 71 |
+
if not file_path.exists():
|
| 72 |
+
np.save(file_path, data)
|
| 73 |
+
logger.info(f"Created dummy file: {file_path}")
|
| 74 |
+
|
| 75 |
+
# Create info.json
|
| 76 |
+
info_path = spawn_dir / "info.json"
|
| 77 |
+
if not info_path.exists():
|
| 78 |
+
info_data = {
|
| 79 |
+
"episode_length": 100,
|
| 80 |
+
"total_reward": 0.0,
|
| 81 |
+
"demo": True
|
| 82 |
+
}
|
| 83 |
+
with open(info_path, 'w') as f:
|
| 84 |
+
json.dump(info_data, f)
|
| 85 |
+
logger.info(f"Created info file: {info_path}")
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
logger.info("Setting up Diamond CSGO demo...")
|
| 89 |
+
|
| 90 |
+
# Try to download models
|
| 91 |
+
has_models = download_checkpoint_if_needed()
|
| 92 |
+
|
| 93 |
+
# Set up demo data
|
| 94 |
+
setup_demo_data()
|
| 95 |
+
|
| 96 |
+
if not has_models:
|
| 97 |
+
logger.warning("Running in demo mode without trained models.")
|
| 98 |
+
logger.warning("The AI agent will not function properly without model checkpoints.")
|
| 99 |
+
|
| 100 |
+
logger.info("Setup complete!")
|
low_res.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50e10ae563ce86e98c58de99c2d3a0a4c0f111f6fde2c87323bd1d5eb5210d1a
|
| 3 |
+
size 20288
|
npy_reader.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
def read_npy_file(file_path):
|
| 6 |
+
"""
|
| 7 |
+
Read a .npy file and return the numpy array.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
file_path (str): Path to the .npy file
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
numpy.ndarray: The loaded numpy array
|
| 14 |
+
|
| 15 |
+
Raises:
|
| 16 |
+
FileNotFoundError: If the file doesn't exist
|
| 17 |
+
ValueError: If the file is not a valid .npy file
|
| 18 |
+
"""
|
| 19 |
+
try:
|
| 20 |
+
# Check if file exists
|
| 21 |
+
if not os.path.exists(file_path):
|
| 22 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 23 |
+
|
| 24 |
+
# Check if file has .npy extension
|
| 25 |
+
if not file_path.lower().endswith('.npy'):
|
| 26 |
+
print(f"Warning: File {file_path} doesn't have .npy extension")
|
| 27 |
+
|
| 28 |
+
# Load the numpy array
|
| 29 |
+
array = np.load(file_path)
|
| 30 |
+
|
| 31 |
+
print(f"Successfully loaded {file_path}")
|
| 32 |
+
print(f"Array shape: {array.shape}")
|
| 33 |
+
print(f"Array dtype: {array.dtype}")
|
| 34 |
+
print(f"Array size: {array.size}")
|
| 35 |
+
|
| 36 |
+
return array
|
| 37 |
+
|
| 38 |
+
except Exception as e:
|
| 39 |
+
raise ValueError(f"Error reading {file_path}: {str(e)}")
|
| 40 |
+
|
| 41 |
+
def print_array_info(array, show_data=False, max_elements=10):
|
| 42 |
+
"""
|
| 43 |
+
Print detailed information about a numpy array.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
array (numpy.ndarray): The numpy array to analyze
|
| 47 |
+
show_data (bool): Whether to show actual data values
|
| 48 |
+
max_elements (int): Maximum number of elements to display
|
| 49 |
+
"""
|
| 50 |
+
print("\n" + "="*50)
|
| 51 |
+
print("ARRAY INFORMATION")
|
| 52 |
+
print("="*50)
|
| 53 |
+
print(f"Shape: {array.shape}")
|
| 54 |
+
print(f"Dtype: {array.dtype}")
|
| 55 |
+
print(f"Size: {array.size}")
|
| 56 |
+
print(f"Dimensions: {array.ndim}")
|
| 57 |
+
print(f"Memory usage: {array.nbytes} bytes")
|
| 58 |
+
|
| 59 |
+
if array.size > 0:
|
| 60 |
+
print(f"Min value: {np.min(array)}")
|
| 61 |
+
print(f"Max value: {np.max(array)}")
|
| 62 |
+
print(f"Mean value: {np.mean(array)}")
|
| 63 |
+
print(f"Std deviation: {np.std(array)}")
|
| 64 |
+
|
| 65 |
+
if show_data:
|
| 66 |
+
print("\nArray data:")
|
| 67 |
+
if array.size <= max_elements:
|
| 68 |
+
print(array)
|
| 69 |
+
else:
|
| 70 |
+
print(f"First {max_elements} elements:")
|
| 71 |
+
print(array.flat[:max_elements])
|
| 72 |
+
print("...")
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
"""
|
| 76 |
+
Main function to demonstrate usage.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
file_path = "/home/alienware3/Documents/diamond/low_res.npy"
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
# Read the NPY file
|
| 83 |
+
array = read_npy_file(file_path)
|
| 84 |
+
|
| 85 |
+
# Print detailed information
|
| 86 |
+
print_array_info(array, show_data=True, max_elements=20)
|
| 87 |
+
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"Error: {e}")
|
| 90 |
+
sys.exit(1)
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
main()
|
run_web_demo.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to run the web demo locally for testing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add project root to Python path
|
| 11 |
+
root_path = Path(__file__).parent
|
| 12 |
+
sys.path.insert(0, str(root_path))
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=logging.INFO,
|
| 17 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 18 |
+
)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
def main():
|
| 22 |
+
"""Run the web demo"""
|
| 23 |
+
logger.info("Starting Diamond CSGO AI Player web demo...")
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import uvicorn
|
| 27 |
+
from app import app
|
| 28 |
+
|
| 29 |
+
# Run the server
|
| 30 |
+
uvicorn.run(
|
| 31 |
+
app,
|
| 32 |
+
host="0.0.0.0",
|
| 33 |
+
port=7860,
|
| 34 |
+
log_level="info",
|
| 35 |
+
reload=False # Disable reload for production
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
except ImportError as e:
|
| 39 |
+
logger.error(f"Missing dependencies: {e}")
|
| 40 |
+
logger.error("Please install requirements: pip install -r requirements.txt")
|
| 41 |
+
sys.exit(1)
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Failed to start server: {e}")
|
| 45 |
+
sys.exit(1)
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
main()
|
test_web_app.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify the web app can be imported and basic functionality works
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
# Add src to Python path
|
| 11 |
+
src_path = Path(__file__).parent / "src"
|
| 12 |
+
sys.path.insert(0, str(src_path))
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
def test_imports():
|
| 18 |
+
"""Test if all required modules can be imported"""
|
| 19 |
+
logger.info("Testing imports...")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
# Test core modules
|
| 23 |
+
from config_web import web_config
|
| 24 |
+
logger.info("Web config imported")
|
| 25 |
+
|
| 26 |
+
from src.csgo.web_action_processing import WebCSGOAction
|
| 27 |
+
logger.info("Web action processing imported")
|
| 28 |
+
|
| 29 |
+
from src.game.web_play_env import WebPlayEnv
|
| 30 |
+
logger.info("Web play environment imported")
|
| 31 |
+
|
| 32 |
+
# Test web framework
|
| 33 |
+
import fastapi
|
| 34 |
+
import uvicorn
|
| 35 |
+
logger.info("Web framework imported")
|
| 36 |
+
|
| 37 |
+
# Test app
|
| 38 |
+
from app import app, WebGameEngine
|
| 39 |
+
logger.info("Main app imported")
|
| 40 |
+
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"Import failed: {e}")
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
def test_config():
|
| 48 |
+
"""Test configuration setup"""
|
| 49 |
+
logger.info("Testing configuration...")
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from config_web import web_config
|
| 53 |
+
|
| 54 |
+
# Test path resolution
|
| 55 |
+
config_path = web_config.get_config_path()
|
| 56 |
+
logger.info(f"Config path: {config_path}")
|
| 57 |
+
|
| 58 |
+
spawn_dir = web_config.get_spawn_dir()
|
| 59 |
+
logger.info(f"Spawn directory: {spawn_dir}")
|
| 60 |
+
|
| 61 |
+
checkpoint_path = web_config.get_checkpoint_path()
|
| 62 |
+
logger.info(f"Checkpoint path: {checkpoint_path}")
|
| 63 |
+
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Configuration test failed: {e}")
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
def test_action_processing():
|
| 71 |
+
"""Test web action processing"""
|
| 72 |
+
logger.info("Testing action processing...")
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
|
| 76 |
+
|
| 77 |
+
# Test key mapping
|
| 78 |
+
test_keys = {'KeyW', 'KeyA', 'Space'}
|
| 79 |
+
action_names = web_keys_to_csgo_action_names(test_keys)
|
| 80 |
+
logger.info(f"Key mapping: {test_keys} -> {action_names}")
|
| 81 |
+
|
| 82 |
+
# Test action creation
|
| 83 |
+
action = WebCSGOAction(
|
| 84 |
+
key_names=action_names,
|
| 85 |
+
mouse_x=10,
|
| 86 |
+
mouse_y=5,
|
| 87 |
+
l_click=False,
|
| 88 |
+
r_click=False
|
| 89 |
+
)
|
| 90 |
+
logger.info(f"Action created: {action}")
|
| 91 |
+
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"Action processing test failed: {e}")
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def main():
|
| 99 |
+
"""Run all tests"""
|
| 100 |
+
logger.info("Starting Diamond CSGO web app tests...")
|
| 101 |
+
|
| 102 |
+
tests = [
|
| 103 |
+
("Imports", test_imports),
|
| 104 |
+
("Configuration", test_config),
|
| 105 |
+
("Action Processing", test_action_processing),
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
passed = 0
|
| 109 |
+
total = len(tests)
|
| 110 |
+
|
| 111 |
+
for name, test_func in tests:
|
| 112 |
+
logger.info(f"\n--- Testing {name} ---")
|
| 113 |
+
if test_func():
|
| 114 |
+
logger.info(f"PASS: {name} test passed")
|
| 115 |
+
passed += 1
|
| 116 |
+
else:
|
| 117 |
+
logger.error(f"FAIL: {name} test failed")
|
| 118 |
+
|
| 119 |
+
logger.info(f"\n=== Test Results ===")
|
| 120 |
+
logger.info(f"Passed: {passed}/{total}")
|
| 121 |
+
|
| 122 |
+
if passed == total:
|
| 123 |
+
logger.info("All tests passed! The web app should work correctly.")
|
| 124 |
+
return True
|
| 125 |
+
else:
|
| 126 |
+
logger.error("Some tests failed. Please check the errors above.")
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
success = main()
|
| 131 |
+
sys.exit(0 if success else 1)
|