musictimer commited on
Commit
1d96a61
ยท
1 Parent(s): b8159f9

Fix initial bugs

Browse files
app.py CHANGED
@@ -99,6 +99,7 @@ class WebGameEngine:
99
  self.models_ready = False # Track if models are loaded
100
  self.download_progress = 0 # Track download progress (0-100)
101
  self.loading_status = "Initializing..." # Loading status message
 
102
  import time
103
  self.time_module = time
104
 
@@ -129,6 +130,9 @@ class WebGameEngine:
129
  logger.info(f"Model has actor_critic weights: {has_actor_critic}")
130
  agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
131
 
 
 
 
132
  self.download_progress = 100
133
  self.loading_status = "Model loaded successfully!"
134
  logger.info("All model weights loaded successfully!")
@@ -220,6 +224,8 @@ class WebGameEngine:
220
  self.loading_status = "Loading local checkpoint..."
221
  agent.load(checkpoint_path)
222
  logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
 
 
223
  else:
224
  logger.error(f"No local checkpoint found at: {checkpoint_path}")
225
  raise FileNotFoundError("No model checkpoint available (local or remote)")
@@ -227,6 +233,7 @@ class WebGameEngine:
227
  except Exception as e:
228
  logger.error(f"Failed to load any checkpoint: {e}")
229
  self._init_dummy_mode()
 
230
  return True
231
 
232
  # Initialize world model environment
@@ -242,12 +249,16 @@ class WebGameEngine:
242
  self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
243
 
244
  # Verify actor-critic is loaded and ready for inference
245
- if agent.actor_critic is not None:
246
  logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
247
  logger.info(f"Actor-critic device: {agent.actor_critic.device}")
248
  # Force AI control for web demo
249
  self.play_env.is_human_player = False
250
  logger.info("WebPlayEnv set to AI control mode")
 
 
 
 
251
  else:
252
  logger.warning("No actor-critic model found - AI inference will not work!")
253
  self.play_env.is_human_player = True
@@ -281,6 +292,7 @@ class WebGameEngine:
281
  except Exception as e:
282
  logger.error(f"Failed to initialize world model environment: {e}")
283
  self._init_dummy_mode()
 
284
  self.models_ready = True
285
  self.loading_status = "Using dummy mode"
286
  return True
@@ -290,6 +302,7 @@ class WebGameEngine:
290
  import traceback
291
  traceback.print_exc()
292
  self._init_dummy_mode()
 
293
  self.models_ready = True
294
  self.loading_status = "Error - using dummy mode"
295
  return True
@@ -573,7 +586,7 @@ async def get_homepage():
573
  <!DOCTYPE html>
574
  <html>
575
  <head>
576
- <title>Diamond CSGO AI Player</title>
577
  <style>
