Spaces:
Sleeping
Sleeping
Commit
·
93dbff3
1
Parent(s):
a29f249
Fix bug 4
Browse files
app.py
CHANGED
|
@@ -29,6 +29,7 @@ from src.agent import Agent
|
|
| 29 |
from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
|
| 30 |
from src.envs import WorldModelEnv
|
| 31 |
from src.game.web_play_env import WebPlayEnv
|
|
|
|
| 32 |
from config_web import web_config
|
| 33 |
|
| 34 |
# Configure logging
|
|
@@ -87,39 +88,56 @@ class WebGameEngine:
|
|
| 87 |
import time
|
| 88 |
self.time_module = time
|
| 89 |
|
| 90 |
-
async def
|
| 91 |
-
"""
|
| 92 |
import asyncio
|
| 93 |
import concurrent.futures
|
| 94 |
-
import urllib.request
|
| 95 |
-
import os
|
| 96 |
|
| 97 |
-
def
|
| 98 |
-
"""
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
# Run
|
| 118 |
loop = asyncio.get_event_loop()
|
| 119 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 120 |
-
await loop.run_in_executor(executor,
|
| 121 |
|
| 122 |
-
|
| 123 |
|
| 124 |
async def initialize_models(self):
|
| 125 |
"""Initialize the AI models and environment"""
|
|
@@ -159,43 +177,20 @@ class WebGameEngine:
|
|
| 159 |
|
| 160 |
# Try to load checkpoint (remote first, then local, then dummy mode)
|
| 161 |
try:
|
| 162 |
-
# First try to
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
# Download to cache directory
|
| 172 |
-
cache_dir = "./cache"
|
| 173 |
-
os.makedirs(cache_dir, exist_ok=True)
|
| 174 |
-
model_cache_path = os.path.join(cache_dir, "agent_epoch_00003.pt")
|
| 175 |
-
|
| 176 |
-
# Download if not cached
|
| 177 |
-
if not os.path.exists(model_cache_path):
|
| 178 |
-
logger.info(f"Downloading 1.53GB model to {model_cache_path}...")
|
| 179 |
-
self.loading_status = "Downloading AI model from Hugging Face Hub..."
|
| 180 |
-
|
| 181 |
-
# Download with progress tracking in a separate thread
|
| 182 |
-
await self._download_model_async(model_url, model_cache_path)
|
| 183 |
-
else:
|
| 184 |
-
logger.info(f"Using cached model from {model_cache_path}")
|
| 185 |
-
self.loading_status = "Loading cached model..."
|
| 186 |
-
|
| 187 |
-
# Use the agent's load method which expects a file path
|
| 188 |
-
self.loading_status = "Loading model weights..."
|
| 189 |
-
agent.load(model_cache_path)
|
| 190 |
-
logger.info(f"Successfully loaded checkpoint from HF Hub")
|
| 191 |
-
|
| 192 |
-
except Exception as hub_error:
|
| 193 |
-
logger.warning(f"Failed to download from HF Hub: {hub_error}")
|
| 194 |
-
|
| 195 |
# Fallback to local checkpoint if available
|
|
|
|
| 196 |
checkpoint_path = web_config.get_checkpoint_path()
|
| 197 |
if checkpoint_path.exists():
|
| 198 |
-
logger.info(f"
|
|
|
|
| 199 |
agent.load(checkpoint_path)
|
| 200 |
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
| 201 |
else:
|
|
|
|
| 29 |
from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
|
| 30 |
from src.envs import WorldModelEnv
|
| 31 |
from src.game.web_play_env import WebPlayEnv
|
| 32 |
+
from src.utils import extract_state_dict
|
| 33 |
from config_web import web_config
|
| 34 |
|
| 35 |
# Configure logging
|
|
|
|
| 88 |
import time
|
| 89 |
self.time_module = time
|
| 90 |
|
| 91 |
+
async def _load_model_from_url_async(self, agent, device):
|
| 92 |
+
"""Load model from URL using torch.hub (HF Spaces compatible)"""
|
| 93 |
import asyncio
|
| 94 |
import concurrent.futures
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
def load_model_weights():
|
| 97 |
+
"""Load model weights in thread pool to avoid blocking"""
|
| 98 |
+
try:
|
| 99 |
+
# Use torch.hub.load_state_dict_from_url which is HF Spaces compatible
|
| 100 |
+
model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
|
| 101 |
+
logger.info(f"Loading model from {model_url} using torch.hub...")
|
| 102 |
+
|
| 103 |
+
# Update progress
|
| 104 |
+
self.download_progress = 10
|
| 105 |
+
self.loading_status = "Downloading model with torch.hub..."
|
| 106 |
+
|
| 107 |
+
# Load state dict directly from URL (handles caching automatically)
|
| 108 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 109 |
+
model_url,
|
| 110 |
+
map_location=device,
|
| 111 |
+
progress=True # Show download progress
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.download_progress = 80
|
| 115 |
+
self.loading_status = "Loading model weights..."
|
| 116 |
+
|
| 117 |
+
# Load each component of the agent using extract_state_dict (same as agent.load method)
|
| 118 |
+
if any(k.startswith("denoiser") for k in state_dict.keys()):
|
| 119 |
+
agent.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
|
| 120 |
+
if any(k.startswith("upsampler") for k in state_dict.keys()) and agent.upsampler is not None:
|
| 121 |
+
agent.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
|
| 122 |
+
if any(k.startswith("rew_end_model") for k in state_dict.keys()) and agent.rew_end_model is not None:
|
| 123 |
+
agent.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
|
| 124 |
+
if any(k.startswith("actor_critic") for k in state_dict.keys()) and agent.actor_critic is not None:
|
| 125 |
+
agent.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))
|
| 126 |
+
|
| 127 |
+
self.download_progress = 100
|
| 128 |
+
self.loading_status = "Model loaded successfully!"
|
| 129 |
+
return True
|
| 130 |
+
|
| 131 |
+
except Exception as e:
|
| 132 |
+
logger.error(f"Failed to load model from URL: {e}")
|
| 133 |
+
return False
|
| 134 |
|
| 135 |
+
# Run in thread pool to avoid blocking
|
| 136 |
loop = asyncio.get_event_loop()
|
| 137 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 138 |
+
success = await loop.run_in_executor(executor, load_model_weights)
|
| 139 |
|
| 140 |
+
return success
|
| 141 |
|
| 142 |
async def initialize_models(self):
|
| 143 |
"""Initialize the AI models and environment"""
|
|
|
|
| 177 |
|
| 178 |
# Try to load checkpoint (remote first, then local, then dummy mode)
|
| 179 |
try:
|
| 180 |
+
# First try to load from Hugging Face Hub using torch.hub
|
| 181 |
+
logger.info("Loading model from Hugging Face Hub with torch.hub...")
|
| 182 |
+
|
| 183 |
+
success = await self._load_model_from_url_async(agent, device)
|
| 184 |
+
|
| 185 |
+
if success:
|
| 186 |
+
logger.info("Successfully loaded checkpoint from HF Hub")
|
| 187 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
# Fallback to local checkpoint if available
|
| 189 |
+
logger.warning("Failed to load from HF Hub, trying local checkpoint...")
|
| 190 |
checkpoint_path = web_config.get_checkpoint_path()
|
| 191 |
if checkpoint_path.exists():
|
| 192 |
+
logger.info(f"Loading local checkpoint: {checkpoint_path}")
|
| 193 |
+
self.loading_status = "Loading local checkpoint..."
|
| 194 |
agent.load(checkpoint_path)
|
| 195 |
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
| 196 |
else:
|