Spaces:
Sleeping
Sleeping
Commit
·
c64c726
1
Parent(s):
0f24197
Initial Diamond CSGO AI deployment
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +38 -0
- README.md +76 -11
- app.py +969 -0
- config/agent/csgo.yaml +34 -0
- config/env/csgo.yaml +7 -0
- config/trainer.yaml +9 -0
- config/world_model_env/fast.yaml +17 -0
- config_web.py +208 -0
- csgo/spawn/0/act.npy +3 -0
- csgo/spawn/0/full_res.npy +3 -0
- csgo/spawn/0/info.json +1 -0
- csgo/spawn/0/low_res.npy +3 -0
- csgo/spawn/0/next_act.npy +3 -0
- packages.txt +3 -0
- requirements.txt +32 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/__pycache__/agent.cpython-310.pyc +0 -0
- src/__pycache__/trainer.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/agent.py +74 -0
- src/coroutines/__init__.py +11 -0
- src/coroutines/__pycache__/__init__.cpython-310.pyc +0 -0
- src/coroutines/__pycache__/collector.cpython-310.pyc +0 -0
- src/coroutines/__pycache__/env_loop.cpython-310.pyc +0 -0
- src/coroutines/collector.py +126 -0
- src/coroutines/env_loop.py +74 -0
- src/csgo/__init__.py +0 -0
- src/csgo/__pycache__/__init__.cpython-310.pyc +0 -0
- src/csgo/__pycache__/action_processing.cpython-310.pyc +0 -0
- src/csgo/__pycache__/keymap.cpython-310.pyc +0 -0
- src/csgo/__pycache__/web_action_processing.cpython-310.pyc +0 -0
- src/csgo/action_processing.py +230 -0
- src/csgo/keymap.py +33 -0
- src/csgo/web_action_processing.py +167 -0
- src/data/__init__.py +6 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/batch.cpython-310.pyc +0 -0
- src/data/__pycache__/batch_sampler.cpython-310.pyc +0 -0
- src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- src/data/__pycache__/episode.cpython-310.pyc +0 -0
- src/data/__pycache__/segment.cpython-310.pyc +0 -0
- src/data/__pycache__/utils.cpython-310.pyc +0 -0
- src/data/batch.py +25 -0
- src/data/batch_sampler.py +72 -0
- src/data/dataset.py +202 -0
- src/data/episode.py +64 -0
- src/data/segment.py +30 -0
- src/data/utils.py +89 -0
- src/envs/__init__.py +2 -0
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.9 slim image
|
| 2 |
+
FROM python:3.9-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install system dependencies
|
| 8 |
+
RUN apt-get update && apt-get install -y \
|
| 9 |
+
build-essential \
|
| 10 |
+
curl \
|
| 11 |
+
software-properties-common \
|
| 12 |
+
git \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy requirements first for better caching
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Install Python dependencies
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy source code
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# Create necessary directories
|
| 25 |
+
RUN mkdir -p csgo/spawn config checkpoints cache
|
| 26 |
+
|
| 27 |
+
# Set environment variables
|
| 28 |
+
ENV PYTHONPATH=/app/src:/app
|
| 29 |
+
ENV CUDA_VISIBLE_DEVICES=""
|
| 30 |
+
|
| 31 |
+
# Expose port
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
# Health check
|
| 35 |
+
HEALTHCHECK CMD curl --fail http://localhost:7860/ || exit 1
|
| 36 |
+
|
| 37 |
+
# Run the application
|
| 38 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,76 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diamond CSGO AI Player 🎮
|
| 2 |
+
|
| 3 |
+
A web-based demo of the Diamond AI agent playing Counter-Strike: Global Offensive using diffusion models and reinforcement learning.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Real-time Keyboard Input**: Use standard WASD controls and other keys to interact
|
| 8 |
+
- **AI Agent**: Pre-trained agent using diffusion-based world models
|
| 9 |
+
- **Web Interface**: No installation required, play directly in your browser
|
| 10 |
+
- **Live Visualization**: See the AI's perspective and actions in real-time
|
| 11 |
+
|
| 12 |
+
## Controls
|
| 13 |
+
|
| 14 |
+
### Movement
|
| 15 |
+
- **W** - Move Forward
|
| 16 |
+
- **A** - Move Left
|
| 17 |
+
- **S** - Move Back
|
| 18 |
+
- **D** - Move Right
|
| 19 |
+
- **Space** - Jump
|
| 20 |
+
- **Ctrl** - Crouch
|
| 21 |
+
- **Shift** - Walk
|
| 22 |
+
|
| 23 |
+
### Actions
|
| 24 |
+
- **1, 2, 3** - Switch Weapons
|
| 25 |
+
- **R** - Reload
|
| 26 |
+
- **Arrow Keys** - Camera Movement
|
| 27 |
+
- **Left/Right Click** - Primary/Secondary Fire
|
| 28 |
+
|
| 29 |
+
### Game Controls
|
| 30 |
+
- **M** - Switch between Human/AI control
|
| 31 |
+
- **Enter** - Reset Environment
|
| 32 |
+
|
| 33 |
+
## How to Play
|
| 34 |
+
|
| 35 |
+
1. Click on the game canvas to focus it
|
| 36 |
+
2. Use keyboard controls to play
|
| 37 |
+
3. The AI agent will respond to your inputs in real-time
|
| 38 |
+
4. Switch to AI mode to watch the agent play autonomously
|
| 39 |
+
|
| 40 |
+
## Technical Details
|
| 41 |
+
|
| 42 |
+
This demo uses:
|
| 43 |
+
- **FastAPI + WebSocket** for real-time communication
|
| 44 |
+
- **PyTorch** for AI model inference
|
| 45 |
+
- **Diffusion Models** for next-frame prediction
|
| 46 |
+
- **World Model Environment** for simulation
|
| 47 |
+
|
| 48 |
+
The agent was trained using the Diamond framework, which combines:
|
| 49 |
+
- Diffusion-based world models
|
| 50 |
+
- Actor-critic reinforcement learning
|
| 51 |
+
- Multi-step planning and imagination
|
| 52 |
+
|
| 53 |
+
## Model Information
|
| 54 |
+
|
| 55 |
+
The AI agent uses several neural networks:
|
| 56 |
+
- **Denoiser**: Diffusion model for generating next observations
|
| 57 |
+
- **Upsampler**: High-resolution image generation
|
| 58 |
+
- **Reward/End Model**: Predicting game outcomes
|
| 59 |
+
- **Actor-Critic**: Action selection and value estimation
|
| 60 |
+
|
| 61 |
+
## Citation
|
| 62 |
+
|
| 63 |
+
This work is based on the Diamond framework. If you use this code, please cite:
|
| 64 |
+
|
| 65 |
+
```bibtex
|
| 66 |
+
@article{diamond2024,
|
| 67 |
+
title={Diamond: Diffusion for World Modeling},
|
| 68 |
+
author={[Authors]},
|
| 69 |
+
journal={[Journal]},
|
| 70 |
+
year={2024}
|
| 71 |
+
}
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## License
|
| 75 |
+
|
| 76 |
+
See LICENSE file for details.
|
app.py
ADDED
|
@@ -0,0 +1,969 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web-based Diamond CSGO AI Player for Hugging Face Spaces
|
| 3 |
+
Uses FastAPI + WebSocket for real-time keyboard input and game streaming
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import base64
|
| 8 |
+
import io
|
| 9 |
+
import json
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Optional, Set
|
| 14 |
+
|
| 15 |
+
import cv2
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import uvicorn
|
| 19 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 20 |
+
from fastapi.responses import HTMLResponse
|
| 21 |
+
from fastapi.staticfiles import StaticFiles
|
| 22 |
+
from hydra import compose, initialize
|
| 23 |
+
from hydra.utils import instantiate
|
| 24 |
+
from omegaconf import DictConfig, OmegaConf
|
| 25 |
+
from PIL import Image
|
| 26 |
+
|
| 27 |
+
# Import your modules
|
| 28 |
+
from src.agent import Agent
|
| 29 |
+
from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
|
| 30 |
+
from src.envs import WorldModelEnv
|
| 31 |
+
from src.game.web_play_env import WebPlayEnv
|
| 32 |
+
from config_web import web_config
|
| 33 |
+
|
| 34 |
+
# Configure logging
|
| 35 |
+
logging.basicConfig(level=logging.INFO)
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
# Global variables
|
| 39 |
+
app = FastAPI(title="Diamond CSGO AI Player")
|
| 40 |
+
connected_clients: Set[WebSocket] = set()
|
| 41 |
+
|
| 42 |
+
class WebKeyMap:
|
| 43 |
+
"""Map web key codes to pygame-like keys for CSGO actions"""
|
| 44 |
+
WEB_TO_CSGO = {
|
| 45 |
+
'KeyW': 'w',
|
| 46 |
+
'KeyA': 'a',
|
| 47 |
+
'KeyS': 's',
|
| 48 |
+
'KeyD': 'd',
|
| 49 |
+
'Space': 'space',
|
| 50 |
+
'ControlLeft': 'left ctrl',
|
| 51 |
+
'ShiftLeft': 'left shift',
|
| 52 |
+
'Digit1': '1',
|
| 53 |
+
'Digit2': '2',
|
| 54 |
+
'Digit3': '3',
|
| 55 |
+
'KeyR': 'r',
|
| 56 |
+
'ArrowUp': 'camera_up',
|
| 57 |
+
'ArrowDown': 'camera_down',
|
| 58 |
+
'ArrowLeft': 'camera_left',
|
| 59 |
+
'ArrowRight': 'camera_right'
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
class WebGameEngine:
|
| 63 |
+
"""Web-compatible game engine that replaces pygame functionality"""
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
self.play_env: Optional[WebPlayEnv] = None
|
| 67 |
+
self.obs = None
|
| 68 |
+
self.running = False
|
| 69 |
+
self.game_started = False
|
| 70 |
+
self.fps = 30 # Display FPS
|
| 71 |
+
self.ai_fps = 10 # AI inference FPS (slower than display for efficiency)
|
| 72 |
+
self.frame_count = 0
|
| 73 |
+
self.ai_frame_count = 0
|
| 74 |
+
self.last_ai_time = 0
|
| 75 |
+
self.start_time = 0 # Track when AI started for proper FPS calculation
|
| 76 |
+
self.pressed_keys: Set[str] = set()
|
| 77 |
+
self.mouse_x = 0
|
| 78 |
+
self.mouse_y = 0
|
| 79 |
+
self.l_click = False
|
| 80 |
+
self.r_click = False
|
| 81 |
+
self.should_reset = False
|
| 82 |
+
self.cached_obs = None # Cache last observation for frame skipping
|
| 83 |
+
self.first_inference_done = False # Track if first inference completed
|
| 84 |
+
self.models_ready = False # Track if models are loaded
|
| 85 |
+
self.download_progress = 0 # Track download progress (0-100)
|
| 86 |
+
self.loading_status = "Initializing..." # Loading status message
|
| 87 |
+
import time
|
| 88 |
+
self.time_module = time
|
| 89 |
+
|
| 90 |
+
async def _download_model_async(self, url, filepath):
|
| 91 |
+
"""Download model asynchronously with progress tracking"""
|
| 92 |
+
import asyncio
|
| 93 |
+
import concurrent.futures
|
| 94 |
+
import urllib.request
|
| 95 |
+
import os
|
| 96 |
+
|
| 97 |
+
def download_with_progress():
|
| 98 |
+
"""Download function that runs in thread pool"""
|
| 99 |
+
def progress_hook(block_num, block_size, total_size):
|
| 100 |
+
if total_size > 0:
|
| 101 |
+
progress = min(100, (block_num * block_size * 100) / total_size)
|
| 102 |
+
self.download_progress = int(progress)
|
| 103 |
+
if progress % 10 == 0: # Log every 10%
|
| 104 |
+
logger.info(f"Download progress: {self.download_progress}%")
|
| 105 |
+
|
| 106 |
+
urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
|
| 107 |
+
self.download_progress = 100
|
| 108 |
+
|
| 109 |
+
# Run download in thread pool to avoid blocking
|
| 110 |
+
loop = asyncio.get_event_loop()
|
| 111 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 112 |
+
await loop.run_in_executor(executor, download_with_progress)
|
| 113 |
+
|
| 114 |
+
logger.info("Model download completed!")
|
| 115 |
+
|
| 116 |
+
async def initialize_models(self):
|
| 117 |
+
"""Initialize the AI models and environment"""
|
| 118 |
+
try:
|
| 119 |
+
import torch
|
| 120 |
+
logger.info("Initializing models...")
|
| 121 |
+
|
| 122 |
+
# Setup environment and paths
|
| 123 |
+
web_config.setup_environment_variables()
|
| 124 |
+
web_config.create_default_configs()
|
| 125 |
+
|
| 126 |
+
config_path = web_config.get_config_path()
|
| 127 |
+
logger.info(f"Using config path: {config_path}")
|
| 128 |
+
|
| 129 |
+
# Convert to relative path for Hydra
|
| 130 |
+
import os
|
| 131 |
+
relative_config_path = os.path.relpath(config_path)
|
| 132 |
+
logger.info(f"Relative config path: {relative_config_path}")
|
| 133 |
+
|
| 134 |
+
with initialize(version_base="1.3", config_path=relative_config_path):
|
| 135 |
+
cfg = compose(config_name="trainer")
|
| 136 |
+
|
| 137 |
+
# Override config for deployment
|
| 138 |
+
cfg.agent = OmegaConf.load(config_path / "agent" / "csgo.yaml")
|
| 139 |
+
cfg.env = OmegaConf.load(config_path / "env" / "csgo.yaml")
|
| 140 |
+
|
| 141 |
+
# Use CPU if no GPU available (for free HF spaces)
|
| 142 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 143 |
+
logger.info(f"Using device: {device}")
|
| 144 |
+
|
| 145 |
+
# Load model checkpoint
|
| 146 |
+
checkpoint_path = web_config.get_checkpoint_path()
|
| 147 |
+
if not checkpoint_path.exists():
|
| 148 |
+
logger.warning(f"No checkpoint found at {checkpoint_path} - using dummy mode")
|
| 149 |
+
self._init_dummy_mode()
|
| 150 |
+
return True
|
| 151 |
+
|
| 152 |
+
# Get spawn directory
|
| 153 |
+
spawn_dir = web_config.get_spawn_dir()
|
| 154 |
+
|
| 155 |
+
# Initialize agent
|
| 156 |
+
num_actions = cfg.env.num_actions
|
| 157 |
+
agent = Agent(instantiate(cfg.agent, num_actions=num_actions)).to(device).eval()
|
| 158 |
+
|
| 159 |
+
# Try to load checkpoint (remote or local)
|
| 160 |
+
try:
|
| 161 |
+
# First try to download from Hugging Face Hub using direct URL
|
| 162 |
+
try:
|
| 163 |
+
import torch.hub
|
| 164 |
+
import os
|
| 165 |
+
logger.info("Downloading model from Hugging Face Hub...")
|
| 166 |
+
|
| 167 |
+
# Direct download URL (change 'blob' to 'resolve' for direct download)
|
| 168 |
+
model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
|
| 169 |
+
|
| 170 |
+
# Download to cache directory
|
| 171 |
+
cache_dir = "./cache"
|
| 172 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 173 |
+
model_cache_path = os.path.join(cache_dir, "agent_epoch_00003.pt")
|
| 174 |
+
|
| 175 |
+
# Download if not cached
|
| 176 |
+
if not os.path.exists(model_cache_path):
|
| 177 |
+
logger.info(f"Downloading 1.53GB model to {model_cache_path}...")
|
| 178 |
+
self.loading_status = "Downloading AI model from Hugging Face Hub..."
|
| 179 |
+
|
| 180 |
+
# Download with progress tracking in a separate thread
|
| 181 |
+
await self._download_model_async(model_url, model_cache_path)
|
| 182 |
+
else:
|
| 183 |
+
logger.info(f"Using cached model from {model_cache_path}")
|
| 184 |
+
self.loading_status = "Loading cached model..."
|
| 185 |
+
|
| 186 |
+
# Use the agent's load method which expects a file path
|
| 187 |
+
self.loading_status = "Loading model weights..."
|
| 188 |
+
agent.load(model_cache_path)
|
| 189 |
+
logger.info(f"Successfully loaded checkpoint from HF Hub")
|
| 190 |
+
|
| 191 |
+
except Exception as hub_error:
|
| 192 |
+
logger.warning(f"Failed to download from HF Hub: {hub_error}")
|
| 193 |
+
|
| 194 |
+
# Fallback to local checkpoint if available
|
| 195 |
+
if checkpoint_path.exists():
|
| 196 |
+
logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
|
| 197 |
+
agent.load(checkpoint_path)
|
| 198 |
+
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
| 199 |
+
else:
|
| 200 |
+
raise FileNotFoundError("No model checkpoint available (local or remote)")
|
| 201 |
+
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error(f"Failed to load any checkpoint: {e}")
|
| 204 |
+
self._init_dummy_mode()
|
| 205 |
+
return True
|
| 206 |
+
|
| 207 |
+
# Initialize world model environment
|
| 208 |
+
try:
|
| 209 |
+
sl = cfg.agent.denoiser.inner_model.num_steps_conditioning
|
| 210 |
+
if agent.upsampler is not None:
|
| 211 |
+
sl = max(sl, cfg.agent.upsampler.inner_model.num_steps_conditioning)
|
| 212 |
+
wm_env_cfg = instantiate(cfg.world_model_env, num_batches_to_preload=1)
|
| 213 |
+
wm_env = WorldModelEnv(agent.denoiser, agent.upsampler, agent.rew_end_model,
|
| 214 |
+
spawn_dir, 1, sl, wm_env_cfg, return_denoising_trajectory=True)
|
| 215 |
+
|
| 216 |
+
# Create play environment
|
| 217 |
+
self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
|
| 218 |
+
|
| 219 |
+
# Model compilation causes 10-30s delay on first inference, so make it optional
|
| 220 |
+
# You can enable it by setting ENABLE_TORCH_COMPILE=1 environment variable
|
| 221 |
+
import os
|
| 222 |
+
if device.type == "cuda" and os.getenv("ENABLE_TORCH_COMPILE", "0") == "1":
|
| 223 |
+
logger.info("Compiling models for faster inference (will cause delay on first inference)...")
|
| 224 |
+
try:
|
| 225 |
+
wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
|
| 226 |
+
if wm_env.upsample_next_obs is not None:
|
| 227 |
+
wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
|
| 228 |
+
logger.info("Model compilation enabled successfully!")
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logger.warning(f"Model compilation failed: {e}")
|
| 231 |
+
else:
|
| 232 |
+
logger.info("Model compilation disabled (faster startup). Set ENABLE_TORCH_COMPILE=1 to enable.")
|
| 233 |
+
|
| 234 |
+
# Reset environment
|
| 235 |
+
self.obs, _ = self.play_env.reset()
|
| 236 |
+
self.cached_obs = self.obs # Initialize cache
|
| 237 |
+
|
| 238 |
+
logger.info("Models initialized successfully!")
|
| 239 |
+
logger.info(f"Initial observation shape: {self.obs.shape if self.obs is not None else 'None'}")
|
| 240 |
+
self.models_ready = True
|
| 241 |
+
self.loading_status = "Ready!"
|
| 242 |
+
return True
|
| 243 |
+
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logger.error(f"Failed to initialize world model environment: {e}")
|
| 246 |
+
self._init_dummy_mode()
|
| 247 |
+
self.models_ready = True
|
| 248 |
+
self.loading_status = "Using dummy mode"
|
| 249 |
+
return True
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
logger.error(f"Failed to initialize models: {e}")
|
| 253 |
+
import traceback
|
| 254 |
+
traceback.print_exc()
|
| 255 |
+
self._init_dummy_mode()
|
| 256 |
+
self.models_ready = True
|
| 257 |
+
self.loading_status = "Error - using dummy mode"
|
| 258 |
+
return True
|
| 259 |
+
|
| 260 |
+
def _init_dummy_mode(self):
|
| 261 |
+
"""Initialize dummy mode for testing without models"""
|
| 262 |
+
logger.info("Initializing dummy mode...")
|
| 263 |
+
|
| 264 |
+
# Create a test observation
|
| 265 |
+
height, width = 150, 600
|
| 266 |
+
img_array = np.zeros((height, width, 3), dtype=np.uint8)
|
| 267 |
+
|
| 268 |
+
# Add test pattern
|
| 269 |
+
for y in range(height):
|
| 270 |
+
for x in range(width):
|
| 271 |
+
img_array[y, x, 0] = (x % 256) # Red gradient
|
| 272 |
+
img_array[y, x, 1] = (y % 256) # Green gradient
|
| 273 |
+
img_array[y, x, 2] = ((x + y) % 256) # Blue pattern
|
| 274 |
+
|
| 275 |
+
# Convert to torch tensor in expected format [-1, 1]
|
| 276 |
+
tensor = torch.from_numpy(img_array).float().permute(2, 0, 1) # CHW format
|
| 277 |
+
tensor = tensor.div(255).mul(2).sub(1) # Convert to [-1, 1] range
|
| 278 |
+
tensor = tensor.unsqueeze(0) # Add batch dimension
|
| 279 |
+
|
| 280 |
+
self.obs = tensor
|
| 281 |
+
self.play_env = None # No real environment in dummy mode
|
| 282 |
+
logger.info("Dummy mode initialized with test pattern")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def step_environment(self):
|
| 286 |
+
"""Step the environment with current input state (with intelligent frame skipping)"""
|
| 287 |
+
if self.play_env is None:
|
| 288 |
+
# Dummy mode - just return current observation
|
| 289 |
+
return self.obs, 0.0, False, False, {"mode": "dummy"}
|
| 290 |
+
|
| 291 |
+
try:
|
| 292 |
+
# Check if reset is requested
|
| 293 |
+
if self.should_reset:
|
| 294 |
+
self.reset_environment()
|
| 295 |
+
self.should_reset = False
|
| 296 |
+
self.last_ai_time = self.time_module.time() # Reset AI timer
|
| 297 |
+
return self.obs, 0.0, False, False, {"reset": True}
|
| 298 |
+
|
| 299 |
+
# Intelligent frame skipping: only run AI inference at target FPS
|
| 300 |
+
current_time = self.time_module.time()
|
| 301 |
+
time_since_last_ai = current_time - self.last_ai_time
|
| 302 |
+
should_run_ai = time_since_last_ai >= (1.0 / self.ai_fps)
|
| 303 |
+
|
| 304 |
+
if should_run_ai:
|
| 305 |
+
# Show loading indicator for first inference (can be slow)
|
| 306 |
+
if not self.first_inference_done:
|
| 307 |
+
logger.info("Running first AI inference (may take 5-15 seconds)...")
|
| 308 |
+
|
| 309 |
+
# Run AI inference
|
| 310 |
+
inference_start = self.time_module.time()
|
| 311 |
+
next_obs, reward, done, truncated, info = self.play_env.step_from_web_input(
|
| 312 |
+
pressed_keys=self.pressed_keys,
|
| 313 |
+
mouse_x=self.mouse_x,
|
| 314 |
+
mouse_y=self.mouse_y,
|
| 315 |
+
l_click=self.l_click,
|
| 316 |
+
r_click=self.r_click
|
| 317 |
+
)
|
| 318 |
+
inference_time = self.time_module.time() - inference_start
|
| 319 |
+
|
| 320 |
+
# Log first inference completion
|
| 321 |
+
if not self.first_inference_done:
|
| 322 |
+
self.first_inference_done = True
|
| 323 |
+
logger.info(f"First AI inference completed in {inference_time:.2f}s - subsequent inferences will be faster!")
|
| 324 |
+
|
| 325 |
+
# Cache the new observation and update timing
|
| 326 |
+
self.cached_obs = next_obs
|
| 327 |
+
self.last_ai_time = current_time
|
| 328 |
+
self.ai_frame_count += 1
|
| 329 |
+
|
| 330 |
+
# Add AI performance info
|
| 331 |
+
info = info or {}
|
| 332 |
+
info["ai_inference"] = True
|
| 333 |
+
|
| 334 |
+
# Calculate proper AI FPS: frames / elapsed time since start
|
| 335 |
+
elapsed_time = current_time - self.start_time
|
| 336 |
+
if elapsed_time > 0 and self.ai_frame_count > 0:
|
| 337 |
+
ai_fps = self.ai_frame_count / elapsed_time
|
| 338 |
+
# Cap at reasonable maximum (shouldn't exceed 100 FPS for AI inference)
|
| 339 |
+
info["ai_fps"] = min(ai_fps, 100.0)
|
| 340 |
+
else:
|
| 341 |
+
info["ai_fps"] = 0
|
| 342 |
+
|
| 343 |
+
info["inference_time"] = inference_time
|
| 344 |
+
|
| 345 |
+
return next_obs, reward, done, truncated, info
|
| 346 |
+
else:
|
| 347 |
+
# Use cached observation for smoother display without AI overhead
|
| 348 |
+
obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
|
| 349 |
+
|
| 350 |
+
# Calculate AI FPS for cached frames too
|
| 351 |
+
elapsed_time = current_time - self.start_time
|
| 352 |
+
if elapsed_time > 0 and self.ai_frame_count > 0:
|
| 353 |
+
ai_fps = min(self.ai_frame_count / elapsed_time, 100.0) # Cap at 100 FPS
|
| 354 |
+
else:
|
| 355 |
+
ai_fps = 0
|
| 356 |
+
|
| 357 |
+
return obs_to_return, 0.0, False, False, {"cached": True, "ai_fps": ai_fps}
|
| 358 |
+
|
| 359 |
+
except Exception as e:
|
| 360 |
+
logger.error(f"Error stepping environment: {e}")
|
| 361 |
+
obs_to_return = self.cached_obs if self.cached_obs is not None else self.obs
|
| 362 |
+
return obs_to_return, 0.0, False, False, {"error": str(e)}
|
| 363 |
+
|
| 364 |
+
def reset_environment(self):
|
| 365 |
+
"""Reset the environment"""
|
| 366 |
+
try:
|
| 367 |
+
if self.play_env is not None:
|
| 368 |
+
self.obs, _ = self.play_env.reset()
|
| 369 |
+
self.cached_obs = self.obs # Update cache
|
| 370 |
+
logger.info("Environment reset successfully")
|
| 371 |
+
else:
|
| 372 |
+
# Dummy mode - recreate test pattern
|
| 373 |
+
self._init_dummy_mode()
|
| 374 |
+
self.cached_obs = self.obs # Update cache
|
| 375 |
+
logger.info("Dummy environment reset")
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(f"Error resetting environment: {e}")
|
| 378 |
+
|
| 379 |
+
def request_reset(self):
|
| 380 |
+
"""Request environment reset on next step"""
|
| 381 |
+
self.should_reset = True
|
| 382 |
+
logger.info("Environment reset requested")
|
| 383 |
+
|
| 384 |
+
def start_game(self):
|
| 385 |
+
"""Start the game"""
|
| 386 |
+
self.game_started = True
|
| 387 |
+
self.start_time = self.time_module.time() # Reset start time for FPS calculation
|
| 388 |
+
self.ai_frame_count = 0 # Reset AI frame count
|
| 389 |
+
logger.info("Game started")
|
| 390 |
+
|
| 391 |
+
def pause_game(self):
|
| 392 |
+
"""Pause/stop the game"""
|
| 393 |
+
self.game_started = False
|
| 394 |
+
logger.info("Game paused")
|
| 395 |
+
|
| 396 |
+
def obs_to_base64(self, obs: torch.Tensor) -> str:
|
| 397 |
+
"""Convert observation tensor to base64 image for web display"""
|
| 398 |
+
if obs is None:
|
| 399 |
+
return ""
|
| 400 |
+
|
| 401 |
+
try:
|
| 402 |
+
# Convert tensor to PIL Image
|
| 403 |
+
if obs.ndim == 4 and obs.size(0) == 1:
|
| 404 |
+
img_array = obs[0].add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
|
| 405 |
+
else:
|
| 406 |
+
img_array = obs.add(1).div(2).mul(255).byte().permute(1, 2, 0).cpu().numpy()
|
| 407 |
+
|
| 408 |
+
img = Image.fromarray(img_array)
|
| 409 |
+
|
| 410 |
+
# Resize for web display to match canvas size (optimized)
|
| 411 |
+
img = img.resize((600, 150), Image.NEAREST) # NEAREST is faster than BICUBIC
|
| 412 |
+
|
| 413 |
+
# Optimized base64 conversion with JPEG for better compression/speed
|
| 414 |
+
buffer = io.BytesIO()
|
| 415 |
+
img.save(buffer, format='JPEG', quality=85, optimize=True) # JPEG is faster than PNG
|
| 416 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
| 417 |
+
return f"data:image/jpeg;base64,{img_str}"
|
| 418 |
+
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.error(f"Error converting observation to base64: {e}")
|
| 421 |
+
return ""
|
| 422 |
+
|
| 423 |
+
async def game_loop(self):
|
| 424 |
+
"""Main game loop that runs continuously"""
|
| 425 |
+
self.running = True
|
| 426 |
+
|
| 427 |
+
while self.running:
|
| 428 |
+
try:
|
| 429 |
+
# Check if models are ready
|
| 430 |
+
if not self.models_ready:
|
| 431 |
+
# Send loading status to clients
|
| 432 |
+
if connected_clients:
|
| 433 |
+
loading_data = {
|
| 434 |
+
'type': 'loading',
|
| 435 |
+
'status': self.loading_status,
|
| 436 |
+
'progress': self.download_progress,
|
| 437 |
+
'ready': False
|
| 438 |
+
}
|
| 439 |
+
disconnected = set()
|
| 440 |
+
for client in connected_clients.copy():
|
| 441 |
+
try:
|
| 442 |
+
await client.send_text(json.dumps(loading_data))
|
| 443 |
+
except:
|
| 444 |
+
disconnected.add(client)
|
| 445 |
+
connected_clients.difference_update(disconnected)
|
| 446 |
+
|
| 447 |
+
await asyncio.sleep(0.5) # Check every 500ms during loading
|
| 448 |
+
continue
|
| 449 |
+
|
| 450 |
+
# Always send frames, but only step environment if game is started
|
| 451 |
+
should_send_frame = True
|
| 452 |
+
|
| 453 |
+
if not self.game_started:
|
| 454 |
+
# Game not started - just send current observation without stepping
|
| 455 |
+
if self.obs is not None and connected_clients:
|
| 456 |
+
should_send_frame = True
|
| 457 |
+
else:
|
| 458 |
+
should_send_frame = False
|
| 459 |
+
await asyncio.sleep(0.1)
|
| 460 |
+
else:
|
| 461 |
+
# Game is started - step environment
|
| 462 |
+
if self.play_env is None:
|
| 463 |
+
await asyncio.sleep(0.1)
|
| 464 |
+
continue
|
| 465 |
+
|
| 466 |
+
# Step environment with current input state
|
| 467 |
+
next_obs, reward, done, truncated, info = self.step_environment()
|
| 468 |
+
|
| 469 |
+
if done or truncated:
|
| 470 |
+
# Auto-reset when episode ends
|
| 471 |
+
self.reset_environment()
|
| 472 |
+
else:
|
| 473 |
+
self.obs = next_obs
|
| 474 |
+
|
| 475 |
+
# Send frame to all connected clients (regardless of game state)
|
| 476 |
+
if should_send_frame and connected_clients and self.obs is not None:
|
| 477 |
+
# Set default values for when game isn't running
|
| 478 |
+
if not self.game_started:
|
| 479 |
+
reward = 0.0
|
| 480 |
+
info = {"waiting": True}
|
| 481 |
+
# If game is started, reward and info should be set above
|
| 482 |
+
|
| 483 |
+
# Convert observation to base64
|
| 484 |
+
image_data = self.obs_to_base64(self.obs)
|
| 485 |
+
|
| 486 |
+
# Debug logging for first few frames
|
| 487 |
+
if self.frame_count < 5:
|
| 488 |
+
logger.info(f"Frame {self.frame_count}: obs shape={self.obs.shape if self.obs is not None else 'None'}, "
|
| 489 |
+
f"image_data_length={len(image_data) if image_data else 0}, "
|
| 490 |
+
f"game_started={self.game_started}")
|
| 491 |
+
|
| 492 |
+
frame_data = {
|
| 493 |
+
'type': 'frame',
|
| 494 |
+
'image': image_data,
|
| 495 |
+
'frame_count': self.frame_count,
|
| 496 |
+
'reward': float(reward.item()) if hasattr(reward, 'item') else float(reward) if reward is not None else 0.0,
|
| 497 |
+
'info': str(info) if info else "",
|
| 498 |
+
'ai_fps': info.get('ai_fps', 0) if isinstance(info, dict) else 0,
|
| 499 |
+
'is_ai_frame': info.get('ai_inference', False) if isinstance(info, dict) else False
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
# Send to all connected clients
|
| 503 |
+
disconnected = set()
|
| 504 |
+
for client in connected_clients.copy():
|
| 505 |
+
try:
|
| 506 |
+
await client.send_text(json.dumps(frame_data))
|
| 507 |
+
except:
|
| 508 |
+
disconnected.add(client)
|
| 509 |
+
|
| 510 |
+
# Remove disconnected clients
|
| 511 |
+
connected_clients.difference_update(disconnected)
|
| 512 |
+
|
| 513 |
+
self.frame_count += 1
|
| 514 |
+
await asyncio.sleep(1.0 / self.fps) # Control FPS
|
| 515 |
+
|
| 516 |
+
except Exception as e:
|
| 517 |
+
logger.error(f"Error in game loop: {e}")
|
| 518 |
+
await asyncio.sleep(0.1)
|
| 519 |
+
|
| 520 |
+
# Global game engine instance
|
| 521 |
+
game_engine = WebGameEngine()
|
| 522 |
+
|
| 523 |
+
@app.on_event("startup")
|
| 524 |
+
async def startup_event():
|
| 525 |
+
"""Initialize models when the app starts"""
|
| 526 |
+
# Start the game loop immediately (it will handle loading state)
|
| 527 |
+
asyncio.create_task(game_engine.game_loop())
|
| 528 |
+
|
| 529 |
+
# Initialize models in background (non-blocking)
|
| 530 |
+
asyncio.create_task(game_engine.initialize_models())
|
| 531 |
+
|
| 532 |
+
@app.get("/", response_class=HTMLResponse)
|
| 533 |
+
async def get_homepage():
|
| 534 |
+
"""Serve the main game interface"""
|
| 535 |
+
html_content = """
|
| 536 |
+
<!DOCTYPE html>
|
| 537 |
+
<html>
|
| 538 |
+
<head>
|
| 539 |
+
<title>Diamond CSGO AI Player</title>
|
| 540 |
+
<style>
|
| 541 |
+
body {
|
| 542 |
+
margin: 0;
|
| 543 |
+
padding: 20px;
|
| 544 |
+
background: #1a1a1a;
|
| 545 |
+
color: white;
|
| 546 |
+
font-family: 'Courier New', monospace;
|
| 547 |
+
text-align: center;
|
| 548 |
+
}
|
| 549 |
+
#gameCanvas {
|
| 550 |
+
border: 2px solid #00ff00;
|
| 551 |
+
background: #000;
|
| 552 |
+
margin: 20px auto;
|
| 553 |
+
display: block;
|
| 554 |
+
}
|
| 555 |
+
#controls {
|
| 556 |
+
margin: 20px;
|
| 557 |
+
display: grid;
|
| 558 |
+
grid-template-columns: 1fr 1fr;
|
| 559 |
+
gap: 20px;
|
| 560 |
+
max-width: 800px;
|
| 561 |
+
margin: 20px auto;
|
| 562 |
+
}
|
| 563 |
+
.control-section {
|
| 564 |
+
background: #2a2a2a;
|
| 565 |
+
padding: 15px;
|
| 566 |
+
border-radius: 8px;
|
| 567 |
+
border: 1px solid #444;
|
| 568 |
+
}
|
| 569 |
+
.key-display {
|
| 570 |
+
background: #333;
|
| 571 |
+
border: 1px solid #555;
|
| 572 |
+
padding: 5px 10px;
|
| 573 |
+
margin: 2px;
|
| 574 |
+
border-radius: 4px;
|
| 575 |
+
display: inline-block;
|
| 576 |
+
min-width: 30px;
|
| 577 |
+
}
|
| 578 |
+
.key-pressed {
|
| 579 |
+
background: #00ff00;
|
| 580 |
+
color: #000;
|
| 581 |
+
}
|
| 582 |
+
#status {
|
| 583 |
+
margin: 10px;
|
| 584 |
+
padding: 10px;
|
| 585 |
+
background: #2a2a2a;
|
| 586 |
+
border-radius: 4px;
|
| 587 |
+
}
|
| 588 |
+
.info {
|
| 589 |
+
color: #00ff00;
|
| 590 |
+
margin: 5px 0;
|
| 591 |
+
}
|
| 592 |
+
</style>
|
| 593 |
+
</head>
|
| 594 |
+
<body>
|
| 595 |
+
<h1>🎮 Diamond CSGO AI Player</h1>
|
| 596 |
+
<p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
|
| 597 |
+
<p id="loadingIndicator" style="color: #ffff00; display: none;">🚀 Starting AI inference... This may take 5-15 seconds on first run.</p>
|
| 598 |
+
|
| 599 |
+
<!-- Model Download Progress -->
|
| 600 |
+
<div id="downloadSection" style="display: none; margin: 20px;">
|
| 601 |
+
<p id="downloadStatus" style="color: #ffaa00; margin: 10px 0;">📥 Downloading AI model...</p>
|
| 602 |
+
<div style="background: #333; border-radius: 10px; padding: 3px; width: 100%; max-width: 600px; margin: 0 auto;">
|
| 603 |
+
<div id="progressBar" style="background: linear-gradient(90deg, #00ff00, #88ff00); height: 20px; border-radius: 7px; width: 0%; transition: width 0.3s;"></div>
|
| 604 |
+
</div>
|
| 605 |
+
<p id="progressText" style="color: #aaa; font-size: 14px; margin: 5px 0;">0% - Initializing...</p>
|
| 606 |
+
</div>
|
| 607 |
+
|
| 608 |
+
<canvas id="gameCanvas" width="600" height="150" tabindex="0"></canvas>
|
| 609 |
+
|
| 610 |
+
<div id="status">
|
| 611 |
+
<div class="info">Status: <span id="connectionStatus">Connecting...</span></div>
|
| 612 |
+
<div class="info">Game: <span id="gameStatus">Click to Start</span></div>
|
| 613 |
+
<div class="info">Frame: <span id="frameCount">0</span> | AI FPS: <span id="aiFps">0</span></div>
|
| 614 |
+
<div class="info">Reward: <span id="reward">0</span></div>
|
| 615 |
+
</div>
|
| 616 |
+
|
| 617 |
+
<div id="controls">
|
| 618 |
+
<div class="control-section">
|
| 619 |
+
<h3>Movement</h3>
|
| 620 |
+
<div>
|
| 621 |
+
<span class="key-display" id="key-w">W</span> Forward<br>
|
| 622 |
+
<span class="key-display" id="key-a">A</span> Left
|
| 623 |
+
<span class="key-display" id="key-s">S</span> Back
|
| 624 |
+
<span class="key-display" id="key-d">D</span> Right<br>
|
| 625 |
+
<span class="key-display" id="key-space">Space</span> Jump
|
| 626 |
+
<span class="key-display" id="key-ctrl">Ctrl</span> Crouch
|
| 627 |
+
<span class="key-display" id="key-shift">Shift</span> Walk
|
| 628 |
+
</div>
|
| 629 |
+
</div>
|
| 630 |
+
|
| 631 |
+
<div class="control-section">
|
| 632 |
+
<h3>Actions</h3>
|
| 633 |
+
<div>
|
| 634 |
+
<span class="key-display" id="key-1">1</span> Weapon 1<br>
|
| 635 |
+
<span class="key-display" id="key-2">2</span> Weapon 2
|
| 636 |
+
<span class="key-display" id="key-3">3</span> Weapon 3<br>
|
| 637 |
+
<span class="key-display" id="key-r">R</span> Reload<br>
|
| 638 |
+
<span class="key-display" id="key-arrows">↑↓←→</span> Camera<br>
|
| 639 |
+
<span class="key-display" id="key-enter">Enter</span> Reset Game<br>
|
| 640 |
+
<span class="key-display" id="key-esc">Esc</span> Pause/Quit
|
| 641 |
+
</div>
|
| 642 |
+
</div>
|
| 643 |
+
</div>
|
| 644 |
+
|
| 645 |
+
<script>
|
| 646 |
+
const canvas = document.getElementById('gameCanvas');
|
| 647 |
+
const ctx = canvas.getContext('2d');
|
| 648 |
+
const statusEl = document.getElementById('connectionStatus');
|
| 649 |
+
const gameStatusEl = document.getElementById('gameStatus');
|
| 650 |
+
const frameEl = document.getElementById('frameCount');
|
| 651 |
+
const aiFpsEl = document.getElementById('aiFps');
|
| 652 |
+
const rewardEl = document.getElementById('reward');
|
| 653 |
+
const loadingEl = document.getElementById('loadingIndicator');
|
| 654 |
+
const downloadSectionEl = document.getElementById('downloadSection');
|
| 655 |
+
const downloadStatusEl = document.getElementById('downloadStatus');
|
| 656 |
+
const progressBarEl = document.getElementById('progressBar');
|
| 657 |
+
const progressTextEl = document.getElementById('progressText');
|
| 658 |
+
|
| 659 |
+
let ws = null;
|
| 660 |
+
let pressedKeys = new Set();
|
| 661 |
+
let gameStarted = false;
|
| 662 |
+
|
| 663 |
+
// Key mapping
|
| 664 |
+
const keyDisplayMap = {
|
| 665 |
+
'KeyW': 'key-w',
|
| 666 |
+
'KeyA': 'key-a',
|
| 667 |
+
'KeyS': 'key-s',
|
| 668 |
+
'KeyD': 'key-d',
|
| 669 |
+
'Space': 'key-space',
|
| 670 |
+
'ControlLeft': 'key-ctrl',
|
| 671 |
+
'ShiftLeft': 'key-shift',
|
| 672 |
+
'Digit1': 'key-1',
|
| 673 |
+
'Digit2': 'key-2',
|
| 674 |
+
'Digit3': 'key-3',
|
| 675 |
+
'KeyR': 'key-r',
|
| 676 |
+
'ArrowUp': 'key-arrows',
|
| 677 |
+
'ArrowDown': 'key-arrows',
|
| 678 |
+
'ArrowLeft': 'key-arrows',
|
| 679 |
+
'ArrowRight': 'key-arrows',
|
| 680 |
+
'Enter': 'key-enter',
|
| 681 |
+
'Escape': 'key-esc'
|
| 682 |
+
};
|
| 683 |
+
|
| 684 |
+
function connectWebSocket() {
|
| 685 |
+
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
| 686 |
+
const wsUrl = `${protocol}//${window.location.host}/ws`;
|
| 687 |
+
|
| 688 |
+
ws = new WebSocket(wsUrl);
|
| 689 |
+
|
| 690 |
+
ws.onopen = function(event) {
|
| 691 |
+
statusEl.textContent = 'Connected';
|
| 692 |
+
statusEl.style.color = '#00ff00';
|
| 693 |
+
};
|
| 694 |
+
|
| 695 |
+
ws.onmessage = function(event) {
|
| 696 |
+
const data = JSON.parse(event.data);
|
| 697 |
+
|
| 698 |
+
if (data.type === 'loading') {
|
| 699 |
+
// Handle loading status
|
| 700 |
+
downloadSectionEl.style.display = 'block';
|
| 701 |
+
downloadStatusEl.textContent = data.status;
|
| 702 |
+
|
| 703 |
+
if (data.progress !== undefined) {
|
| 704 |
+
progressBarEl.style.width = data.progress + '%';
|
| 705 |
+
progressTextEl.textContent = data.progress + '% - ' + data.status;
|
| 706 |
+
} else {
|
| 707 |
+
progressTextEl.textContent = data.status;
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
gameStatusEl.textContent = 'Loading Models...';
|
| 711 |
+
gameStatusEl.style.color = '#ffaa00';
|
| 712 |
+
|
| 713 |
+
} else if (data.type === 'frame') {
|
| 714 |
+
// Hide loading indicators once we get frames
|
| 715 |
+
downloadSectionEl.style.display = 'none';
|
| 716 |
+
// Update frame display
|
| 717 |
+
if (data.image) {
|
| 718 |
+
const img = new Image();
|
| 719 |
+
img.onload = function() {
|
| 720 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
| 721 |
+
ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
|
| 722 |
+
};
|
| 723 |
+
img.src = data.image;
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
+
frameEl.textContent = data.frame_count;
|
| 727 |
+
rewardEl.textContent = data.reward.toFixed(2);
|
| 728 |
+
|
| 729 |
+
// Update AI FPS display and hide loading indicator once AI starts
|
| 730 |
+
if (data.ai_fps !== undefined && data.ai_fps !== null) {
|
| 731 |
+
// Ensure FPS value is reasonable
|
| 732 |
+
const aiFps = Math.min(Math.max(data.ai_fps, 0), 100);
|
| 733 |
+
aiFpsEl.textContent = aiFps.toFixed(1);
|
| 734 |
+
|
| 735 |
+
// Color code AI FPS for performance indication
|
| 736 |
+
if (aiFps >= 8) {
|
| 737 |
+
aiFpsEl.style.color = '#00ff00'; // Green for good performance
|
| 738 |
+
} else if (aiFps >= 5) {
|
| 739 |
+
aiFpsEl.style.color = '#ffff00'; // Yellow for moderate performance
|
| 740 |
+
} else if (aiFps > 0) {
|
| 741 |
+
aiFpsEl.style.color = '#ff0000'; // Red for poor performance
|
| 742 |
+
} else {
|
| 743 |
+
aiFpsEl.style.color = '#888888'; // Gray for inactive
|
| 744 |
+
}
|
| 745 |
+
|
| 746 |
+
// Hide loading indicator once AI inference starts working
|
| 747 |
+
if (aiFps > 0 && gameStarted) {
|
| 748 |
+
loadingEl.style.display = 'none';
|
| 749 |
+
gameStatusEl.textContent = 'Playing';
|
| 750 |
+
gameStatusEl.style.color = '#00ff00';
|
| 751 |
+
}
|
| 752 |
+
}
|
| 753 |
+
}
|
| 754 |
+
};
|
| 755 |
+
|
| 756 |
+
ws.onclose = function(event) {
|
| 757 |
+
statusEl.textContent = 'Disconnected';
|
| 758 |
+
statusEl.style.color = '#ff0000';
|
| 759 |
+
setTimeout(connectWebSocket, 1000); // Reconnect after 1 second
|
| 760 |
+
};
|
| 761 |
+
|
| 762 |
+
ws.onerror = function(event) {
|
| 763 |
+
statusEl.textContent = 'Error';
|
| 764 |
+
statusEl.style.color = '#ff0000';
|
| 765 |
+
};
|
| 766 |
+
}
|
| 767 |
+
|
| 768 |
+
function sendKeyState() {
|
| 769 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 770 |
+
ws.send(JSON.stringify({
|
| 771 |
+
type: 'keys',
|
| 772 |
+
keys: Array.from(pressedKeys)
|
| 773 |
+
}));
|
| 774 |
+
}
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
function startGame() {
|
| 778 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 779 |
+
ws.send(JSON.stringify({
|
| 780 |
+
type: 'start'
|
| 781 |
+
}));
|
| 782 |
+
gameStarted = true;
|
| 783 |
+
gameStatusEl.textContent = 'Starting AI...';
|
| 784 |
+
gameStatusEl.style.color = '#ffff00';
|
| 785 |
+
loadingEl.style.display = 'block';
|
| 786 |
+
console.log('Game started');
|
| 787 |
+
}
|
| 788 |
+
}
|
| 789 |
+
|
| 790 |
+
function pauseGame() {
|
| 791 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 792 |
+
ws.send(JSON.stringify({
|
| 793 |
+
type: 'pause'
|
| 794 |
+
}));
|
| 795 |
+
gameStarted = false;
|
| 796 |
+
gameStatusEl.textContent = 'Paused - Click to Resume';
|
| 797 |
+
gameStatusEl.style.color = '#ffff00';
|
| 798 |
+
console.log('Game paused');
|
| 799 |
+
}
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
function updateKeyDisplay() {
|
| 803 |
+
// Reset all key displays
|
| 804 |
+
Object.values(keyDisplayMap).forEach(id => {
|
| 805 |
+
const el = document.getElementById(id);
|
| 806 |
+
if (el) el.classList.remove('key-pressed');
|
| 807 |
+
});
|
| 808 |
+
|
| 809 |
+
// Highlight pressed keys
|
| 810 |
+
pressedKeys.forEach(key => {
|
| 811 |
+
const displayId = keyDisplayMap[key];
|
| 812 |
+
if (displayId) {
|
| 813 |
+
const el = document.getElementById(displayId);
|
| 814 |
+
if (el) el.classList.add('key-pressed');
|
| 815 |
+
}
|
| 816 |
+
});
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
// Focus canvas and handle keyboard events
|
| 820 |
+
canvas.addEventListener('click', () => {
|
| 821 |
+
canvas.focus();
|
| 822 |
+
if (!gameStarted) {
|
| 823 |
+
startGame();
|
| 824 |
+
}
|
| 825 |
+
});
|
| 826 |
+
|
| 827 |
+
canvas.addEventListener('keydown', (event) => {
|
| 828 |
+
event.preventDefault();
|
| 829 |
+
|
| 830 |
+
// Handle special keys
|
| 831 |
+
if (event.code === 'Enter') {
|
| 832 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 833 |
+
ws.send(JSON.stringify({
|
| 834 |
+
type: 'reset'
|
| 835 |
+
}));
|
| 836 |
+
console.log('Environment reset requested');
|
| 837 |
+
}
|
| 838 |
+
// Add to pressedKeys for visual feedback
|
| 839 |
+
pressedKeys.add(event.code);
|
| 840 |
+
updateKeyDisplay();
|
| 841 |
+
|
| 842 |
+
// Remove Enter from pressedKeys after a short delay for visual feedback
|
| 843 |
+
setTimeout(() => {
|
| 844 |
+
pressedKeys.delete(event.code);
|
| 845 |
+
updateKeyDisplay();
|
| 846 |
+
}, 200);
|
| 847 |
+
} else if (event.code === 'Escape') {
|
| 848 |
+
pauseGame();
|
| 849 |
+
// Add to pressedKeys for visual feedback
|
| 850 |
+
pressedKeys.add(event.code);
|
| 851 |
+
updateKeyDisplay();
|
| 852 |
+
|
| 853 |
+
// Remove ESC from pressedKeys after a short delay for visual feedback
|
| 854 |
+
setTimeout(() => {
|
| 855 |
+
pressedKeys.delete(event.code);
|
| 856 |
+
updateKeyDisplay();
|
| 857 |
+
}, 200);
|
| 858 |
+
} else {
|
| 859 |
+
// Only send game keys if game is started
|
| 860 |
+
if (gameStarted) {
|
| 861 |
+
pressedKeys.add(event.code);
|
| 862 |
+
updateKeyDisplay();
|
| 863 |
+
sendKeyState();
|
| 864 |
+
}
|
| 865 |
+
}
|
| 866 |
+
});
|
| 867 |
+
|
| 868 |
+
canvas.addEventListener('keyup', (event) => {
|
| 869 |
+
event.preventDefault();
|
| 870 |
+
|
| 871 |
+
// Don't handle special keys release (handled in keydown with timeout)
|
| 872 |
+
if (event.code !== 'Enter' && event.code !== 'Escape') {
|
| 873 |
+
if (gameStarted) {
|
| 874 |
+
pressedKeys.delete(event.code);
|
| 875 |
+
updateKeyDisplay();
|
| 876 |
+
sendKeyState();
|
| 877 |
+
}
|
| 878 |
+
}
|
| 879 |
+
});
|
| 880 |
+
|
| 881 |
+
// Handle mouse events for clicks
|
| 882 |
+
canvas.addEventListener('mousedown', (event) => {
|
| 883 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 884 |
+
ws.send(JSON.stringify({
|
| 885 |
+
type: 'mouse',
|
| 886 |
+
button: event.button,
|
| 887 |
+
action: 'down',
|
| 888 |
+
x: event.offsetX,
|
| 889 |
+
y: event.offsetY
|
| 890 |
+
}));
|
| 891 |
+
}
|
| 892 |
+
});
|
| 893 |
+
|
| 894 |
+
canvas.addEventListener('mouseup', (event) => {
|
| 895 |
+
if (ws && ws.readyState === WebSocket.OPEN) {
|
| 896 |
+
ws.send(JSON.stringify({
|
| 897 |
+
type: 'mouse',
|
| 898 |
+
button: event.button,
|
| 899 |
+
action: 'up',
|
| 900 |
+
x: event.offsetX,
|
| 901 |
+
y: event.offsetY
|
| 902 |
+
}));
|
| 903 |
+
}
|
| 904 |
+
});
|
| 905 |
+
|
| 906 |
+
// Initialize
|
| 907 |
+
connectWebSocket();
|
| 908 |
+
canvas.focus();
|
| 909 |
+
</script>
|
| 910 |
+
</body>
|
| 911 |
+
</html>
|
| 912 |
+
"""
|
| 913 |
+
return html_content
|
| 914 |
+
|
| 915 |
+
@app.websocket("/ws")
|
| 916 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 917 |
+
"""Handle WebSocket connections for real-time game communication"""
|
| 918 |
+
await websocket.accept()
|
| 919 |
+
connected_clients.add(websocket)
|
| 920 |
+
|
| 921 |
+
try:
|
| 922 |
+
while True:
|
| 923 |
+
# Receive messages from client
|
| 924 |
+
data = await websocket.receive_text()
|
| 925 |
+
message = json.loads(data)
|
| 926 |
+
|
| 927 |
+
if message['type'] == 'keys':
|
| 928 |
+
# Update pressed keys
|
| 929 |
+
game_engine.pressed_keys = set(message['keys'])
|
| 930 |
+
|
| 931 |
+
elif message['type'] == 'reset':
|
| 932 |
+
# Handle environment reset request
|
| 933 |
+
game_engine.request_reset()
|
| 934 |
+
|
| 935 |
+
elif message['type'] == 'start':
|
| 936 |
+
# Handle game start request
|
| 937 |
+
game_engine.start_game()
|
| 938 |
+
|
| 939 |
+
elif message['type'] == 'pause':
|
| 940 |
+
# Handle game pause request
|
| 941 |
+
game_engine.pause_game()
|
| 942 |
+
|
| 943 |
+
elif message['type'] == 'mouse':
|
| 944 |
+
# Handle mouse events
|
| 945 |
+
if message['action'] == 'down':
|
| 946 |
+
if message['button'] == 0: # Left click
|
| 947 |
+
game_engine.l_click = True
|
| 948 |
+
elif message['button'] == 2: # Right click
|
| 949 |
+
game_engine.r_click = True
|
| 950 |
+
elif message['action'] == 'up':
|
| 951 |
+
if message['button'] == 0: # Left click
|
| 952 |
+
game_engine.l_click = False
|
| 953 |
+
elif message['button'] == 2: # Right click
|
| 954 |
+
game_engine.r_click = False
|
| 955 |
+
|
| 956 |
+
# Update mouse position (relative to canvas)
|
| 957 |
+
game_engine.mouse_x = message.get('x', 0) - 300 # Center at 300px
|
| 958 |
+
game_engine.mouse_y = message.get('y', 0) - 150 # Center at 150px
|
| 959 |
+
|
| 960 |
+
except WebSocketDisconnect:
|
| 961 |
+
connected_clients.discard(websocket)
|
| 962 |
+
except Exception as e:
|
| 963 |
+
logger.error(f"WebSocket error: {e}")
|
| 964 |
+
connected_clients.discard(websocket)
|
| 965 |
+
|
| 966 |
+
if __name__ == "__main__":
|
| 967 |
+
# For local development
|
| 968 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
| 969 |
+
|
config/agent/csgo.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: agent.AgentConfig
|
| 2 |
+
|
| 3 |
+
denoiser:
|
| 4 |
+
_target_: models.diffusion.DenoiserConfig
|
| 5 |
+
sigma_data: 0.5
|
| 6 |
+
sigma_offset_noise: 0.1
|
| 7 |
+
noise_previous_obs: true
|
| 8 |
+
upsampling_factor: null
|
| 9 |
+
inner_model:
|
| 10 |
+
_target_: models.diffusion.InnerModelConfig
|
| 11 |
+
img_channels: 3
|
| 12 |
+
num_steps_conditioning: 4
|
| 13 |
+
cond_channels: 2048
|
| 14 |
+
depths: [2, 2, 2, 2]
|
| 15 |
+
channels: [128, 256, 512, 1024]
|
| 16 |
+
attn_depths: [0, 0, 1, 1]
|
| 17 |
+
|
| 18 |
+
upsampler:
|
| 19 |
+
_target_: models.diffusion.DenoiserConfig
|
| 20 |
+
sigma_data: 0.5
|
| 21 |
+
sigma_offset_noise: 0.1
|
| 22 |
+
noise_previous_obs: false
|
| 23 |
+
upsampling_factor: 5
|
| 24 |
+
inner_model:
|
| 25 |
+
_target_: models.diffusion.InnerModelConfig
|
| 26 |
+
img_channels: 3
|
| 27 |
+
num_steps_conditioning: 1
|
| 28 |
+
cond_channels: 2048
|
| 29 |
+
depths: [2, 2, 2, 2]
|
| 30 |
+
channels: [64, 64, 128, 256]
|
| 31 |
+
attn_depths: [0, 0, 0, 1]
|
| 32 |
+
|
| 33 |
+
rew_end_model: null
|
| 34 |
+
actor_critic: null
|
config/env/csgo.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
train:
|
| 2 |
+
id: csgo
|
| 3 |
+
size: [150, 600]
|
| 4 |
+
num_actions: 51
|
| 5 |
+
path_data_low_res: /tmp/dummy_data_low_res
|
| 6 |
+
path_data_full_res: /tmp/dummy_data_full_res
|
| 7 |
+
keymap: csgo
|
config/trainer.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- env: csgo
|
| 4 |
+
- agent: csgo
|
| 5 |
+
- world_model_env: fast
|
| 6 |
+
|
| 7 |
+
static_dataset:
|
| 8 |
+
path: /tmp/dummy_data_low_res
|
| 9 |
+
ignore_sample_weights: True
|
config/world_model_env/fast.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: envs.WorldModelEnvConfig
|
| 2 |
+
horizon: 1000
|
| 3 |
+
num_batches_to_preload: 1
|
| 4 |
+
diffusion_sampler_next_obs:
|
| 5 |
+
_target_: models.diffusion.DiffusionSamplerConfig
|
| 6 |
+
num_steps_denoising: 6 # Balanced: better quality than 3, faster than 10
|
| 7 |
+
sigma_min: 0.002
|
| 8 |
+
sigma_max: 5.0
|
| 9 |
+
rho: 7
|
| 10 |
+
order: 1
|
| 11 |
+
diffusion_sampler_upsampling:
|
| 12 |
+
_target_: models.diffusion.DiffusionSamplerConfig
|
| 13 |
+
num_steps_denoising: 4 # Balanced: better quality than 2, faster than 5
|
| 14 |
+
sigma_min: 0.002
|
| 15 |
+
sigma_max: 5.0
|
| 16 |
+
rho: 7
|
| 17 |
+
order: 1
|
config_web.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration helper for web deployment
|
| 3 |
+
Handles path resolution and model loading for deployment
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class WebConfig:
|
| 14 |
+
"""Configuration manager for web deployment"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, base_path: Optional[Path] = None):
|
| 17 |
+
if base_path is None:
|
| 18 |
+
base_path = Path.cwd()
|
| 19 |
+
self.base_path = Path(base_path)
|
| 20 |
+
|
| 21 |
+
def get_config_path(self) -> Path:
|
| 22 |
+
"""Get configuration directory path"""
|
| 23 |
+
# Try multiple possible locations
|
| 24 |
+
possible_paths = [
|
| 25 |
+
self.base_path / "config",
|
| 26 |
+
self.base_path / "src" / ".." / "config",
|
| 27 |
+
Path(__file__).parent / "config"
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
for path in possible_paths:
|
| 31 |
+
if path.exists():
|
| 32 |
+
return path.resolve()
|
| 33 |
+
|
| 34 |
+
# Create default config directory
|
| 35 |
+
config_path = self.base_path / "config"
|
| 36 |
+
config_path.mkdir(exist_ok=True)
|
| 37 |
+
return config_path
|
| 38 |
+
|
| 39 |
+
def get_checkpoint_path(self) -> Path:
|
| 40 |
+
"""Find and return the best available checkpoint"""
|
| 41 |
+
# Try different possible locations and names
|
| 42 |
+
possible_checkpoints = [
|
| 43 |
+
self.base_path / "agent_epoch_00003.pt",
|
| 44 |
+
self.base_path / "agent_epoch_00003.pt",
|
| 45 |
+
self.base_path / "checkpoints" / "agent_epoch_00003.pt",
|
| 46 |
+
self.base_path / "checkpoints" / "agent_epoch_00003.pt",
|
| 47 |
+
self.base_path / "checkpoints" / "latest.pt",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
for ckpt_path in possible_checkpoints:
|
| 51 |
+
if ckpt_path.exists():
|
| 52 |
+
logger.info(f"Found checkpoint: {ckpt_path}")
|
| 53 |
+
return ckpt_path
|
| 54 |
+
|
| 55 |
+
# If no checkpoint found, create a dummy message
|
| 56 |
+
logger.warning("No checkpoint found - you may need to download models")
|
| 57 |
+
return self.base_path / "checkpoints" / "model_not_found.pt"
|
| 58 |
+
|
| 59 |
+
def get_spawn_dir(self) -> Path:
|
| 60 |
+
"""Get spawn data directory"""
|
| 61 |
+
spawn_dir = self.base_path / "csgo" / "spawn"
|
| 62 |
+
spawn_dir.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
|
| 64 |
+
# Create dummy spawn data if it doesn't exist
|
| 65 |
+
spawn_subdir = spawn_dir / "0"
|
| 66 |
+
spawn_subdir.mkdir(exist_ok=True)
|
| 67 |
+
|
| 68 |
+
# Create dummy files if they don't exist
|
| 69 |
+
dummy_files = ["act.npy", "full_res.npy", "info.json", "low_res.npy", "next_act.npy"]
|
| 70 |
+
for filename in dummy_files:
|
| 71 |
+
file_path = spawn_subdir / filename
|
| 72 |
+
if not file_path.exists():
|
| 73 |
+
if filename.endswith('.npy'):
|
| 74 |
+
import numpy as np
|
| 75 |
+
np.save(file_path, np.zeros((1, 10))) # Dummy array
|
| 76 |
+
elif filename.endswith('.json'):
|
| 77 |
+
import json
|
| 78 |
+
with open(file_path, 'w') as f:
|
| 79 |
+
json.dump({"dummy": True}, f)
|
| 80 |
+
|
| 81 |
+
return spawn_dir
|
| 82 |
+
|
| 83 |
+
def setup_environment_variables(self):
|
| 84 |
+
"""Set up environment variables for deployment"""
|
| 85 |
+
# Disable CUDA if not available (for CPU-only deployment)
|
| 86 |
+
if not self.has_cuda():
|
| 87 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
| 88 |
+
|
| 89 |
+
# Set Python path
|
| 90 |
+
python_path = str(self.base_path / "src")
|
| 91 |
+
current_path = os.environ.get("PYTHONPATH", "")
|
| 92 |
+
if python_path not in current_path:
|
| 93 |
+
os.environ["PYTHONPATH"] = f"{python_path}:{current_path}" if current_path else python_path
|
| 94 |
+
|
| 95 |
+
def has_cuda(self) -> bool:
|
| 96 |
+
"""Check if CUDA is available"""
|
| 97 |
+
try:
|
| 98 |
+
import torch
|
| 99 |
+
return torch.cuda.is_available()
|
| 100 |
+
except ImportError:
|
| 101 |
+
return False
|
| 102 |
+
|
| 103 |
+
def create_default_configs(self):
|
| 104 |
+
"""Create default configuration files if they don't exist"""
|
| 105 |
+
config_dir = self.get_config_path()
|
| 106 |
+
|
| 107 |
+
# Create agent config
|
| 108 |
+
agent_dir = config_dir / "agent"
|
| 109 |
+
agent_dir.mkdir(exist_ok=True)
|
| 110 |
+
|
| 111 |
+
agent_config_path = agent_dir / "csgo.yaml"
|
| 112 |
+
if not agent_config_path.exists():
|
| 113 |
+
with open(agent_config_path, 'w') as f:
|
| 114 |
+
f.write("""_target_: agent.AgentConfig
|
| 115 |
+
|
| 116 |
+
denoiser:
|
| 117 |
+
_target_: models.diffusion.DenoiserConfig
|
| 118 |
+
sigma_data: 0.5
|
| 119 |
+
sigma_offset_noise: 0.1
|
| 120 |
+
noise_previous_obs: true
|
| 121 |
+
upsampling_factor: null
|
| 122 |
+
inner_model:
|
| 123 |
+
_target_: models.diffusion.InnerModelConfig
|
| 124 |
+
img_channels: 3
|
| 125 |
+
num_steps_conditioning: 4
|
| 126 |
+
cond_channels: 2048
|
| 127 |
+
depths: [2, 2, 2, 2]
|
| 128 |
+
channels: [128, 256, 512, 1024]
|
| 129 |
+
attn_depths: [0, 0, 1, 1]
|
| 130 |
+
|
| 131 |
+
upsampler:
|
| 132 |
+
_target_: models.diffusion.DenoiserConfig
|
| 133 |
+
sigma_data: 0.5
|
| 134 |
+
sigma_offset_noise: 0.1
|
| 135 |
+
noise_previous_obs: false
|
| 136 |
+
upsampling_factor: 5
|
| 137 |
+
inner_model:
|
| 138 |
+
_target_: models.diffusion.InnerModelConfig
|
| 139 |
+
img_channels: 3
|
| 140 |
+
num_steps_conditioning: 1
|
| 141 |
+
cond_channels: 2048
|
| 142 |
+
depths: [2, 2, 2, 2]
|
| 143 |
+
channels: [64, 64, 128, 256]
|
| 144 |
+
attn_depths: [0, 0, 0, 1]
|
| 145 |
+
|
| 146 |
+
rew_end_model: null
|
| 147 |
+
actor_critic: null
|
| 148 |
+
""")
|
| 149 |
+
|
| 150 |
+
# Create env config
|
| 151 |
+
env_dir = config_dir / "env"
|
| 152 |
+
env_dir.mkdir(exist_ok=True)
|
| 153 |
+
|
| 154 |
+
env_config_path = env_dir / "csgo.yaml"
|
| 155 |
+
if not env_config_path.exists():
|
| 156 |
+
with open(env_config_path, 'w') as f:
|
| 157 |
+
f.write("""train:
|
| 158 |
+
id: csgo
|
| 159 |
+
size: [150, 600]
|
| 160 |
+
num_actions: 51
|
| 161 |
+
path_data_low_res: /tmp/dummy_data_low_res
|
| 162 |
+
path_data_full_res: /tmp/dummy_data_full_res
|
| 163 |
+
keymap: csgo
|
| 164 |
+
""")
|
| 165 |
+
|
| 166 |
+
# Create world model env config
|
| 167 |
+
wm_env_dir = config_dir / "world_model_env"
|
| 168 |
+
wm_env_dir.mkdir(exist_ok=True)
|
| 169 |
+
|
| 170 |
+
wm_config_path = wm_env_dir / "fast.yaml"
|
| 171 |
+
if not wm_config_path.exists():
|
| 172 |
+
with open(wm_config_path, 'w') as f:
|
| 173 |
+
f.write("""_target_: envs.WorldModelEnvConfig
|
| 174 |
+
horizon: 1000
|
| 175 |
+
num_batches_to_preload: 1
|
| 176 |
+
diffusion_sampler_next_obs:
|
| 177 |
+
_target_: models.diffusion.DiffusionSamplerConfig
|
| 178 |
+
num_steps_denoising: 10
|
| 179 |
+
sigma_min: 0.002
|
| 180 |
+
sigma_max: 5.0
|
| 181 |
+
rho: 7
|
| 182 |
+
order: 1
|
| 183 |
+
diffusion_sampler_upsampling:
|
| 184 |
+
_target_: models.diffusion.DiffusionSamplerConfig
|
| 185 |
+
num_steps_denoising: 5
|
| 186 |
+
sigma_min: 0.002
|
| 187 |
+
sigma_max: 5.0
|
| 188 |
+
rho: 7
|
| 189 |
+
order: 1
|
| 190 |
+
""")
|
| 191 |
+
|
| 192 |
+
# Create trainer config
|
| 193 |
+
trainer_config_path = config_dir / "trainer.yaml"
|
| 194 |
+
if not trainer_config_path.exists():
|
| 195 |
+
with open(trainer_config_path, 'w') as f:
|
| 196 |
+
f.write("""defaults:
|
| 197 |
+
- _self_
|
| 198 |
+
- env: csgo
|
| 199 |
+
- agent: csgo
|
| 200 |
+
- world_model_env: fast
|
| 201 |
+
|
| 202 |
+
static_dataset:
|
| 203 |
+
path: /tmp/dummy_data_low_res
|
| 204 |
+
ignore_sample_weights: True
|
| 205 |
+
""")
|
| 206 |
+
|
| 207 |
+
# Global config instance
|
| 208 |
+
web_config = WebConfig()
|
csgo/spawn/0/act.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11830620c54f47d0ee6a9f904e68516980f8cd5af488572bd6e9e4815e8be52d
|
| 3 |
+
size 332
|
csgo/spawn/0/full_res.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cff6e9d7871c6f3c622f964fabbc181befeca9ef5b5a8a3f4e6cce1af79e6a8f
|
| 3 |
+
size 1080128
|
csgo/spawn/0/info.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"original_file_id": "4001-4200/hdf5_dm_july2021_4143.hdf5", "timestep_start": 540}
|
csgo/spawn/0/low_res.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d775f579f104caf9e195fa12bdf302e54c3c0f8938483ace0fa2cb75b694be1
|
| 3 |
+
size 43328
|
csgo/spawn/0/next_act.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:762d9db84444e12912a8e535d10b29783db59f6a7f97579c933496354ffb4bb6
|
| 3 |
+
size 10328
|
packages.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build-essential
|
| 2 |
+
curl
|
| 3 |
+
git
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML dependencies
|
| 2 |
+
torch>=1.13.0
|
| 3 |
+
torchvision>=0.14.0
|
| 4 |
+
torchaudio>=0.13.0
|
| 5 |
+
numpy>=1.21.0
|
| 6 |
+
|
| 7 |
+
# Configuration management
|
| 8 |
+
hydra-core>=1.2.0
|
| 9 |
+
omegaconf>=2.2.0
|
| 10 |
+
|
| 11 |
+
# Web framework for deployment
|
| 12 |
+
fastapi>=0.68.0
|
| 13 |
+
uvicorn>=0.15.0
|
| 14 |
+
websockets>=10.0
|
| 15 |
+
|
| 16 |
+
# Image processing
|
| 17 |
+
opencv-python-headless>=4.5.0
|
| 18 |
+
Pillow>=8.0.0
|
| 19 |
+
|
| 20 |
+
# Hugging Face integration
|
| 21 |
+
huggingface-hub>=0.10.0
|
| 22 |
+
|
| 23 |
+
# Data handling
|
| 24 |
+
h5py>=3.7.0
|
| 25 |
+
|
| 26 |
+
# Optional: for better performance
|
| 27 |
+
# torch-audio # if needed for audio processing
|
| 28 |
+
|
| 29 |
+
# Development dependencies (uncomment for local development)
|
| 30 |
+
# pytest>=6.0.0
|
| 31 |
+
# black>=21.0.0
|
| 32 |
+
# isort>=5.0.0
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
src/__pycache__/agent.cpython-310.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
src/__pycache__/trainer.cpython-310.pyc
ADDED
|
Binary file (14.6 kB). View file
|
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (13.2 kB). View file
|
|
|
src/agent.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from envs import TorchEnv, WorldModelEnv
|
| 9 |
+
from models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig
|
| 10 |
+
from models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig
|
| 11 |
+
from models.rew_end_model import RewEndModel, RewEndModelConfig
|
| 12 |
+
from utils import extract_state_dict
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class AgentConfig:
|
| 17 |
+
denoiser: DenoiserConfig
|
| 18 |
+
upsampler: Optional[DenoiserConfig]
|
| 19 |
+
rew_end_model: Optional[RewEndModelConfig]
|
| 20 |
+
actor_critic: Optional[ActorCriticConfig]
|
| 21 |
+
num_actions: int
|
| 22 |
+
|
| 23 |
+
def __post_init__(self) -> None:
|
| 24 |
+
self.denoiser.inner_model.num_actions = self.num_actions
|
| 25 |
+
if self.upsampler is not None:
|
| 26 |
+
self.upsampler.inner_model.num_actions = self.num_actions
|
| 27 |
+
if self.rew_end_model is not None:
|
| 28 |
+
self.rew_end_model.num_actions = self.num_actions
|
| 29 |
+
if self.actor_critic is not None:
|
| 30 |
+
self.actor_critic.num_actions = self.num_actions
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Agent(nn.Module):
|
| 34 |
+
def __init__(self, cfg: AgentConfig) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.denoiser = Denoiser(cfg.denoiser)
|
| 37 |
+
self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None
|
| 38 |
+
self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None
|
| 39 |
+
self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None
|
| 40 |
+
|
| 41 |
+
@property
|
| 42 |
+
def device(self):
|
| 43 |
+
return self.denoiser.device
|
| 44 |
+
|
| 45 |
+
def setup_training(
|
| 46 |
+
self,
|
| 47 |
+
sigma_distribution_cfg: SigmaDistributionConfig,
|
| 48 |
+
sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig],
|
| 49 |
+
actor_critic_loss_cfg: Optional[ActorCriticLossConfig],
|
| 50 |
+
rl_env: Optional[Union[TorchEnv, WorldModelEnv]],
|
| 51 |
+
) -> None:
|
| 52 |
+
self.denoiser.setup_training(sigma_distribution_cfg)
|
| 53 |
+
if self.upsampler is not None:
|
| 54 |
+
self.upsampler.setup_training(sigma_distribution_cfg_upsampler)
|
| 55 |
+
if self.actor_critic is not None:
|
| 56 |
+
self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg)
|
| 57 |
+
|
| 58 |
+
def load(
|
| 59 |
+
self,
|
| 60 |
+
path_to_ckpt: Path,
|
| 61 |
+
load_denoiser: bool = True,
|
| 62 |
+
load_upsampler: bool = True,
|
| 63 |
+
load_rew_end_model: bool = True,
|
| 64 |
+
load_actor_critic: bool = True,
|
| 65 |
+
) -> None:
|
| 66 |
+
sd = torch.load(Path(path_to_ckpt), map_location=self.device)
|
| 67 |
+
if load_denoiser:
|
| 68 |
+
self.denoiser.load_state_dict(extract_state_dict(sd, "denoiser"))
|
| 69 |
+
if load_upsampler:
|
| 70 |
+
self.upsampler.load_state_dict(extract_state_dict(sd, "upsampler"))
|
| 71 |
+
if load_rew_end_model and self.rew_end_model is not None:
|
| 72 |
+
self.rew_end_model.load_state_dict(extract_state_dict(sd, "rew_end_model"))
|
| 73 |
+
if load_actor_critic and self.actor_critic is not None:
|
| 74 |
+
self.actor_critic.load_state_dict(extract_state_dict(sd, "actor_critic"))
|
src/coroutines/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def coroutine(func):
|
| 5 |
+
@wraps(func)
|
| 6 |
+
def primer(*args, **kwargs):
|
| 7 |
+
gen = func(*args, **kwargs)
|
| 8 |
+
next(gen)
|
| 9 |
+
return gen
|
| 10 |
+
|
| 11 |
+
return primer
|
src/coroutines/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (484 Bytes). View file
|
|
|
src/coroutines/__pycache__/collector.cpython-310.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
src/coroutines/__pycache__/env_loop.cpython-310.pyc
ADDED
|
Binary file (2.31 kB). View file
|
|
|
src/coroutines/collector.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Generator, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from . import coroutine
|
| 10 |
+
from data import Episode, Dataset
|
| 11 |
+
from envs import TorchEnv
|
| 12 |
+
from .env_loop import make_env_loop
|
| 13 |
+
from utils import Logs
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@coroutine
|
| 17 |
+
def make_collector(
|
| 18 |
+
env: TorchEnv,
|
| 19 |
+
model: nn.Module,
|
| 20 |
+
dataset: Dataset,
|
| 21 |
+
epsilon: float = 0.0,
|
| 22 |
+
reset_every_collect: bool = False,
|
| 23 |
+
verbose: bool = True,
|
| 24 |
+
) -> Generator[Logs, int, None]:
|
| 25 |
+
num_envs = env.num_envs
|
| 26 |
+
|
| 27 |
+
env_loop, buffer, episode_ids, dead = (None,) * 4
|
| 28 |
+
num_steps, num_episodes, to_log, pbar = (None,) * 4
|
| 29 |
+
|
| 30 |
+
def setup_new_collect():
|
| 31 |
+
nonlocal num_steps, num_episodes, buffer, to_log, pbar
|
| 32 |
+
num_steps = 0
|
| 33 |
+
num_episodes = 0
|
| 34 |
+
buffer = defaultdict(list)
|
| 35 |
+
to_log = []
|
| 36 |
+
pbar = tqdm(
|
| 37 |
+
total=num_to_collect.total,
|
| 38 |
+
unit=num_to_collect.unit,
|
| 39 |
+
desc=f"Collect {dataset.name}",
|
| 40 |
+
disable=not verbose,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def reset():
|
| 44 |
+
nonlocal env_loop, episode_ids, dead
|
| 45 |
+
env_loop = make_env_loop(env, model, epsilon)
|
| 46 |
+
episode_ids = defaultdict(lambda: None)
|
| 47 |
+
dead = [None] * num_envs
|
| 48 |
+
|
| 49 |
+
num_to_collect = yield
|
| 50 |
+
setup_new_collect()
|
| 51 |
+
reset()
|
| 52 |
+
|
| 53 |
+
while True:
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
all_obs, act, rew, end, trunc, *_, [infos] = env_loop.send(1)
|
| 56 |
+
|
| 57 |
+
num_steps += num_envs
|
| 58 |
+
pbar.update(num_envs if num_to_collect.steps is not None else 0)
|
| 59 |
+
|
| 60 |
+
for i, (o, a, r, e, t) in enumerate(zip(all_obs, act, rew, end, trunc)):
|
| 61 |
+
buffer[i].append((o, a, r, e, t))
|
| 62 |
+
dead[i] = (e + t).clip(max=1).item()
|
| 63 |
+
|
| 64 |
+
num_episodes += sum(dead)
|
| 65 |
+
|
| 66 |
+
can_stop = num_to_collect.can_stop(num_steps, num_episodes)
|
| 67 |
+
|
| 68 |
+
count_dead = 0
|
| 69 |
+
for i in range(num_envs):
|
| 70 |
+
# Store incomplete episodes only when reset_every_collect is set to False (train)
|
| 71 |
+
add_to_dataset = dead[i] or (can_stop and not reset_every_collect)
|
| 72 |
+
if add_to_dataset:
|
| 73 |
+
info = {"final_observation": infos["final_observation"][count_dead]} if dead[i] else {}
|
| 74 |
+
ep = Episode(*(torch.cat(x, dim=0) for x in zip(*buffer[i])), info).to("cpu")
|
| 75 |
+
if episode_ids[i] is not None:
|
| 76 |
+
ep = dataset.load_episode(episode_ids[i]) + ep
|
| 77 |
+
episode_ids[i] = dataset.add_episode(ep, episode_id=episode_ids[i])
|
| 78 |
+
|
| 79 |
+
if dead[i]:
|
| 80 |
+
to_log.append(
|
| 81 |
+
{
|
| 82 |
+
f"{dataset.name}/episode_id": episode_ids[i],
|
| 83 |
+
**ep.compute_metrics(),
|
| 84 |
+
}
|
| 85 |
+
)
|
| 86 |
+
buffer[i] = []
|
| 87 |
+
episode_ids[i] = None
|
| 88 |
+
pbar.update(1 if num_to_collect.episodes is not None else 0)
|
| 89 |
+
|
| 90 |
+
count_dead += dead[i]
|
| 91 |
+
|
| 92 |
+
if can_stop:
|
| 93 |
+
pbar.close()
|
| 94 |
+
metrics = {
|
| 95 |
+
"num_steps": dataset.num_steps,
|
| 96 |
+
"counts/rew_-1": dataset.counts_rew[0],
|
| 97 |
+
"counts/rew__0": dataset.counts_rew[1],
|
| 98 |
+
"counts/rew_+1": dataset.counts_rew[2],
|
| 99 |
+
"counts/end_0": dataset.counts_end[0],
|
| 100 |
+
"counts/end_1": dataset.counts_end[1],
|
| 101 |
+
}
|
| 102 |
+
to_log.append({f"{dataset.name}/{k}": v for k, v in metrics.items()})
|
| 103 |
+
num_to_collect = yield to_log
|
| 104 |
+
setup_new_collect()
|
| 105 |
+
if reset_every_collect:
|
| 106 |
+
reset()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class NumToCollect:
|
| 111 |
+
steps: Optional[int] = None
|
| 112 |
+
episodes: Optional[int] = None
|
| 113 |
+
|
| 114 |
+
def __post_init__(self) -> None:
|
| 115 |
+
assert (self.steps is None) != (self.episodes is None)
|
| 116 |
+
|
| 117 |
+
def can_stop(self, num_steps: int, num_episodes: int) -> bool:
|
| 118 |
+
return num_steps >= self.steps if self.steps is not None else num_episodes >= self.episodes
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def unit(self) -> str:
|
| 122 |
+
return "steps" if self.steps is not None else "eps"
|
| 123 |
+
|
| 124 |
+
@property
|
| 125 |
+
def total(self) -> int:
|
| 126 |
+
return self.steps if self.steps is not None else self.episodes
|
src/coroutines/env_loop.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Generator, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.distributions.categorical import Categorical
|
| 7 |
+
|
| 8 |
+
from . import coroutine
|
| 9 |
+
from envs import TorchEnv, WorldModelEnv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@coroutine
|
| 13 |
+
def make_env_loop(
|
| 14 |
+
env: Union[TorchEnv, WorldModelEnv], model: nn.Module, epsilon: float = 0.0
|
| 15 |
+
) -> Generator[Tuple[torch.Tensor, ...], int, None]:
|
| 16 |
+
num_steps = yield
|
| 17 |
+
|
| 18 |
+
hx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
|
| 19 |
+
cx = torch.zeros(env.num_envs, model.lstm_dim, device=model.device)
|
| 20 |
+
|
| 21 |
+
seed = random.randint(0, 2**31 - 1)
|
| 22 |
+
obs, _ = env.reset(seed=[seed + i for i in range(env.num_envs)])
|
| 23 |
+
|
| 24 |
+
while True:
|
| 25 |
+
hx, cx = hx.detach(), cx.detach()
|
| 26 |
+
all_ = []
|
| 27 |
+
infos = []
|
| 28 |
+
n = 0
|
| 29 |
+
|
| 30 |
+
while n < num_steps:
|
| 31 |
+
logits_act, val, (hx, cx) = model.predict_act_value(obs, (hx, cx))
|
| 32 |
+
act = Categorical(logits=logits_act).sample()
|
| 33 |
+
|
| 34 |
+
if random.random() < epsilon:
|
| 35 |
+
act = torch.randint(low=0, high=env.num_actions, size=(obs.size(0),), device=obs.device)
|
| 36 |
+
|
| 37 |
+
next_obs, rew, end, trunc, info = env.step(act)
|
| 38 |
+
|
| 39 |
+
if n > 0:
|
| 40 |
+
val_bootstrap = val.detach().clone()
|
| 41 |
+
if dead.any():
|
| 42 |
+
val_bootstrap[dead] = val_final_obs
|
| 43 |
+
all_[-1][-1] = val_bootstrap
|
| 44 |
+
|
| 45 |
+
dead = torch.logical_or(end, trunc)
|
| 46 |
+
|
| 47 |
+
if dead.any():
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
_, val_final_obs, _ = model.predict_act_value(info["final_observation"], (hx[dead], cx[dead]))
|
| 50 |
+
reset_gate = 1 - dead.float().unsqueeze(1)
|
| 51 |
+
hx = hx * reset_gate
|
| 52 |
+
cx = cx * reset_gate
|
| 53 |
+
if "burnin_obs" in info:
|
| 54 |
+
burnin_obs = info["burnin_obs"]
|
| 55 |
+
for i in range(burnin_obs.size(1)):
|
| 56 |
+
_, _, (hx[dead], cx[dead]) = model.predict_act_value(burnin_obs[:, i], (hx[dead], cx[dead]))
|
| 57 |
+
|
| 58 |
+
all_.append([obs, act, rew, end, trunc, logits_act, val, None])
|
| 59 |
+
infos.append(info)
|
| 60 |
+
|
| 61 |
+
obs = next_obs
|
| 62 |
+
n += 1
|
| 63 |
+
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
_, val_bootstrap, _ = model.predict_act_value(next_obs, (hx, cx)) # do not update hx/cx
|
| 66 |
+
|
| 67 |
+
if dead.any():
|
| 68 |
+
val_bootstrap[dead] = val_final_obs
|
| 69 |
+
|
| 70 |
+
all_[-1][-1] = val_bootstrap
|
| 71 |
+
|
| 72 |
+
all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap = (torch.stack(x, dim=1) for x in zip(*all_))
|
| 73 |
+
|
| 74 |
+
num_steps = yield all_obs, act, rew, end, trunc, logits_act, val, val_bootstrap, infos
|
src/csgo/__init__.py
ADDED
|
File without changes
|
src/csgo/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (148 Bytes). View file
|
|
|
src/csgo/__pycache__/action_processing.cpython-310.pyc
ADDED
|
Binary file (5.76 kB). View file
|
|
|
src/csgo/__pycache__/keymap.cpython-310.pyc
ADDED
|
Binary file (622 Bytes). View file
|
|
|
src/csgo/__pycache__/web_action_processing.cpython-310.pyc
ADDED
|
Binary file (4.71 kB). View file
|
|
|
src/csgo/action_processing.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Credits: some parts are taken and modified from the file `config.py` from https://github.com/TeaPearce/Counter-Strike_Behavioural_Cloning/
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, List, Set, Tuple
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pygame
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .keymap import CSGO_FORBIDDEN_COMBINATIONS, CSGO_KEYMAP
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class CSGOAction:
|
| 17 |
+
keys: List[int]
|
| 18 |
+
mouse_x: float
|
| 19 |
+
mouse_y: float
|
| 20 |
+
l_click: bool
|
| 21 |
+
r_click: bool
|
| 22 |
+
|
| 23 |
+
def __post_init__(self) -> None:
|
| 24 |
+
self.keys = filter_keys_pressed_forbidden(self.keys)
|
| 25 |
+
self.process_mouse()
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def key_names(self) -> List[str]:
|
| 29 |
+
return [pygame.key.name(key) for key in self.keys]
|
| 30 |
+
|
| 31 |
+
def process_mouse(self) -> None:
|
| 32 |
+
# Clip and match mouse to closest in list of possibles
|
| 33 |
+
x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1])
|
| 34 |
+
y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1])
|
| 35 |
+
self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x))
|
| 36 |
+
self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y))
|
| 37 |
+
|
| 38 |
+
# Use arrows to override mouse movements
|
| 39 |
+
for key in self.key_names:
|
| 40 |
+
if key == "left":
|
| 41 |
+
self.mouse_x = -60
|
| 42 |
+
elif key == "right":
|
| 43 |
+
self.mouse_x = +60
|
| 44 |
+
elif key == "up":
|
| 45 |
+
self.mouse_y = -50
|
| 46 |
+
elif key == "down":
|
| 47 |
+
self.mouse_y = +50
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def print_csgo_action(action: CSGOAction) -> Tuple[str]:
|
| 51 |
+
action_names = [CSGO_KEYMAP[k] for k in action.keys] if len(action.keys) > 0 else []
|
| 52 |
+
action_names = [x for x in action_names if not x.startswith("camera_")]
|
| 53 |
+
keys = " + ".join(action_names)
|
| 54 |
+
mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0)
|
| 55 |
+
clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click
|
| 56 |
+
return keys, mouse, clicks
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
MOUSE_X_POSSIBLES = [
|
| 60 |
+
-1000,
|
| 61 |
+
-500,
|
| 62 |
+
-300,
|
| 63 |
+
-200,
|
| 64 |
+
-100,
|
| 65 |
+
-60,
|
| 66 |
+
-30,
|
| 67 |
+
-20,
|
| 68 |
+
-10,
|
| 69 |
+
-4,
|
| 70 |
+
-2,
|
| 71 |
+
0,
|
| 72 |
+
2,
|
| 73 |
+
4,
|
| 74 |
+
10,
|
| 75 |
+
20,
|
| 76 |
+
30,
|
| 77 |
+
60,
|
| 78 |
+
100,
|
| 79 |
+
200,
|
| 80 |
+
300,
|
| 81 |
+
500,
|
| 82 |
+
1000,
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
MOUSE_Y_POSSIBLES = [
|
| 86 |
+
-200,
|
| 87 |
+
-100,
|
| 88 |
+
-50,
|
| 89 |
+
-20,
|
| 90 |
+
-10,
|
| 91 |
+
-4,
|
| 92 |
+
-2,
|
| 93 |
+
0,
|
| 94 |
+
2,
|
| 95 |
+
4,
|
| 96 |
+
10,
|
| 97 |
+
20,
|
| 98 |
+
50,
|
| 99 |
+
100,
|
| 100 |
+
200,
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
MOUSE_X_LIM = (MOUSE_X_POSSIBLES[0], MOUSE_X_POSSIBLES[-1])
|
| 104 |
+
MOUSE_Y_LIM = (MOUSE_Y_POSSIBLES[0], MOUSE_Y_POSSIBLES[-1])
|
| 105 |
+
N_KEYS = 11 # number of keyboard outputs, w,s,a,d,space,ctrl,shift,1,2,3,r
|
| 106 |
+
N_CLICKS = 2 # number of mouse buttons, left, right
|
| 107 |
+
N_MOUSE_X = len(MOUSE_X_POSSIBLES) # number of outputs on mouse x axis
|
| 108 |
+
N_MOUSE_Y = len(MOUSE_Y_POSSIBLES) # number of outputs on mouse y axis
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def encode_csgo_action(csgo_action: CSGOAction, device: torch.device) -> torch.Tensor:
|
| 112 |
+
|
| 113 |
+
# mouse_x = csgo_action.mouse_x
|
| 114 |
+
# mouse_y = csgo_action.mouse_y
|
| 115 |
+
|
| 116 |
+
keys_pressed_onehot = np.zeros(N_KEYS)
|
| 117 |
+
mouse_x_onehot = np.zeros(N_MOUSE_X)
|
| 118 |
+
mouse_y_onehot = np.zeros(N_MOUSE_Y)
|
| 119 |
+
l_click_onehot = np.zeros(1)
|
| 120 |
+
r_click_onehot = np.zeros(1)
|
| 121 |
+
|
| 122 |
+
for key in csgo_action.key_names:
|
| 123 |
+
if key == "w":
|
| 124 |
+
keys_pressed_onehot[0] = 1
|
| 125 |
+
elif key == "a":
|
| 126 |
+
keys_pressed_onehot[1] = 1
|
| 127 |
+
elif key == "s":
|
| 128 |
+
keys_pressed_onehot[2] = 1
|
| 129 |
+
elif key == "d":
|
| 130 |
+
keys_pressed_onehot[3] = 1
|
| 131 |
+
elif key == "space":
|
| 132 |
+
keys_pressed_onehot[4] = 1
|
| 133 |
+
elif key == "left ctrl":
|
| 134 |
+
keys_pressed_onehot[5] = 1
|
| 135 |
+
elif key == "left shift":
|
| 136 |
+
keys_pressed_onehot[6] = 1
|
| 137 |
+
elif key == "1":
|
| 138 |
+
keys_pressed_onehot[7] = 1
|
| 139 |
+
elif key == "2":
|
| 140 |
+
keys_pressed_onehot[8] = 1
|
| 141 |
+
elif key == "3":
|
| 142 |
+
keys_pressed_onehot[9] = 1
|
| 143 |
+
elif key == "r":
|
| 144 |
+
keys_pressed_onehot[10] = 1
|
| 145 |
+
|
| 146 |
+
l_click_onehot[0] = int(csgo_action.l_click)
|
| 147 |
+
r_click_onehot[0] = int(csgo_action.r_click)
|
| 148 |
+
|
| 149 |
+
mouse_x_onehot[MOUSE_X_POSSIBLES.index(csgo_action.mouse_x)] = 1
|
| 150 |
+
mouse_y_onehot[MOUSE_Y_POSSIBLES.index(csgo_action.mouse_y)] = 1
|
| 151 |
+
|
| 152 |
+
assert mouse_x_onehot.sum() == 1
|
| 153 |
+
assert mouse_y_onehot.sum() == 1
|
| 154 |
+
|
| 155 |
+
return torch.tensor(
|
| 156 |
+
np.concatenate((
|
| 157 |
+
keys_pressed_onehot,
|
| 158 |
+
l_click_onehot,
|
| 159 |
+
r_click_onehot,
|
| 160 |
+
mouse_x_onehot,
|
| 161 |
+
mouse_y_onehot,
|
| 162 |
+
)),
|
| 163 |
+
device=device,
|
| 164 |
+
dtype=torch.float32,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def decode_csgo_action(y_preds: torch.Tensor) -> CSGOAction:
|
| 169 |
+
y_preds = y_preds.squeeze()
|
| 170 |
+
keys_pred = y_preds[0:N_KEYS]
|
| 171 |
+
l_click_pred = y_preds[N_KEYS : N_KEYS + 1]
|
| 172 |
+
r_click_pred = y_preds[N_KEYS + 1 : N_KEYS + N_CLICKS]
|
| 173 |
+
mouse_x_pred = y_preds[N_KEYS + N_CLICKS : N_KEYS + N_CLICKS + N_MOUSE_X]
|
| 174 |
+
mouse_y_pred = y_preds[
|
| 175 |
+
N_KEYS + N_CLICKS + N_MOUSE_X : N_KEYS + N_CLICKS + N_MOUSE_X + N_MOUSE_Y
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
keys_pressed = []
|
| 179 |
+
keys_pressed_onehot = np.round(keys_pred)
|
| 180 |
+
if keys_pressed_onehot[0] == 1:
|
| 181 |
+
keys_pressed.append("w")
|
| 182 |
+
if keys_pressed_onehot[1] == 1:
|
| 183 |
+
keys_pressed.append("a")
|
| 184 |
+
if keys_pressed_onehot[2] == 1:
|
| 185 |
+
keys_pressed.append("s")
|
| 186 |
+
if keys_pressed_onehot[3] == 1:
|
| 187 |
+
keys_pressed.append("d")
|
| 188 |
+
if keys_pressed_onehot[4] == 1:
|
| 189 |
+
keys_pressed.append("space")
|
| 190 |
+
if keys_pressed_onehot[5] == 1:
|
| 191 |
+
keys_pressed.append("left ctrl")
|
| 192 |
+
if keys_pressed_onehot[6] == 1:
|
| 193 |
+
keys_pressed.append("left shift")
|
| 194 |
+
if keys_pressed_onehot[7] == 1:
|
| 195 |
+
keys_pressed.append("1")
|
| 196 |
+
if keys_pressed_onehot[8] == 1:
|
| 197 |
+
keys_pressed.append("2")
|
| 198 |
+
if keys_pressed_onehot[9] == 1:
|
| 199 |
+
keys_pressed.append("3")
|
| 200 |
+
if keys_pressed_onehot[10] == 1:
|
| 201 |
+
keys_pressed.append("r")
|
| 202 |
+
|
| 203 |
+
l_click = int(np.round(l_click_pred))
|
| 204 |
+
r_click = int(np.round(r_click_pred))
|
| 205 |
+
|
| 206 |
+
id = np.argmax(mouse_x_pred)
|
| 207 |
+
mouse_x = MOUSE_X_POSSIBLES[id]
|
| 208 |
+
id = np.argmax(mouse_y_pred)
|
| 209 |
+
mouse_y = MOUSE_Y_POSSIBLES[id]
|
| 210 |
+
|
| 211 |
+
keys_pressed = [pygame.key.key_code(x) for x in keys_pressed]
|
| 212 |
+
|
| 213 |
+
return CSGOAction(keys_pressed, mouse_x, mouse_y, bool(l_click), bool(r_click))
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def filter_keys_pressed_forbidden(keys_pressed: List[int], keymap: Dict[int, str] = CSGO_KEYMAP, forbidden_combinations: List[Set[str]] = CSGO_FORBIDDEN_COMBINATIONS) -> List[int]:
|
| 217 |
+
keys = set()
|
| 218 |
+
names = set()
|
| 219 |
+
for key in keys_pressed:
|
| 220 |
+
if key not in keymap:
|
| 221 |
+
continue
|
| 222 |
+
name = keymap[key]
|
| 223 |
+
keys.add(key)
|
| 224 |
+
names.add(name)
|
| 225 |
+
for forbidden in forbidden_combinations:
|
| 226 |
+
if forbidden.issubset(names):
|
| 227 |
+
keys.remove(key)
|
| 228 |
+
names.remove(name)
|
| 229 |
+
break
|
| 230 |
+
return list(filter(lambda key: key in keys, keys_pressed))
|
src/csgo/keymap.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pygame
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
CSGO_KEYMAP = {
|
| 5 |
+
pygame.K_w: "up",
|
| 6 |
+
pygame.K_d: "right",
|
| 7 |
+
pygame.K_a: "left",
|
| 8 |
+
pygame.K_s: "down",
|
| 9 |
+
pygame.K_SPACE: "jump",
|
| 10 |
+
pygame.K_LCTRL: "crouch",
|
| 11 |
+
pygame.K_LSHIFT: "walk",
|
| 12 |
+
pygame.K_1: "weapon1",
|
| 13 |
+
pygame.K_2: "weapon2",
|
| 14 |
+
pygame.K_3: "weapon3",
|
| 15 |
+
pygame.K_r: "reload",
|
| 16 |
+
|
| 17 |
+
# Override mouse movement with arrows
|
| 18 |
+
pygame.K_UP: "camera_up",
|
| 19 |
+
pygame.K_RIGHT: "camera_right",
|
| 20 |
+
pygame.K_LEFT: "camera_left",
|
| 21 |
+
pygame.K_DOWN: "camera_down",
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
CSGO_FORBIDDEN_COMBINATIONS = [
|
| 26 |
+
{"up", "down"},
|
| 27 |
+
{"left", "right"},
|
| 28 |
+
{"weapon1", "weapon2"},
|
| 29 |
+
{"weapon1", "weapon3"},
|
| 30 |
+
{"weapon2", "weapon3"},
|
| 31 |
+
{"camera_up", "camera_down"},
|
| 32 |
+
{"camera_left", "camera_right"},
|
| 33 |
+
]
|
src/csgo/web_action_processing.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Web-compatible action processing for CSGO actions
|
| 3 |
+
Converts web keyboard inputs to CSGO actions without pygame dependency
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, List, Set, Tuple
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
# Web key code to CSGO action mapping
|
| 13 |
+
WEB_KEYMAP = {
|
| 14 |
+
'KeyW': "up",
|
| 15 |
+
'KeyD': "right",
|
| 16 |
+
'KeyA': "left",
|
| 17 |
+
'KeyS': "down",
|
| 18 |
+
'Space': "jump",
|
| 19 |
+
'ControlLeft': "crouch",
|
| 20 |
+
'ShiftLeft': "walk",
|
| 21 |
+
'Digit1': "weapon1",
|
| 22 |
+
'Digit2': "weapon2",
|
| 23 |
+
'Digit3': "weapon3",
|
| 24 |
+
'KeyR': "reload",
|
| 25 |
+
'ArrowUp': "camera_up",
|
| 26 |
+
'ArrowRight': "camera_right",
|
| 27 |
+
'ArrowLeft': "camera_left",
|
| 28 |
+
'ArrowDown': "camera_down",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
# Forbidden key combinations (same logic as original)
|
| 32 |
+
WEB_FORBIDDEN_COMBINATIONS = [
|
| 33 |
+
{"up", "down"},
|
| 34 |
+
{"left", "right"},
|
| 35 |
+
{"weapon1", "weapon2"},
|
| 36 |
+
{"weapon1", "weapon3"},
|
| 37 |
+
{"weapon2", "weapon3"},
|
| 38 |
+
{"camera_up", "camera_down"},
|
| 39 |
+
{"camera_left", "camera_right"},
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class WebCSGOAction:
|
| 44 |
+
"""Web-compatible CSGO action without pygame dependencies"""
|
| 45 |
+
key_names: List[str] # Use string names instead of pygame key codes
|
| 46 |
+
mouse_x: float
|
| 47 |
+
mouse_y: float
|
| 48 |
+
l_click: bool
|
| 49 |
+
r_click: bool
|
| 50 |
+
|
| 51 |
+
def __post_init__(self) -> None:
|
| 52 |
+
self.key_names = filter_web_keys_forbidden(self.key_names)
|
| 53 |
+
self.process_mouse()
|
| 54 |
+
|
| 55 |
+
def process_mouse(self) -> None:
|
| 56 |
+
"""Process mouse movement with discretization"""
|
| 57 |
+
# Import mouse constants
|
| 58 |
+
from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, MOUSE_X_LIM, MOUSE_Y_LIM
|
| 59 |
+
|
| 60 |
+
# Clip and match mouse to closest in list of possibles
|
| 61 |
+
x = np.clip(self.mouse_x, MOUSE_X_LIM[0], MOUSE_X_LIM[1])
|
| 62 |
+
y = np.clip(self.mouse_y, MOUSE_Y_LIM[0], MOUSE_Y_LIM[1])
|
| 63 |
+
self.mouse_x = min(MOUSE_X_POSSIBLES, key=lambda x_: abs(x_ - x))
|
| 64 |
+
self.mouse_y = min(MOUSE_Y_POSSIBLES, key=lambda x_: abs(x_ - y))
|
| 65 |
+
|
| 66 |
+
# Use arrow keys to override mouse movements
|
| 67 |
+
for key_name in self.key_names:
|
| 68 |
+
if key_name == "camera_left":
|
| 69 |
+
self.mouse_x = -60
|
| 70 |
+
elif key_name == "camera_right":
|
| 71 |
+
self.mouse_x = +60
|
| 72 |
+
elif key_name == "camera_up":
|
| 73 |
+
self.mouse_y = -50
|
| 74 |
+
elif key_name == "camera_down":
|
| 75 |
+
self.mouse_y = +50
|
| 76 |
+
|
| 77 |
+
def filter_web_keys_forbidden(key_names: List[str]) -> List[str]:
|
| 78 |
+
"""Filter out forbidden key combinations"""
|
| 79 |
+
names = set(key_names)
|
| 80 |
+
filtered_names = []
|
| 81 |
+
|
| 82 |
+
for key_name in key_names:
|
| 83 |
+
# Check if adding this key would create a forbidden combination
|
| 84 |
+
test_names = set(filtered_names + [key_name])
|
| 85 |
+
is_forbidden = False
|
| 86 |
+
|
| 87 |
+
for forbidden in WEB_FORBIDDEN_COMBINATIONS:
|
| 88 |
+
if forbidden.issubset(test_names):
|
| 89 |
+
is_forbidden = True
|
| 90 |
+
break
|
| 91 |
+
|
| 92 |
+
if not is_forbidden:
|
| 93 |
+
filtered_names.append(key_name)
|
| 94 |
+
|
| 95 |
+
return filtered_names
|
| 96 |
+
|
| 97 |
+
def web_keys_to_csgo_action_names(pressed_web_keys: Set[str]) -> List[str]:
|
| 98 |
+
"""Convert set of pressed web keys to CSGO action names"""
|
| 99 |
+
action_names = []
|
| 100 |
+
for web_key in pressed_web_keys:
|
| 101 |
+
if web_key in WEB_KEYMAP:
|
| 102 |
+
action_names.append(WEB_KEYMAP[web_key])
|
| 103 |
+
return action_names
|
| 104 |
+
|
| 105 |
+
def encode_web_csgo_action(web_action: WebCSGOAction, device: torch.device) -> torch.Tensor:
|
| 106 |
+
"""Encode web CSGO action to tensor format (compatible with original encoding)"""
|
| 107 |
+
from .action_processing import MOUSE_X_POSSIBLES, MOUSE_Y_POSSIBLES, N_KEYS, N_CLICKS, N_MOUSE_X, N_MOUSE_Y
|
| 108 |
+
|
| 109 |
+
keys_pressed_onehot = np.zeros(N_KEYS)
|
| 110 |
+
mouse_x_onehot = np.zeros(N_MOUSE_X)
|
| 111 |
+
mouse_y_onehot = np.zeros(N_MOUSE_Y)
|
| 112 |
+
l_click_onehot = np.zeros(1)
|
| 113 |
+
r_click_onehot = np.zeros(1)
|
| 114 |
+
|
| 115 |
+
# Map action names to one-hot encoding
|
| 116 |
+
for action_name in web_action.key_names:
|
| 117 |
+
if action_name == "up": # w key
|
| 118 |
+
keys_pressed_onehot[0] = 1
|
| 119 |
+
elif action_name == "left": # a key
|
| 120 |
+
keys_pressed_onehot[1] = 1
|
| 121 |
+
elif action_name == "down": # s key
|
| 122 |
+
keys_pressed_onehot[2] = 1
|
| 123 |
+
elif action_name == "right": # d key
|
| 124 |
+
keys_pressed_onehot[3] = 1
|
| 125 |
+
elif action_name == "jump": # space
|
| 126 |
+
keys_pressed_onehot[4] = 1
|
| 127 |
+
elif action_name == "crouch": # ctrl
|
| 128 |
+
keys_pressed_onehot[5] = 1
|
| 129 |
+
elif action_name == "walk": # shift
|
| 130 |
+
keys_pressed_onehot[6] = 1
|
| 131 |
+
elif action_name == "weapon1": # 1
|
| 132 |
+
keys_pressed_onehot[7] = 1
|
| 133 |
+
elif action_name == "weapon2": # 2
|
| 134 |
+
keys_pressed_onehot[8] = 1
|
| 135 |
+
elif action_name == "weapon3": # 3
|
| 136 |
+
keys_pressed_onehot[9] = 1
|
| 137 |
+
elif action_name == "reload": # r
|
| 138 |
+
keys_pressed_onehot[10] = 1
|
| 139 |
+
|
| 140 |
+
l_click_onehot[0] = int(web_action.l_click)
|
| 141 |
+
r_click_onehot[0] = int(web_action.r_click)
|
| 142 |
+
|
| 143 |
+
mouse_x_onehot[MOUSE_X_POSSIBLES.index(web_action.mouse_x)] = 1
|
| 144 |
+
mouse_y_onehot[MOUSE_Y_POSSIBLES.index(web_action.mouse_y)] = 1
|
| 145 |
+
|
| 146 |
+
assert mouse_x_onehot.sum() == 1
|
| 147 |
+
assert mouse_y_onehot.sum() == 1
|
| 148 |
+
|
| 149 |
+
return torch.tensor(
|
| 150 |
+
np.concatenate((
|
| 151 |
+
keys_pressed_onehot,
|
| 152 |
+
l_click_onehot,
|
| 153 |
+
r_click_onehot,
|
| 154 |
+
mouse_x_onehot,
|
| 155 |
+
mouse_y_onehot,
|
| 156 |
+
)),
|
| 157 |
+
device=device,
|
| 158 |
+
dtype=torch.float32,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def print_web_csgo_action(action: WebCSGOAction) -> Tuple[str, str, str]:
|
| 162 |
+
"""Print web CSGO action in readable format"""
|
| 163 |
+
action_names = [name for name in action.key_names if not name.startswith("camera_")]
|
| 164 |
+
keys = " + ".join(action_names)
|
| 165 |
+
mouse = str((action.mouse_x, action.mouse_y)) * (action.mouse_x != 0 or action.mouse_y != 0)
|
| 166 |
+
clicks = "L" * action.l_click + " + " * (action.l_click and action.r_click) + "R" * action.r_click
|
| 167 |
+
return keys, mouse, clicks
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .batch import Batch
|
| 2 |
+
from .batch_sampler import BatchSampler
|
| 3 |
+
from .dataset import Dataset, CSGOHdf5Dataset
|
| 4 |
+
from .episode import Episode
|
| 5 |
+
from .segment import Segment, SegmentId
|
| 6 |
+
from .utils import collate_segments_to_batch, DatasetTraverser, make_segment
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (501 Bytes). View file
|
|
|
src/data/__pycache__/batch.cpython-310.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
src/data/__pycache__/batch_sampler.cpython-310.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
src/data/__pycache__/dataset.cpython-310.pyc
ADDED
|
Binary file (9.52 kB). View file
|
|
|
src/data/__pycache__/episode.cpython-310.pyc
ADDED
|
Binary file (3.76 kB). View file
|
|
|
src/data/__pycache__/segment.cpython-310.pyc
ADDED
|
Binary file (1.16 kB). View file
|
|
|
src/data/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
src/data/batch.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .segment import SegmentId
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Batch:
|
| 12 |
+
obs: torch.ByteTensor
|
| 13 |
+
act: torch.LongTensor
|
| 14 |
+
rew: torch.FloatTensor
|
| 15 |
+
end: torch.LongTensor
|
| 16 |
+
trunc: torch.LongTensor
|
| 17 |
+
mask_padding: torch.BoolTensor
|
| 18 |
+
info: List[Dict[str, Any]]
|
| 19 |
+
segment_ids: List[SegmentId]
|
| 20 |
+
|
| 21 |
+
def pin_memory(self) -> Batch:
|
| 22 |
+
return Batch(**{k: v if k in ("segment_ids", "info") else v.pin_memory() for k, v in self.__dict__.items()})
|
| 23 |
+
|
| 24 |
+
def to(self, device: torch.device) -> Batch:
|
| 25 |
+
return Batch(**{k: v if k in ("segment_ids", "info") else v.to(device) for k, v in self.__dict__.items()})
|
src/data/batch_sampler.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Generator, List, Optional
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from .dataset import CSGOHdf5Dataset, Dataset
|
| 7 |
+
from .segment import SegmentId
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class BatchSampler(torch.utils.data.Sampler):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
dataset: Dataset,
|
| 14 |
+
rank: int,
|
| 15 |
+
world_size: int,
|
| 16 |
+
batch_size: int,
|
| 17 |
+
seq_length: int,
|
| 18 |
+
sample_weights: Optional[List[float]] = None,
|
| 19 |
+
can_sample_beyond_end: bool = False,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__(dataset)
|
| 22 |
+
assert isinstance(dataset, (Dataset, CSGOHdf5Dataset))
|
| 23 |
+
self.dataset = dataset
|
| 24 |
+
self.rank = rank
|
| 25 |
+
self.world_size = world_size
|
| 26 |
+
self.sample_weights = sample_weights
|
| 27 |
+
self.batch_size = batch_size
|
| 28 |
+
self.seq_length = seq_length
|
| 29 |
+
self.can_sample_beyond_end = can_sample_beyond_end
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
def __iter__(self) -> Generator[List[SegmentId], None, None]:
|
| 35 |
+
while True:
|
| 36 |
+
yield self.sample()
|
| 37 |
+
|
| 38 |
+
def sample(self) -> List[SegmentId]:
|
| 39 |
+
num_episodes = self.dataset.num_episodes
|
| 40 |
+
|
| 41 |
+
if (self.sample_weights is None) or num_episodes < len(self.sample_weights):
|
| 42 |
+
weights = self.dataset.lengths / self.dataset.num_steps
|
| 43 |
+
else:
|
| 44 |
+
weights = self.sample_weights
|
| 45 |
+
num_weights = len(self.sample_weights)
|
| 46 |
+
assert all([0 <= x <= 1 for x in weights]) and sum(weights) == 1
|
| 47 |
+
sizes = [
|
| 48 |
+
num_episodes // num_weights + (num_episodes % num_weights) * (i == num_weights - 1)
|
| 49 |
+
for i in range(num_weights)
|
| 50 |
+
]
|
| 51 |
+
weights = [w / s for (w, s) in zip(weights, sizes) for _ in range(s)]
|
| 52 |
+
|
| 53 |
+
episodes_partition = np.arange(self.rank, num_episodes, self.world_size)
|
| 54 |
+
weights = np.array(weights[self.rank::self.world_size])
|
| 55 |
+
max_eps = self.batch_size
|
| 56 |
+
episode_ids = np.random.choice(episodes_partition, size=max_eps, replace=True, p=weights / weights.sum())
|
| 57 |
+
episode_ids = episode_ids.repeat(self.batch_size // max_eps)
|
| 58 |
+
timesteps = np.random.randint(low=0, high=self.dataset.lengths[episode_ids])
|
| 59 |
+
|
| 60 |
+
# padding allowed, both before start and after end
|
| 61 |
+
if self.can_sample_beyond_end:
|
| 62 |
+
starts = timesteps - np.random.randint(0, self.seq_length, len(timesteps))
|
| 63 |
+
stops = starts + self.seq_length
|
| 64 |
+
|
| 65 |
+
# padding allowed only before start
|
| 66 |
+
else:
|
| 67 |
+
stops = np.minimum(
|
| 68 |
+
self.dataset.lengths[episode_ids], timesteps + 1 + np.random.randint(0, self.seq_length, len(timesteps))
|
| 69 |
+
)
|
| 70 |
+
starts = stops - self.seq_length
|
| 71 |
+
|
| 72 |
+
return [SegmentId(*x) for x in zip(episode_ids, starts, stops)]
|
src/data/dataset.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
import multiprocessing as mp
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import shutil
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
|
| 7 |
+
import h5py
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch.utils.data import Dataset as TorchDataset
|
| 12 |
+
|
| 13 |
+
from .episode import Episode
|
| 14 |
+
from .segment import Segment, SegmentId
|
| 15 |
+
from .utils import make_segment
|
| 16 |
+
from utils import StateDictMixin
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Dataset(StateDictMixin, TorchDataset):
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
directory: Path,
|
| 23 |
+
dataset_full_res: Optional[TorchDataset],
|
| 24 |
+
name: Optional[str] = None,
|
| 25 |
+
cache_in_ram: bool = False,
|
| 26 |
+
use_manager: bool = False,
|
| 27 |
+
save_on_disk: bool = True,
|
| 28 |
+
) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
|
| 31 |
+
# State
|
| 32 |
+
self.is_static = False
|
| 33 |
+
self.num_episodes = None
|
| 34 |
+
self.num_steps = None
|
| 35 |
+
self.start_idx = None
|
| 36 |
+
self.lengths = None
|
| 37 |
+
self.counter_rew = None
|
| 38 |
+
self.counter_end = None
|
| 39 |
+
|
| 40 |
+
self._directory = Path(directory).expanduser()
|
| 41 |
+
self._name = name if name is not None else self._directory.stem
|
| 42 |
+
self._cache_in_ram = cache_in_ram
|
| 43 |
+
self._save_on_disk = save_on_disk
|
| 44 |
+
self._default_path = self._directory / "info.pt"
|
| 45 |
+
self._cache = mp.Manager().dict() if use_manager else {}
|
| 46 |
+
self._reset()
|
| 47 |
+
|
| 48 |
+
self._dataset_full_res = dataset_full_res
|
| 49 |
+
|
| 50 |
+
def __len__(self) -> int:
|
| 51 |
+
return self.num_steps
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, segment_id: SegmentId) -> Segment:
|
| 54 |
+
episode = self.load_episode(segment_id.episode_id)
|
| 55 |
+
segment = make_segment(episode, segment_id, should_pad=True)
|
| 56 |
+
if self._dataset_full_res is not None:
|
| 57 |
+
segment_id_full_res = SegmentId(episode.info["original_file_id"], segment_id.start, segment_id.stop)
|
| 58 |
+
segment.info["full_res"] = self._dataset_full_res[segment_id_full_res].obs
|
| 59 |
+
elif "full_res" in segment.info:
|
| 60 |
+
segment.info["full_res"] = segment.info["full_res"][segment_id.start:segment_id.stop]
|
| 61 |
+
return segment
|
| 62 |
+
|
| 63 |
+
def __str__(self) -> str:
|
| 64 |
+
return f"{self.name}: {self.num_episodes} episodes, {self.num_steps} steps."
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def name(self) -> str:
|
| 68 |
+
return self._name
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def counts_rew(self) -> List[int]:
|
| 72 |
+
return [self.counter_rew[r] for r in [-1, 0, 1]]
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def counts_end(self) -> List[int]:
|
| 76 |
+
return [self.counter_end[e] for e in [0, 1]]
|
| 77 |
+
|
| 78 |
+
def _reset(self) -> None:
|
| 79 |
+
self.num_episodes = 0
|
| 80 |
+
self.num_steps = 0
|
| 81 |
+
self.start_idx = np.array([], dtype=np.int64)
|
| 82 |
+
self.lengths = np.array([], dtype=np.int64)
|
| 83 |
+
self.counter_rew = Counter()
|
| 84 |
+
self.counter_end = Counter()
|
| 85 |
+
self._cache.clear()
|
| 86 |
+
|
| 87 |
+
def clear(self) -> None:
|
| 88 |
+
self.assert_not_static()
|
| 89 |
+
if self._directory.is_dir():
|
| 90 |
+
shutil.rmtree(self._directory)
|
| 91 |
+
self._reset()
|
| 92 |
+
|
| 93 |
+
def load_episode(self, episode_id: int) -> Episode:
|
| 94 |
+
if self._cache_in_ram and episode_id in self._cache:
|
| 95 |
+
episode = self._cache[episode_id]
|
| 96 |
+
else:
|
| 97 |
+
episode = Episode.load(self._get_episode_path(episode_id))
|
| 98 |
+
if self._cache_in_ram:
|
| 99 |
+
self._cache[episode_id] = episode
|
| 100 |
+
return episode
|
| 101 |
+
|
| 102 |
+
def add_episode(self, episode: Episode, *, episode_id: Optional[int] = None) -> int:
|
| 103 |
+
self.assert_not_static()
|
| 104 |
+
episode = episode.to("cpu")
|
| 105 |
+
|
| 106 |
+
if episode_id is None:
|
| 107 |
+
episode_id = self.num_episodes
|
| 108 |
+
self.start_idx = np.concatenate((self.start_idx, np.array([self.num_steps])))
|
| 109 |
+
self.lengths = np.concatenate((self.lengths, np.array([len(episode)])))
|
| 110 |
+
self.num_steps += len(episode)
|
| 111 |
+
self.num_episodes += 1
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
assert episode_id < self.num_episodes
|
| 115 |
+
old_episode = self.load_episode(episode_id)
|
| 116 |
+
incr_num_steps = len(episode) - len(old_episode)
|
| 117 |
+
self.lengths[episode_id] = len(episode)
|
| 118 |
+
self.start_idx[episode_id + 1 :] += incr_num_steps
|
| 119 |
+
self.num_steps += incr_num_steps
|
| 120 |
+
self.counter_rew.subtract(old_episode.rew.sign().tolist())
|
| 121 |
+
self.counter_end.subtract(old_episode.end.tolist())
|
| 122 |
+
|
| 123 |
+
self.counter_rew.update(episode.rew.sign().tolist())
|
| 124 |
+
self.counter_end.update(episode.end.tolist())
|
| 125 |
+
|
| 126 |
+
if self._save_on_disk:
|
| 127 |
+
episode.save(self._get_episode_path(episode_id))
|
| 128 |
+
|
| 129 |
+
if self._cache_in_ram:
|
| 130 |
+
self._cache[episode_id] = episode
|
| 131 |
+
|
| 132 |
+
return episode_id
|
| 133 |
+
|
| 134 |
+
def _get_episode_path(self, episode_id: int) -> Path:
|
| 135 |
+
n = 3 # number of hierarchies
|
| 136 |
+
powers = np.arange(n)
|
| 137 |
+
subfolders = np.floor((episode_id % 10 ** (1 + powers)) / 10**powers) * 10**powers
|
| 138 |
+
subfolders = [int(x) for x in subfolders[::-1]]
|
| 139 |
+
subfolders = "/".join([f"{x:0{n - i}d}" for i, x in enumerate(subfolders)])
|
| 140 |
+
return self._directory / subfolders / f"{episode_id}.pt"
|
| 141 |
+
|
| 142 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 143 |
+
super().load_state_dict(state_dict)
|
| 144 |
+
self._cache.clear()
|
| 145 |
+
|
| 146 |
+
def assert_not_static(self) -> None:
|
| 147 |
+
assert not self.is_static, "Trying to modify a static dataset."
|
| 148 |
+
|
| 149 |
+
def save_to_default_path(self) -> None:
|
| 150 |
+
self._default_path.parent.mkdir(exist_ok=True, parents=True)
|
| 151 |
+
torch.save(self.state_dict(), self._default_path)
|
| 152 |
+
|
| 153 |
+
def load_from_default_path(self) -> None:
|
| 154 |
+
if self._default_path.is_file():
|
| 155 |
+
self.load_state_dict(torch.load(self._default_path))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class CSGOHdf5Dataset(StateDictMixin, TorchDataset):
|
| 159 |
+
def __init__(self, directory: Path) -> None:
|
| 160 |
+
super().__init__()
|
| 161 |
+
filenames = sorted(Path(directory).rglob("*.hdf5"), key=lambda x: int(x.stem.split("_")[-1]))
|
| 162 |
+
self._filenames = {f"{x.parent.name}/{x.name}": x for x in filenames}
|
| 163 |
+
self._length_one_episode = 1000
|
| 164 |
+
self.num_episodes = len(self._filenames)
|
| 165 |
+
self.num_steps = self._length_one_episode * self.num_episodes
|
| 166 |
+
self.lengths = np.array([self._length_one_episode] * self.num_episodes, dtype=np.int64)
|
| 167 |
+
|
| 168 |
+
def __len__(self) -> int:
|
| 169 |
+
return self.num_steps
|
| 170 |
+
|
| 171 |
+
def save_to_default_path(self) -> None:
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
def __getitem__(self, segment_id: SegmentId) -> Segment:
|
| 175 |
+
assert segment_id.start < self._length_one_episode and segment_id.stop > 0 and segment_id.start < segment_id.stop
|
| 176 |
+
pad_len_right = max(0, segment_id.stop - self._length_one_episode)
|
| 177 |
+
pad_len_left = max(0, -segment_id.start)
|
| 178 |
+
|
| 179 |
+
start = max(0, segment_id.start)
|
| 180 |
+
stop = min(self._length_one_episode, segment_id.stop)
|
| 181 |
+
mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
|
| 182 |
+
|
| 183 |
+
with h5py.File(self._filenames[segment_id.episode_id], "r") as f:
|
| 184 |
+
obs = torch.stack([torch.tensor(f[f"frame_{i}_x"][:]).flip(2).permute(2, 0, 1).div(255).mul(2).sub(1) for i in range(start, stop)])
|
| 185 |
+
act = torch.tensor(np.array([f[f"frame_{i}_y"][:] for i in range(start, stop)]))
|
| 186 |
+
states = torch.stack([torch.tensor(f[f"frame_{i}_observation"][:]) for i in range(start, stop)])
|
| 187 |
+
ego_state = torch.stack([torch.tensor(f[f"frame_{i}_ego_state"][:]) for i in range(start, stop)])
|
| 188 |
+
|
| 189 |
+
def pad(x):
|
| 190 |
+
right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
|
| 191 |
+
return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
|
| 192 |
+
|
| 193 |
+
obs = pad(obs)
|
| 194 |
+
act = pad(act)
|
| 195 |
+
rew = torch.zeros(obs.size(0))
|
| 196 |
+
end = torch.zeros(obs.size(0), dtype=torch.uint8)
|
| 197 |
+
trunc = torch.zeros(obs.size(0), dtype=torch.uint8)
|
| 198 |
+
return Segment(obs, act, rew, end, trunc, mask_padding, states=states, ego_state=ego_state, info={}, id=SegmentId(segment_id.episode_id, start, stop))
|
| 199 |
+
|
| 200 |
+
def load_episode(self, episode_id: int) -> Episode: # used by DatasetTraverser
|
| 201 |
+
s = self[SegmentId(episode_id, 0, self._length_one_episode)]
|
| 202 |
+
return Episode(s.obs, s.act, s.rew, s.end, s.trunc, s.info)
|
src/data/episode.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class Episode:
|
| 11 |
+
obs: torch.FloatTensor
|
| 12 |
+
act: torch.LongTensor
|
| 13 |
+
rew: torch.FloatTensor
|
| 14 |
+
end: torch.ByteTensor
|
| 15 |
+
trunc: torch.ByteTensor
|
| 16 |
+
info: Dict[str, Any]
|
| 17 |
+
states: torch.FloatTensor
|
| 18 |
+
ego_state: torch.FloatTensor
|
| 19 |
+
|
| 20 |
+
def __len__(self) -> int:
|
| 21 |
+
return self.obs.size(0)
|
| 22 |
+
|
| 23 |
+
def __add__(self, other: Episode) -> Episode:
|
| 24 |
+
assert self.dead.sum() == 0
|
| 25 |
+
d = {k: torch.cat((v, other.__dict__[k]), dim=0) for k, v in self.__dict__.items() if k != "info"}
|
| 26 |
+
return Episode(**d, info=merge_info(self.info, other.info))
|
| 27 |
+
|
| 28 |
+
def to(self, device) -> Episode:
|
| 29 |
+
return Episode(**{k: v.to(device) if k != "info" else v for k, v in self.__dict__.items()})
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def dead(self) -> torch.ByteTensor:
|
| 33 |
+
return (self.end + self.trunc).clip(max=1)
|
| 34 |
+
|
| 35 |
+
def compute_metrics(self) -> Dict[str, Any]:
|
| 36 |
+
return {"length": len(self), "return": self.rew.sum().item()}
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def load(cls, path: Path, map_location: Optional[torch.device] = None) -> Episode:
|
| 40 |
+
return cls(
|
| 41 |
+
**{
|
| 42 |
+
k: v.div(255).mul(2).sub(1) if k == "obs" else v
|
| 43 |
+
for k, v in torch.load(Path(path), map_location=map_location).items()
|
| 44 |
+
}
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def save(self, path: Path) -> None:
|
| 48 |
+
path = Path(path)
|
| 49 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 50 |
+
d = {k: v.add(1).div(2).mul(255).byte() if k == "obs" else v for k, v in self.__dict__.items()}
|
| 51 |
+
torch.save(d, path.with_suffix(".tmp"))
|
| 52 |
+
path.with_suffix(".tmp").rename(path)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def merge_info(info_a, info_b):
|
| 56 |
+
keys_a = set(info_a)
|
| 57 |
+
keys_b = set(info_b)
|
| 58 |
+
intersection = keys_a & keys_b
|
| 59 |
+
info = {
|
| 60 |
+
**{k: info_a[k] for k in keys_a if k not in intersection},
|
| 61 |
+
**{k: info_b[k] for k in keys_b if k not in intersection},
|
| 62 |
+
**{k: torch.cat((info_a[k], info_b[k]), dim=0) for k in intersection},
|
| 63 |
+
}
|
| 64 |
+
return info
|
src/data/segment.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Dict, Union
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class SegmentId:
|
| 10 |
+
episode_id: Union[int, str]
|
| 11 |
+
start: int
|
| 12 |
+
stop: int
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Segment:
|
| 17 |
+
obs: torch.FloatTensor
|
| 18 |
+
act: torch.LongTensor
|
| 19 |
+
rew: torch.FloatTensor
|
| 20 |
+
end: torch.ByteTensor
|
| 21 |
+
trunc: torch.ByteTensor
|
| 22 |
+
mask_padding: torch.BoolTensor
|
| 23 |
+
states: torch.FloatTensor
|
| 24 |
+
ego_state: torch.FloatTensor
|
| 25 |
+
info: Dict[str, Any]
|
| 26 |
+
id: SegmentId
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def effective_size(self):
|
| 30 |
+
return self.mask_padding.sum().item()
|
src/data/utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Generator, List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from .batch import Batch
|
| 8 |
+
from .episode import Episode
|
| 9 |
+
from .segment import Segment, SegmentId
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def collate_segments_to_batch(segments: List[Segment]) -> Batch:
|
| 13 |
+
attrs = ("obs", "act", "rew", "end", "trunc", "mask_padding")
|
| 14 |
+
stack = (torch.stack([getattr(s, x) for s in segments]) for x in attrs)
|
| 15 |
+
return Batch(*stack, [s.info for s in segments], [s.id for s in segments])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def make_segment(episode: Episode, segment_id: SegmentId, should_pad: bool = True) -> Segment:
|
| 19 |
+
assert segment_id.start < len(episode) and segment_id.stop > 0 and segment_id.start < segment_id.stop
|
| 20 |
+
pad_len_right = max(0, segment_id.stop - len(episode))
|
| 21 |
+
pad_len_left = max(0, -segment_id.start)
|
| 22 |
+
assert pad_len_right == pad_len_left == 0 or should_pad
|
| 23 |
+
|
| 24 |
+
def pad(x):
|
| 25 |
+
right = F.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [pad_len_right]) if pad_len_right > 0 else x
|
| 26 |
+
return F.pad(right, [0 for _ in range(2 * x.ndim - 2)] + [pad_len_left, 0]) if pad_len_left > 0 else right
|
| 27 |
+
|
| 28 |
+
start = max(0, segment_id.start)
|
| 29 |
+
stop = min(len(episode), segment_id.stop)
|
| 30 |
+
mask_padding = torch.cat((torch.zeros(pad_len_left), torch.ones(stop - start), torch.zeros(pad_len_right))).bool()
|
| 31 |
+
|
| 32 |
+
return Segment(
|
| 33 |
+
pad(episode.obs[start:stop]),
|
| 34 |
+
pad(episode.act[start:stop]),
|
| 35 |
+
pad(episode.rew[start:stop]),
|
| 36 |
+
pad(episode.end[start:stop]),
|
| 37 |
+
pad(episode.trunc[start:stop]),
|
| 38 |
+
mask_padding,
|
| 39 |
+
pad(episode.states[start:stop]),
|
| 40 |
+
pad(episode.ego_state[start:stop]),
|
| 41 |
+
info=episode.info,
|
| 42 |
+
id=SegmentId(segment_id.episode_id, start, stop),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class DatasetTraverser:
|
| 47 |
+
def __init__(self, dataset, batch_num_samples: int, chunk_size: int) -> None:
|
| 48 |
+
self.dataset = dataset
|
| 49 |
+
self.batch_num_samples = batch_num_samples
|
| 50 |
+
self.chunk_size = chunk_size
|
| 51 |
+
|
| 52 |
+
def __len__(self):
|
| 53 |
+
return math.ceil(
|
| 54 |
+
sum(
|
| 55 |
+
[
|
| 56 |
+
math.ceil(self.dataset.lengths[episode_id] / self.chunk_size)
|
| 57 |
+
- int(self.dataset.lengths[episode_id] % self.chunk_size == 1)
|
| 58 |
+
for episode_id in range(self.dataset.num_episodes)
|
| 59 |
+
]
|
| 60 |
+
)
|
| 61 |
+
/ self.batch_num_samples
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def __iter__(self) -> Generator[Batch, None, None]:
|
| 65 |
+
chunks = []
|
| 66 |
+
for episode_id in range(self.dataset.num_episodes):
|
| 67 |
+
episode = self.dataset.load_episode(episode_id)
|
| 68 |
+
segments = []
|
| 69 |
+
for i in range(math.ceil(len(episode) / self.chunk_size)):
|
| 70 |
+
start = i * self.chunk_size
|
| 71 |
+
stop = (i + 1) * self.chunk_size
|
| 72 |
+
segment = make_segment(
|
| 73 |
+
episode,
|
| 74 |
+
SegmentId(episode_id, start, stop),
|
| 75 |
+
should_pad=True,
|
| 76 |
+
)
|
| 77 |
+
segment_id_full_res = SegmentId(episode.info["original_file_id"], start, stop)
|
| 78 |
+
segment.info["full_res"] = self.dataset._dataset_full_res[segment_id_full_res].obs
|
| 79 |
+
chunks.append(segment)
|
| 80 |
+
if chunks[-1].effective_size < 2:
|
| 81 |
+
chunks.pop()
|
| 82 |
+
|
| 83 |
+
while len(chunks) >= self.batch_num_samples:
|
| 84 |
+
yield collate_segments_to_batch(chunks[: self.batch_num_samples])
|
| 85 |
+
chunks = chunks[self.batch_num_samples :]
|
| 86 |
+
|
| 87 |
+
if len(chunks) > 0:
|
| 88 |
+
yield collate_segments_to_batch(chunks)
|
| 89 |
+
|
src/envs/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .env import make_atari_env, TorchEnv
|
| 2 |
+
from .world_model_env import WorldModelEnv, WorldModelEnvConfig
|