578
  body {
579
  margin: 0;
@@ -629,7 +642,7 @@ async def get_homepage():
629
  </style>
630
  </head>
631
  <body>
632
- <h1>๐ŸŽฎ Diamond CSGO AI Player</h1>
633
  <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
634
  <p id="loadingIndicator" style="color: #ffff00; display: none;">๐Ÿš€ Starting AI inference... This may take 5-15 seconds on first run.</p>
635
 
 
99
  self.models_ready = False # Track if models are loaded
100
  self.download_progress = 0 # Track download progress (0-100)
101
  self.loading_status = "Initializing..." # Loading status message
102
+ self.actor_critic_loaded = False # Track if actor_critic was loaded with trained weights
103
  import time
104
  self.time_module = time
105
 
 
130
  logger.info(f"Model has actor_critic weights: {has_actor_critic}")
131
  agent.load_state_dict(state_dict, load_actor_critic=has_actor_critic)
132
 
133
+ # Track if actor_critic was actually loaded with trained weights
134
+ self.actor_critic_loaded = has_actor_critic
135
+
136
  self.download_progress = 100
137
  self.loading_status = "Model loaded successfully!"
138
  logger.info("All model weights loaded successfully!")
 
224
  self.loading_status = "Loading local checkpoint..."
225
  agent.load(checkpoint_path)
226
  logger.info(f"Successfully loaded local checkpoint: {checkpoint_path}")
227
+ # Assume local checkpoint has actor_critic weights (may need verification)
228
+ self.actor_critic_loaded = True
229
  else:
230
  logger.error(f"No local checkpoint found at: {checkpoint_path}")
231
  raise FileNotFoundError("No model checkpoint available (local or remote)")
 
233
  except Exception as e:
234
  logger.error(f"Failed to load any checkpoint: {e}")
235
  self._init_dummy_mode()
236
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
237
  return True
238
 
239
  # Initialize world model environment
 
249
  self.play_env = WebPlayEnv(agent, wm_env, False, False, False)
250
 
251
  # Verify actor-critic is loaded and ready for inference
252
+ if agent.actor_critic is not None and self.actor_critic_loaded:
253
  logger.info(f"Actor-critic model loaded with {agent.actor_critic.lstm_dim} LSTM dimensions")
254
  logger.info(f"Actor-critic device: {agent.actor_critic.device}")
255
  # Force AI control for web demo
256
  self.play_env.is_human_player = False
257
  logger.info("WebPlayEnv set to AI control mode")
258
+ elif agent.actor_critic is not None and not self.actor_critic_loaded:
259
+ logger.warning("Actor-critic model exists but has no trained weights - using dummy mode!")
260
+ self.play_env.is_human_player = True
261
+ logger.info("WebPlayEnv set to human control mode (no trained weights)")
262
  else:
263
  logger.warning("No actor-critic model found - AI inference will not work!")
264
  self.play_env.is_human_player = True
 
292
  except Exception as e:
293
  logger.error(f"Failed to initialize world model environment: {e}")
294
  self._init_dummy_mode()
295
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
296
  self.models_ready = True
297
  self.loading_status = "Using dummy mode"
298
  return True
 
302
  import traceback
303
  traceback.print_exc()
304
  self._init_dummy_mode()
305
+ self.actor_critic_loaded = False # No actor_critic in dummy mode
306
  self.models_ready = True
307
  self.loading_status = "Error - using dummy mode"
308
  return True
 
586
  <!DOCTYPE html>
587
  <html>
588
  <head>
589
+ <title>Physics-informed BEV World Model</title>
590
  <style>
591
  body {
592
  margin: 0;
 
642
  </style>
643
  </head>
644
  <body>
645
+ <h1>๐ŸŽฎ Physics-informed BEV World Model</h1>
646
  <p><strong>Click the game canvas to start playing!</strong> Use ESC to pause, Enter to reset environment.</p>
647
  <p id="loadingIndicator" style="color: #ffff00; display: none;">๐Ÿš€ Starting AI inference... This may take 5-15 seconds on first run.</p>
648
 
src/game/__pycache__/dataset_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/dataset_env.cpython-310.pyc and b/src/game/__pycache__/dataset_env.cpython-310.pyc differ
 
src/game/__pycache__/web_play_env.cpython-310.pyc CHANGED
Binary files a/src/game/__pycache__/web_play_env.cpython-310.pyc and b/src/game/__pycache__/web_play_env.cpython-310.pyc differ
 
src/game/web_play_env.py CHANGED
@@ -32,9 +32,13 @@ class WebPlayEnv(PlayEnv):
32
  self.is_human_player = False # AI controls the actions
33
  self.human_input_override = False # Can be set to True to allow human input
34
 
35
- # Initialize LSTM hidden states for actor-critic
36
- self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
37
- self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
 
 
 
 
38
 
39
  def switch_controller(self) -> None:
40
  """Switch between AI and human control"""
@@ -97,9 +101,11 @@ class WebPlayEnv(PlayEnv):
97
  if obs.device != self.agent.device:
98
  obs = obs.to(self.agent.device, non_blocking=True)
99
 
100
- # Detach hidden states to prevent gradient tracking
101
- self.hx = self.hx.detach()
102
- self.cx = self.cx.detach()
 
 
103
 
104
  # Resize observation to match actor-critic expected encoder/LSTM input
105
  # Count how many MaxPool2d layers are in the encoder to infer downsampling factor
@@ -145,10 +151,12 @@ class WebPlayEnv(PlayEnv):
145
  self.obs = next_obs
146
  self.t += 1
147
 
148
- # Reset hidden states on episode end
149
  if end.any() or trunc.any():
150
- self.hx.zero_()
151
- self.cx.zero_()
 
 
152
 
153
  # Return the step results
154
  return next_obs, rew, end, trunc, env_info
 
32
  self.is_human_player = False # AI controls the actions
33
  self.human_input_override = False # Can be set to True to allow human input
34
 
35
+ # Initialize LSTM hidden states for actor-critic (only if actor_critic exists)
36
+ if agent.actor_critic is not None:
37
+ self.hx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
38
+ self.cx = torch.zeros(1, agent.actor_critic.lstm_dim, device=agent.device)
39
+ else:
40
+ self.hx = None
41
+ self.cx = None
42
 
43
  def switch_controller(self) -> None:
44
  """Switch between AI and human control"""
 
101
  if obs.device != self.agent.device:
102
  obs = obs.to(self.agent.device, non_blocking=True)
103
 
104
+ # Detach hidden states to prevent gradient tracking (only if they exist)
105
+ if self.hx is not None:
106
+ self.hx = self.hx.detach()
107
+ if self.cx is not None:
108
+ self.cx = self.cx.detach()
109
 
110
  # Resize observation to match actor-critic expected encoder/LSTM input
111
  # Count how many MaxPool2d layers are in the encoder to infer downsampling factor
 
151
  self.obs = next_obs
152
  self.t += 1
153
 
154
+ # Reset hidden states on episode end (only if they exist)
155
  if end.any() or trunc.any():
156
+ if self.hx is not None:
157
+ self.hx.zero_()
158
+ if self.cx is not None:
159
+ self.cx.zero_()
160
 
161
  # Return the step results
162
  return next_obs, rew, end, trunc, env_info