Spaces:
Sleeping
Sleeping
Commit
ยท
1d96a61
1
Parent(s):
b8159f9
Fix initial bugs
Browse files
app.py
CHANGED
|
@@ -99,6 +99,7 @@ class WebGameEngine:
|
|
| 99 |
self.models_ready = False # Track if models are loaded
|
| 100 |
self.download_progress = 0 # Track download progress (0-100)
|
| 101 |
self.loading_status = "Initializing..." # Loading status message
|
|
|
|
| 102 |
import time
|
| 103 |
self.time_module = time
|
| 104 |
|
|
@@ -129,6 +130,9 @@ class WebGameEngine:
|
|
| 129 |
logger.info(f"Model has actor_critic weights: {has_actor_critic}")
|
| 130 |
agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
| 132 |
self.download_progress = 100
|
| 133 |
self.loading_status = "Model loaded successfully!"
|
| 134 |
logger.info("All model weights loaded successfully!")
|
|
@@ -220,6 +224,8 @@ class WebGameEngine:
|
|
| 220 |
self.loading_status = "Loading local checkpoint..."
|
| 221 |
agent.load(checkpoint_path)
|
| 222 |
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
|
|
|
|
|
|
| 223 |
else:
|
| 224 |
logger.error(f"No local checkpoint found at: {checkpoint_path}")
|
| 225 |
raise FileNotFoundError("No model checkpoint available (local or remote)")
|
|
@@ -227,6 +233,7 @@ class WebGameEngine:
|
|
| 227 |
except Exception as e:
|
| 228 |
logger.error(f"Failed to load any checkpoint: {e}")
|
| 229 |
self._init_dummy_mode()
|
|
|
|
| 230 |
return True
|
| 231 |
|
| 232 |
# Initialize world model environment
|
|
@@ -242,12 +249,16 @@ class WebGameEngine:
|
|
| 242 |
self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
|
| 243 |
|
| 244 |
# Verify actor-critic is loaded and ready for inference
|
| 245 |
-
if agent.actor_critic is not None:
|
| 246 |
logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
|
| 247 |
logger.info(f"Actor-critic device: {agent.actor_critic.device}")
|
| 248 |
# Force AI control for web demo
|
| 249 |
self.play_env.is_human_player = False
|
| 250 |
logger.info("WebPlayEnv set to AI control mode")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
else:
|
| 252 |
logger.warning("No actor-critic model found - AI inference will not work!")
|
| 253 |
self.play_env.is_human_player = True
|
|
@@ -281,6 +292,7 @@ class WebGameEngine:
|
|
| 281 |
except Exception as e:
|
| 282 |
logger.error(f"Failed to initialize world model environment: {e}")
|
| 283 |
self._init_dummy_mode()
|
|
|
|
| 284 |
self.models_ready = True
|
| 285 |
self.loading_status = "Using dummy mode"
|
| 286 |
return True
|
|
@@ -290,6 +302,7 @@ class WebGameEngine:
|
|
| 290 |
import traceback
|
| 291 |
traceback.print_exc()
|
| 292 |
self._init_dummy_mode()
|
|
|
|
| 293 |
self.models_ready = True
|
| 294 |
self.loading_status = "Error - using dummy mode"
|
| 295 |
return True
|
|
@@ -573,7 +586,7 @@ async def get_homepage():
|
|
| 573 |
<!DOCTYPE html>
|
| 574 |
<html>
|
| 575 |
<head>
|
| 576 |
-
<title>
|
| 577 |
<style>
|
| 578 |
body {
|
| 579 |
margin: 0;
|
|
@@ -629,7 +642,7 @@ async def get_homepage():
|
|
| 629 |
</style>
|
| 630 |
</head>
|
| 631 |
<body>
|
| 632 |
-
<h1>๐ฎ
|
| 633 |
<p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
|
| 634 |
<p id="loadingIndicator" style="color: #ffff00; display: none;">๐ Starting AI inference... This may take 5-15 seconds on first run.</p>
|
| 635 |
|
|
|
|
| 99 |
self.models_ready = False # Track if models are loaded
|
| 100 |
self.download_progress = 0 # Track download progress (0-100)
|
| 101 |
self.loading_status = "Initializing..." # Loading status message
|
| 102 |
+
self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
|
| 103 |
import time
|
| 104 |
self.time_module = time
|
| 105 |
|
|
|
|
| 130 |
logger.info(f"Model has actor_critic weights: {has_actor_critic}")
|
| 131 |
agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
|
| 132 |
|
| 133 |
+
# Track if actor_critic was actually loaded with trained weights
|
| 134 |
+
self.actor_critic_loaded = has_actor_critic
|
| 135 |
+
|
| 136 |
self.download_progress = 100
|
| 137 |
self.loading_status = "Model loaded successfully!"
|
| 138 |
logger.info("All model weights loaded successfully!")
|
|
|
|
| 224 |
self.loading_status = "Loading local checkpoint..."
|
| 225 |
agent.load(checkpoint_path)
|
| 226 |
logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
|
| 227 |
+
# Assume local checkpoint has actor_critic weights (may need verification)
|
| 228 |
+
self.actor_critic_loaded = True
|
| 229 |
else:
|
| 230 |
logger.error(f"No local checkpoint found at: {checkpoint_path}")
|
| 231 |
raise FileNotFoundError("No model checkpoint available (local or remote)")
|
|
|
|
| 233 |
except Exception as e:
|
| 234 |
logger.error(f"Failed to load any checkpoint: {e}")
|
| 235 |
self._init_dummy_mode()
|
| 236 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 237 |
return True
|
| 238 |
|
| 239 |
# Initialize world model environment
|
|
|
|
| 249 |
self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
|
| 250 |
|
| 251 |
# Verify actor-critic is loaded and ready for inference
|
| 252 |
+
if agent.actor_critic is not None and self.actor_critic_loaded:
|
| 253 |
logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
|
| 254 |
logger.info(f"Actor-critic device: {agent.actor_critic.device}")
|
| 255 |
# Force AI control for web demo
|
| 256 |
self.play_env.is_human_player = False
|
| 257 |
logger.info("WebPlayEnv set to AI control mode")
|
| 258 |
+
elif agent.actor_critic is not None and not self.actor_critic_loaded:
|
| 259 |
+
logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
|
| 260 |
+
self.play_env.is_human_player = True
|
| 261 |
+
logger.info("WebPlayEnv set to human control mode (no trained weights)")
|
| 262 |
else:
|
| 263 |
logger.warning("No actor-critic model found - AI inference will not work!")
|
| 264 |
self.play_env.is_human_player = True
|
|
|
|
| 292 |
except Exception as e:
|
| 293 |
logger.error(f"Failed to initialize world model environment: {e}")
|
| 294 |
self._init_dummy_mode()
|
| 295 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 296 |
self.models_ready = True
|
| 297 |
self.loading_status = "Using dummy mode"
|
| 298 |
return True
|
|
|
|
| 302 |
import traceback
|
| 303 |
traceback.print_exc()
|
| 304 |
self._init_dummy_mode()
|
| 305 |
+
self.actor_critic_loaded = False # No actor_critic in dummy mode
|
| 306 |
self.models_ready = True
|
| 307 |
self.loading_status = "Error - using dummy mode"
|
| 308 |
return True
|
|
|
|
| 586 |
<!DOCTYPE html>
|
| 587 |
<html>
|
| 588 |
<head>
|
| 589 |
+
<title>Physics-informed BEV World Model</title>
|
| 590 |
<style>
|
| 591 |
body {
|
| 592 |
margin: 0;
|
|
|
|
| 642 |
</style>
|
| 643 |
</head>
|
| 644 |
<body>
|
| 645 |
+
<h1>๐ฎ Physics-informed BEV World Model</h1>
|
| 646 |
<p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
|
| 647 |
<p id="loadingIndicator" style="color: #ffff00; display: none;">๐ Starting AI inference... This may take 5-15 seconds on first run.</p>
|
| 648 |
|
src/game/__pycache__/dataset_env.cpython-310.pyc
CHANGED
|
Binary files a/src/game/__pycache__/dataset_env.cpython-310.pyc and b/src/game/__pycache__/dataset_env.cpython-310.pyc differ
|
|
|
src/game/__pycache__/web_play_env.cpython-310.pyc
CHANGED
|
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
|
|
|
src/game/web_play_env.py
CHANGED
|
@@ -32,9 +32,13 @@ class WebPlayEnv(PlayEnv):
|
|
| 32 |
self.is_human_player = False # AI controls the actions
|
| 33 |
self.human_input_override = False # Can be set to True to allow human input
|
| 34 |
|
| 35 |
-
# Initialize LSTM hidden states for actor-critic
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def switch_controller(self) -> None:
|
| 40 |
"""Switch between AI and human control"""
|
|
@@ -97,9 +101,11 @@ class WebPlayEnv(PlayEnv):
|
|
| 97 |
if obs.device != self.agent.device:
|
| 98 |
obs = obs.to(self.agent.device, non_blocking=True)
|
| 99 |
|
| 100 |
-
# Detach hidden states to prevent gradient tracking
|
| 101 |
-
self.hx
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
|
| 104 |
# Resize observation to match actor-critic expected encoder/LSTM input
|
| 105 |
# Count how many MaxPool2d layers are in the encoder to infer downsampling factor
|
|
@@ -145,10 +151,12 @@ class WebPlayEnv(PlayEnv):
|
|
| 145 |
self.obs = next_obs
|
| 146 |
self.t += 1
|
| 147 |
|
| 148 |
-
# Reset hidden states on episode end
|
| 149 |
if end.any() or trunc.any():
|
| 150 |
-
self.hx
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# Return the step results
|
| 154 |
return next_obs, rew, end, trunc, env_info
|
|
|
|
| 32 |
self.is_human_player = False # AI controls the actions
|
| 33 |
self.human_input_override = False # Can be set to True to allow human input
|
| 34 |
|
| 35 |
+
# Initialize LSTM hidden states for actor-critic (only if actor_critic exists)
|
| 36 |
+
if agent.actor_critic is not None:
|
| 37 |
+
self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
|
| 38 |
+
self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
|
| 39 |
+
else:
|
| 40 |
+
self.hx = None
|
| 41 |
+
self.cx = None
|
| 42 |
|
| 43 |
def switch_controller(self) -> None:
|
| 44 |
"""Switch between AI and human control"""
|
|
|
|
| 101 |
if obs.device != self.agent.device:
|
| 102 |
obs = obs.to(self.agent.device, non_blocking=True)
|
| 103 |
|
| 104 |
+
# Detach hidden states to prevent gradient tracking (only if they exist)
|
| 105 |
+
if self.hx is not None:
|
| 106 |
+
self.hx = self.hx.detach()
|
| 107 |
+
if self.cx is not None:
|
| 108 |
+
self.cx = self.cx.detach()
|
| 109 |
|
| 110 |
# Resize observation to match actor-critic expected encoder/LSTM input
|
| 111 |
# Count how many MaxPool2d layers are in the encoder to infer downsampling factor
|
|
|
|
| 151 |
self.obs = next_obs
|
| 152 |
self.t += 1
|
| 153 |
|
| 154 |
+
# Reset hidden states on episode end (only if they exist)
|
| 155 |
if end.any() or trunc.any():
|
| 156 |
+
if self.hx is not None:
|
| 157 |
+
self.hx.zero_()
|
| 158 |
+
if self.cx is not None:
|
| 159 |
+
self.cx.zero_()
|
| 160 |
|
| 161 |
# Return the step results
|
| 162 |
return next_obs, rew, end, trunc, env_info
|