musictimer commited on
Commit
41e58ab
·
1 Parent(s): 02c6351

Fix initial bugs

Browse files
Files changed (1) hide show
  1. app.py +69 -17
app.py CHANGED
@@ -167,14 +167,24 @@ class WebGameEngine:
167
  self.loading_status = "Loading model weights into agent..."
168
  logger.info("State dict loaded, applying to agent...")
169
 
170
- # Load state dict into agent, but skip actor_critic if not present
171
  has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
172
- logger.info(f"Model has actor_critic weights: {has_actor_critic}")
 
 
 
 
 
173
  agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
174
 
175
  # Track if actor_critic was actually loaded with trained weights
176
  self.actor_critic_loaded = has_actor_critic
177
 
 
 
 
 
 
178
  self.download_progress = 100
179
  self.loading_status = "Model loaded successfully!"
180
  logger.info("All model weights loaded successfully!")
@@ -297,35 +307,61 @@ class WebGameEngine:
297
  logger.info(f"Actor-critic device: {agent.actor_critic.device}")
298
  # Force AI control for web demo
299
  self.play_env.is_human_player = False
300
- logger.info("WebPlayEnv set to AI control mode")
301
  elif agent.actor_critic is not None and not self.actor_critic_loaded:
302
- logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
 
 
303
  self.play_env.is_human_player = True
 
304
  logger.info("WebPlayEnv set to human control mode (no trained weights)")
305
  else:
306
- logger.warning("No actor-critic model found - AI inference will not work!")
307
  self.play_env.is_human_player = True
308
  logger.info("WebPlayEnv set to human control mode (fallback)")
309
 
310
- # Enable torch.compile by default like play.py does (can disable with DISABLE_TORCH_COMPILE=1)
311
- import os, pwd
312
  try:
313
  pwd.getpwuid(os.getuid())
314
  except KeyError:
315
  os.environ["USER"] = "huggingface"
316
 
317
- os.environ["DISABLE_TORCH_COMPILE"] = "0"
318
- if device.type == "cuda" and os.getenv("DISABLE_TORCH_COMPILE", "0") != "1":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  logger.info("Compiling models for faster inference (like play.py --compile)...")
320
  try:
321
  wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
322
  if wm_env.upsample_next_obs is not None:
323
  wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
324
- logger.info("Model compilation enabled successfully!")
325
  except Exception as e:
326
- logger.warning(f"Model compilation failed: {e}")
 
327
  else:
328
- logger.info("Model compilation disabled. Set DISABLE_TORCH_COMPILE=0 to enable.")
 
329
 
330
  # Reset environment
331
  self.obs, _ = self.play_env.reset()
@@ -714,17 +750,33 @@ class WebGameEngine:
714
  start = self.time_module.time()
715
 
716
  # Use FP16 autocast for faster inference (like play.py can do with modern GPUs)
717
- from torch.cuda.amp import autocast
718
- with autocast(dtype=torch.float16, enabled=torch.cuda.is_available()):
 
719
  res = self.play_env.step_from_web_input(**web_state)
720
 
721
  infer_t = self.time_module.time() - start
722
  await self._out_queue.put((*res, infer_t))
723
  except Exception as e:
724
  logger.error(f"Inference worker error: {e}")
725
- # Put a dummy result to avoid hanging
726
- dummy_obs = self.obs if self.obs is not None else torch.zeros(3, 150, 600)
727
- await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
  # Global game engine instance
730
  game_engine = WebGameEngine()
 
167
  self.loading_status = "Loading model weights into agent..."
168
  logger.info("State dict loaded, applying to agent...")
169
 
170
+ # Check what components are in the state dict
171
  has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
172
+ has_denoiser = any(k.startswith('denoiser.') for k in state_dict.keys())
173
+ has_upsampler = any(k.startswith('upsampler.') for k in state_dict.keys())
174
+
175
+ logger.info(f"Model components found - actor_critic: {has_actor_critic}, denoiser: {has_denoiser}, upsampler: {has_upsampler}")
176
+
177
+ # Load state dict into agent
178
  agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
179
 
180
  # Track if actor_critic was actually loaded with trained weights
181
  self.actor_critic_loaded = has_actor_critic
182
 
183
+ # For HF Spaces demo, if no actor_critic, we can still show the world model
184
+ if not has_actor_critic:
185
+ logger.warning("No actor_critic weights found - world model will work but AI won't play")
186
+ logger.info("Users can still interact and see the world model predictions")
187
+
188
  self.download_progress = 100
189
  self.loading_status = "Model loaded successfully!"
190
  logger.info("All model weights loaded successfully!")
 
