musictimer commited on
Commit
93dbff3
·
1 Parent(s): a29f249
Files changed (1) hide show
  1. app.py +55 -60
app.py CHANGED
@@ -29,6 +29,7 @@ from src.agent import Agent
29
  from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
30
  from src.envs import WorldModelEnv
31
  from src.game.web_play_env import WebPlayEnv
 
32
  from config_web import web_config
33
 
34
  # Configure logging
@@ -87,39 +88,56 @@ class WebGameEngine:
87
  import time
88
  self.time_module = time
89
 
90
- async def _download_model_async(self, url, filepath):
91
- """Download model asynchronously with progress tracking"""
92
  import asyncio
93
  import concurrent.futures
94
- import urllib.request
95
- import os
96
 
97
- def download_with_progress():
98
- """Download function that runs in thread pool"""
99
- def progress_hook(block_num, block_size, total_size):
100
- if total_size > 0:
101
- progress = min(100, (block_num * block_size * 100) / total_size)
102
- new_progress = int(progress)
103
-
104
- # Update progress more frequently for smooth progress bar
105
- if new_progress != self.download_progress:
106
- self.download_progress = new_progress
107
- self.loading_status = f"Downloading AI model ({self.download_progress}%)"
108
-
109
- # Log every 5% instead of 10% for better feedback
110
- if self.download_progress % 5 == 0:
111
- logger.info(f"Download progress: {self.download_progress}%")
112
-
113
- urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
114
- self.download_progress = 100
115
- self.loading_status = "Download complete! Loading model..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- # Run download in thread pool to avoid blocking
118
  loop = asyncio.get_event_loop()
119
  with concurrent.futures.ThreadPoolExecutor() as executor:
120
- await loop.run_in_executor(executor, download_with_progress)
121
 
122
- logger.info("Model download completed!")
123
 
124
  async def initialize_models(self):
125
  """Initialize the AI models and environment"""
@@ -159,43 +177,20 @@ class WebGameEngine:
159
 
160
  # Try to load checkpoint (remote first, then local, then dummy mode)
161
  try:
162
- # First try to download from Hugging Face Hub using direct URL
163
- try:
164
- import torch.hub
165
- import os
166
- logger.info("Downloading model from Hugging Face Hub...")
167
-
168
- # Direct download URL (change 'blob' to 'resolve' for direct download)
169
- model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
170
-
171
- # Download to cache directory
172
- cache_dir = "./cache"
173
- os.makedirs(cache_dir, exist_ok=True)
174
- model_cache_path = os.path.join(cache_dir, "agent_epoch_00003.pt")
175
-
176
- # Download if not cached
177
- if not os.path.exists(model_cache_path):
178
- logger.info(f"Downloading 1.53GB model to {model_cache_path}...")
179
- self.loading_status = "Downloading AI model from Hugging Face Hub..."
180
-
181
- # Download with progress tracking in a separate thread
182
- await self._download_model_async(model_url, model_cache_path)
183
- else:
184
- logger.info(f"Using cached model from {model_cache_path}")
185
- self.loading_status = "Loading cached model..."
186
-
187
- # Use the agent's load method which expects a file path
188
- self.loading_status = "Loading model weights..."
189
- agent.load(model_cache_path)
190
- logger.info(f"Successfully loaded checkpoint from HF Hub")
191
-
192
- except Exception as hub_error:
193
- logger.warning(f"Failed to download from HF Hub: {hub_error}")
194
-
195
  # Fallback to local checkpoint if available
 
196
  checkpoint_path = web_config.get_checkpoint_path()
197
  if checkpoint_path.exists():
198
- logger.info(f"Falling back to local checkpoint: {checkpoint_path}")
 
199
  agent.load(checkpoint_path)
200
  logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
201
  else:
 
29
  from src.csgo.web_action_processing import WebCSGOAction, web_keys_to_csgo_action_names
30
  from src.envs import WorldModelEnv
31
  from src.game.web_play_env import WebPlayEnv
32
+ from src.utils import extract_state_dict
33
  from config_web import web_config
34
 
35
  # Configure logging
 
88
  import time
89
  self.time_module = time
90
 
91
+ async def _load_model_from_url_async(self, agent, device):
92
+ """Load model from URL using torch.hub (HF Spaces compatible)"""
93
  import asyncio
94
  import concurrent.futures
 
 
95
 
96
+ def load_model_weights():
97
+ """Load model weights in thread pool to avoid blocking"""
98
+ try:
99
+ # Use torch.hub.load_state_dict_from_url which is HF Spaces compatible
100
+ model_url = "https://huggingface.co/Etadingrui/diamond-1B/resolve/main/agent_epoch_00003.pt"
101
+ logger.info(f"Loading model from {model_url} using torch.hub...")
102
+
103
+ # Update progress
104
+ self.download_progress = 10
105
+ self.loading_status = "Downloading model with torch.hub..."
106
+
107
+ # Load state dict directly from URL (handles caching automatically)
108
+ state_dict = torch.hub.load_state_dict_from_url(
109
+ model_url,
110
+ map_location=device,
111
+ progress=True # Show download progress
112
+ )
113
+
114
+ self.download_progress = 80
115
+ self.loading_status = "Loading model weights..."
116
+
117
+ # Load each component of the agent using extract_state_dict (same as agent.load method)
118
+ if any(k.startswith("denoiser") for k in state_dict.keys()):
119
+ agent.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser"))
120
+ if any(k.startswith("upsampler") for k in state_dict.keys()) and agent.upsampler is not None:
121
+ agent.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler"))
122
+ if any(k.startswith("rew_end_model") for k in state_dict.keys()) and agent.rew_end_model is not None:
123
+ agent.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model"))
124
+ if any(k.startswith("actor_critic") for k in state_dict.keys()) and agent.actor_critic is not None:
125
+ agent.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))
126
+
127
+ self.download_progress = 100
128
+ self.loading_status = "Model loaded successfully!"
129
+ return True
130
+
131
+ except Exception as e:
132
+ logger.error(f"Failed to load model from URL: {e}")
133
+ return False
134
 
135
+ # Run in thread pool to avoid blocking
136
  loop = asyncio.get_event_loop()
137
  with concurrent.futures.ThreadPoolExecutor() as executor:
138
+ success = await loop.run_in_executor(executor, load_model_weights)
139
 
140
+ return success
141
 
142
  async def initialize_models(self):
143
  """Initialize the AI models and environment"""
 
177
 
178
  # Try to load checkpoint (remote first, then local, then dummy mode)
179
  try:
180
+ # First try to load from Hugging Face Hub using torch.hub
181
+ logger.info("Loading model from Hugging Face Hub with torch.hub...")
182
+
183
+ success = await self._load_model_from_url_async(agent, device)
184
+
185
+ if success:
186
+ logger.info("Successfully loaded checkpoint from HF Hub")
187
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  # Fallback to local checkpoint if available
189
+ logger.warning("Failed to load from HF Hub, trying local checkpoint...")
190
  checkpoint_path = web_config.get_checkpoint_path()
191
  if checkpoint_path.exists():
192
+ logger.info(f"Loading local checkpoint: {checkpoint_path}")
193
+ self.loading_status = "Loading local checkpoint..."
194
  agent.load(checkpoint_path)
195
  logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
196
  else: