Spaces:
Sleeping
Sleeping
Commit
·
41e58ab
1
Parent(s):
02c6351
Fix initial bugs
Browse files
app.py
CHANGED
|
@@ -167,14 +167,24 @@ class WebGameEngine:
|
|
| 167 |
self.loading_status = "Loading model weights into agent..."
|
| 168 |
logger.info("State dict loaded, applying to agent...")
|
| 169 |
|
| 170 |
-
#
|
| 171 |
has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
|
| 174 |
|
| 175 |
# Track if actor_critic was actually loaded with trained weights
|
| 176 |
self.actor_critic_loaded = has_actor_critic
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
self.download_progress = 100
|
| 179 |
self.loading_status = "Model loaded successfully!"
|
| 180 |
logger.info("All model weights loaded successfully!")
|
|
@@ -297,35 +307,61 @@ class WebGameEngine:
|
|
| 297 |
logger.info(f"Actor-critic device: {agent.actor_critic.device}")
|
| 298 |
# Force AI control for web demo
|
| 299 |
self.play_env.is_human_player = False
|
| 300 |
-
logger.info("WebPlayEnv set to AI control mode")
|
| 301 |
elif agent.actor_critic is not None and not self.actor_critic_loaded:
|
| 302 |
-
logger.warning("Actor-critic model exists but has no trained weights
|
|
|
|
|
|
|
| 303 |
self.play_env.is_human_player = True
|
|
|
|
| 304 |
logger.info("WebPlayEnv set to human control mode (no trained weights)")
|
| 305 |
else:
|
| 306 |
-
logger.warning("No actor-critic model found - AI inference will not work!")
|
| 307 |
self.play_env.is_human_player = True
|
| 308 |
logger.info("WebPlayEnv set to human control mode (fallback)")
|
| 309 |
|
| 310 |
-
#
|
| 311 |
-
import os, pwd
|
| 312 |
try:
|
| 313 |
pwd.getpwuid(os.getuid())
|
| 314 |
except KeyError:
|
| 315 |
os.environ["USER"] = "huggingface"
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
logger.info("Compiling models for faster inference (like play.py --compile)...")
|
| 320 |
try:
|
| 321 |
wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
|
| 322 |
if wm_env.upsample_next_obs is not None:
|
| 323 |
wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
|
| 324 |
-
logger.info("Model compilation enabled successfully!")
|
| 325 |
except Exception as e:
|
| 326 |
-
logger.warning(f"Model compilation failed: {e}")
|
|
|
|
| 327 |
else:
|
| 328 |
-
|
|
|
|
| 329 |
|
| 330 |
# Reset environment
|
| 331 |
self.obs, _ = self.play_env.reset()
|
|
@@ -714,17 +750,33 @@ class WebGameEngine:
|
|
| 714 |
start = self.time_module.time()
|
| 715 |
|
| 716 |
# Use FP16 autocast for faster inference (like play.py can do with modern GPUs)
|
| 717 |
-
|
| 718 |
-
|
|
|
|
| 719 |
res = self.play_env.step_from_web_input(**web_state)
|
| 720 |
|
| 721 |
infer_t = self.time_module.time() - start
|
| 722 |
await self._out_queue.put((*res, infer_t))
|
| 723 |
except Exception as e:
|
| 724 |
logger.error(f"Inference worker error: {e}")
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
|
| 729 |
# Global game engine instance
|
| 730 |
game_engine = WebGameEngine()
|
|
|
|
| 167 |
self.loading_status = "Loading model weights into agent..."
|
| 168 |
logger.info("State dict loaded, applying to agent...")
|
| 169 |
|
| 170 |
+
# Check what components are in the state dict
|
| 171 |
has_actor_critic = any(k.startswith('actor_critic.') for k in state_dict.keys())
|
| 172 |
+
has_denoiser = any(k.startswith('denoiser.') for k in state_dict.keys())
|
| 173 |
+
has_upsampler = any(k.startswith('upsampler.') for k in state_dict.keys())
|
| 174 |
+
|
| 175 |
+
logger.info(f"Model components found - actor_critic: {has_actor_critic}, denoiser: {has_denoiser}, upsampler: {has_upsampler}")
|
| 176 |
+
|
| 177 |
+
# Load state dict into agent
|
| 178 |
agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
|
| 179 |
|
| 180 |
# Track if actor_critic was actually loaded with trained weights
|
| 181 |
self.actor_critic_loaded = has_actor_critic
|
| 182 |
|
| 183 |
+
# For HF Spaces demo, if no actor_critic, we can still show the world model
|
| 184 |
+
if not has_actor_critic:
|
| 185 |
+
logger.warning("No actor_critic weights found - world model will work but AI won't play")
|
| 186 |
+
logger.info("Users can still interact and see the world model predictions")
|
| 187 |
+
|
| 188 |
self.download_progress = 100
|
| 189 |
self.loading_status = "Model loaded successfully!"
|
| 190 |
logger.info("All model weights loaded successfully!")
|
|
|
|
| 307 |
logger.info(f"Actor-critic device: {agent.actor_critic.device}")
|
| 308 |
# Force AI control for web demo
|
| 309 |
self.play_env.is_human_player = False
|
| 310 |
+
logger.info("✅ WebPlayEnv set to AI control mode - ready for inference!")
|
| 311 |
elif agent.actor_critic is not None and not self.actor_critic_loaded:
|
| 312 |
+
logger.warning("⚠️ Actor-critic model exists but has no trained weights!")
|
| 313 |
+
logger.info("🎮 Demo will work in world-model mode (human input + world simulation)")
|
| 314 |
+
# Still allow human input to drive the world model
|
| 315 |
self.play_env.is_human_player = True
|
| 316 |
+
self.play_env.human_input_override = True # Always use human input
|
| 317 |
logger.info("WebPlayEnv set to human control mode (no trained weights)")
|
| 318 |
else:
|
| 319 |
+
logger.warning("❌ No actor-critic model found - AI inference will not work!")
|
| 320 |
self.play_env.is_human_player = True
|
| 321 |
logger.info("WebPlayEnv set to human control mode (fallback)")
|
| 322 |
|
| 323 |
+
# Set up cache directories for HF Spaces compatibility
|
| 324 |
+
import os, pwd, tempfile
|
| 325 |
try:
|
| 326 |
pwd.getpwuid(os.getuid())
|
| 327 |
except KeyError:
|
| 328 |
os.environ["USER"] = "huggingface"
|
| 329 |
|
| 330 |
+
# Set writable cache directories for HF Spaces
|
| 331 |
+
cache_dir = tempfile.gettempdir()
|
| 332 |
+
os.environ.setdefault("TRITON_CACHE_DIR", os.path.join(cache_dir, "triton"))
|
| 333 |
+
os.environ.setdefault("TORCH_COMPILE_DEBUG_DIR", os.path.join(cache_dir, "torch_compile"))
|
| 334 |
+
|
| 335 |
+
# Create cache directories
|
| 336 |
+
for cache_var in ["TRITON_CACHE_DIR", "TORCH_COMPILE_DEBUG_DIR"]:
|
| 337 |
+
cache_path = os.environ[cache_var]
|
| 338 |
+
os.makedirs(cache_path, exist_ok=True)
|
| 339 |
+
|
| 340 |
+
# Enable torch.compile with proper error handling for HF Spaces
|
| 341 |
+
# Check if we're on HF Spaces (common indicators)
|
| 342 |
+
is_hf_spaces = any([
|
| 343 |
+
'space_id' in os.environ,
|
| 344 |
+
'huggingface' in os.environ.get('USER', '').lower(),
|
| 345 |
+
'/app' in os.getcwd()
|
| 346 |
+
])
|
| 347 |
+
|
| 348 |
+
compile_enabled = (device.type == "cuda" and
|
| 349 |
+
os.getenv("DISABLE_TORCH_COMPILE", "0") != "1" and
|
| 350 |
+
not is_hf_spaces) # Disable by default on HF Spaces due to permission issues
|
| 351 |
+
|
| 352 |
+
if compile_enabled:
|
| 353 |
logger.info("Compiling models for faster inference (like play.py --compile)...")
|
| 354 |
try:
|
| 355 |
wm_env.predict_next_obs = torch.compile(wm_env.predict_next_obs, mode="reduce-overhead")
|
| 356 |
if wm_env.upsample_next_obs is not None:
|
| 357 |
wm_env.upsample_next_obs = torch.compile(wm_env.upsample_next_obs, mode="reduce-overhead")
|
| 358 |
+
logger.info("✅ Model compilation enabled successfully!")
|
| 359 |
except Exception as e:
|
| 360 |
+
logger.warning(f"⚠️ Model compilation failed: {e}")
|
| 361 |
+
logger.info("Continuing without model compilation...")
|
| 362 |
else:
|
| 363 |
+
reason = "HF Spaces detected" if is_hf_spaces else "disabled by env var"
|
| 364 |
+
logger.info(f"Model compilation disabled ({reason}). Models will run uncompiled.")
|
| 365 |
|
| 366 |
# Reset environment
|
| 367 |
self.obs, _ = self.play_env.reset()
|
|
|
|
| 750 |
start = self.time_module.time()
|
| 751 |
|
| 752 |
# Use FP16 autocast for faster inference (like play.py can do with modern GPUs)
|
| 753 |
+
# Use newer autocast API to avoid deprecation warning
|
| 754 |
+
import torch
|
| 755 |
+
with torch.amp.autocast('cuda', dtype=torch.float16, enabled=torch.cuda.is_available()):
|
| 756 |
res = self.play_env.step_from_web_input(**web_state)
|
| 757 |
|
| 758 |
infer_t = self.time_module.time() - start
|
| 759 |
await self._out_queue.put((*res, infer_t))
|
| 760 |
except Exception as e:
|
| 761 |
logger.error(f"Inference worker error: {e}")
|
| 762 |
+
import traceback
|
| 763 |
+
logger.error(f"Full traceback: {traceback.format_exc()}")
|
| 764 |
+
|
| 765 |
+
# Create a proper dummy result with correct tensor properties
|
| 766 |
+
try:
|
| 767 |
+
if self.obs is not None and hasattr(self.obs, 'shape') and hasattr(self.obs, 'device'):
|
| 768 |
+
dummy_obs = self.obs.clone()
|
| 769 |
+
else:
|
| 770 |
+
# Fallback to a standard tensor on the right device
|
| 771 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 772 |
+
dummy_obs = torch.zeros(1, 3, 150, 600, device=device)
|
| 773 |
+
|
| 774 |
+
await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0))
|
| 775 |
+
except Exception as e2:
|
| 776 |
+
logger.error(f"Error creating dummy result: {e2}")
|
| 777 |
+
# Last resort - create CPU tensor
|
| 778 |
+
dummy_obs = torch.zeros(1, 3, 150, 600)
|
| 779 |
+
await self._out_queue.put((dummy_obs, 0.0, False, False, {"error": str(e)}, 0.0))
|
| 780 |
|
| 781 |
# Global game engine instance
|
| 782 |
game_engine = WebGameEngine()
|