307
  logger.info(f"Actor-critic device: {agent.actor_critic.device}")
308
  # Force AI control for web demo
309
  self.play_env.is_human_player = False
310
+ logger.info("WebPlayEnv set to AI control mode - ready for inference!")
311
  elif agent.actor_critic is not None and not self.actor_critic_loaded:
312
+ logger.warning("⚠️ Actor-critic model exists but has no trained weights!")
313
+ logger.info("🎮 Demo will work in world-model mode (human input + world simulation)")
314
+ # Still allow human input to drive the world model
315
  self.play_env.is_human_player = True
316
+ self.play_env.human_input_override = True # Always use human input
317
  logger.info("WebPlayEnv set to human control mode (no trained weights)")
318
  else:
319
+ logger.warning("No actor-critic model found - AI inference will not work!")
320
  self.play_env.is_human_player = True
321
  logger.info("WebPlayEnv set to human control mode (fallback)")
322
 
323
+ # Set up cache directories for HF Spaces compatibility
324
+ import os, pwd, tempfile
325
  try:
326
  pwd.getpwuid(os.getuid())
327
  except KeyError:
328
  os.environ["USER"] = "huggingface"
329
 
330
+ # Set writable cache directories for HF Spaces
331
+ cache_dir = tempfile.gettempdir()
332
+ os.environ.setdefault("TRITON_CACHE_DIR", os.path.join(cache_dir, "triton"))
333
+ os.environ.setdefault("TORCH_COMPILE_DEBUG_DIR", os.path.join(cache_dir, "torch_compile"))
334
+
335
+ # Create cache directories
336
+ for cache_var in ["TRITON_CACHE_DIR", "TORCH_COMPILE_DEBUG_DIR"]:
337
+ cache_path = os.environ[cache_var]
338
+ os.makedirs(cache_path, exist_ok=True)
339
+
340
+ # Enable torch.compile with proper error handling for HF Spaces
341
+ # Check if we're on HF Spaces (common indicators)
342
+ is_hf_spaces = any([
343
+ 'space_id' in os.environ,
344
+ 'huggingface' in os.environ.get('USER', '').lower(),
345
+ '/app' in os.getcwd()
346
+ ])
347
+
348
+ compile_enabled = (device.type == "cuda" and
349
+ os.getenv("DISABLE_TORCH_COMPILE", "0") != "1" and
350
+ not is_hf_spaces) # Disable by default on HF Spaces due to permission issues
351
+
352
+ if compile_enabled:
353
  logger.info("Compiling models for faster inference (like play.py --compile)...")
354
  try:
355
  wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
356
  if wm_env.upsample_next_obs is not None:
357
  wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
358
+ logger.info("Model compilation enabled successfully!")
359
  except Exception as e:
360
+ logger.warning(f"⚠️ Model compilation failed: {e}")
361
+ logger.info("Continuing without model compilation...")
362
  else:
363
+ reason = "HF Spaces detected" if is_hf_spaces else "disabled by env var"
364
+ logger.info(f"Model compilation disabled ({reason}). Models will run uncompiled.")
365
 
366
  # Reset environment
367
  self.obs, _ = self.play_env.reset()
 
750
  start = self.time_module.time()
751
 
752
  # Use FP16 autocast for faster inference (like play.py can do with modern GPUs)
753
+ # Use newer autocast API to avoid deprecation warning
754
+ import torch
755
+ with torch.amp.autocast('cuda', dtype=torch.float16, enabled=torch.cuda.is_available()):
756
  res = self.play_env.step_from_web_input(**web_state)
757
 
758
  infer_t = self.time_module.time() - start
759
  await self._out_queue.put((*res, infer_t))
760
  except Exception as e:
761
  logger.error(f"Inference worker error: {e}")
762
+ import traceback
763
+ logger.error(f"Full traceback: {traceback.format_exc()}")
764
+
765
+ # Create a proper dummy result with correct tensor properties
766
+ try:
767
+ if self.obs is not None and hasattr(self.obs, 'shape') and hasattr(self.obs, 'device'):
768
+ dummy_obs = self.obs.clone()
769
+ else:
770
+ # Fallback to a standard tensor on the right device
771
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
772
+ dummy_obs = torch.zeros(1, 3, 150, 600, device=device)
773
+
774
+ await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0))
775
+ except Exception as e2:
776
+ logger.error(f"Error creating dummy result: {e2}")
777
+ # Last resort - create CPU tensor
778
+ dummy_obs = torch.zeros(1, 3, 150, 600)
779
+ await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0))
780
 
781
  # Global game engine instance
782
  game_engine = WebGameEngine()