Spaces:
Sleeping
Sleeping
| """ | |
| Configuration helper for web deployment | |
| Handles path resolution and model loading for deployment | |
| """ | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class WebConfig: | |
| """Configuration manager for web deployment""" | |
| def __init__(self, base_path: Optional[Path] = None): | |
| if base_path is None: | |
| base_path = Path.cwd() | |
| self.base_path = Path(base_path) | |
| def get_config_path(self) -> Path: | |
| """Get configuration directory path""" | |
| # Try multiple possible locations | |
| possible_paths = [ | |
| self.base_path / "config", | |
| self.base_path / "src" / ".." / "config", | |
| Path(__file__).parent / "config" | |
| ] | |
| for path in possible_paths: | |
| if path.exists(): | |
| return path.resolve() | |
| # Create default config directory | |
| config_path = self.base_path / "config" | |
| config_path.mkdir(exist_ok=True) | |
| return config_path | |
| def get_checkpoint_path(self) -> Path: | |
| """Find and return the best available checkpoint""" | |
| # Try different possible locations and names | |
| possible_checkpoints = [ | |
| self.base_path / "agent_epoch_00003.pt", | |
| self.base_path / "agent_epoch_00003.pt", | |
| self.base_path / "checkpoints" / "agent_epoch_00003.pt", | |
| self.base_path / "checkpoints" / "agent_epoch_00003.pt", | |
| self.base_path / "checkpoints" / "latest.pt", | |
| ] | |
| for ckpt_path in possible_checkpoints: | |
| if ckpt_path.exists(): | |
| logger.info(f"Found checkpoint: {ckpt_path}") | |
| return ckpt_path | |
| # If no checkpoint found, create a dummy message | |
| logger.warning("No checkpoint found - you may need to download models") | |
| return self.base_path / "checkpoints" / "model_not_found.pt" | |
| def get_spawn_dir(self) -> Path: | |
| """Get spawn data directory""" | |
| spawn_dir = self.base_path / "csgo" / "spawn" | |
| spawn_dir.mkdir(parents=True, exist_ok=True) | |
| # Create dummy spawn data if it doesn't exist | |
| spawn_subdir = spawn_dir / "0" | |
| spawn_subdir.mkdir(exist_ok=True) | |
| # Create dummy files if they don't exist | |
| dummy_files = ["act.npy", "full_res.npy", "info.json", "low_res.npy", "next_act.npy"] | |
| for filename in dummy_files: | |
| file_path = spawn_subdir / filename | |
| if not file_path.exists(): | |
| if filename.endswith('.npy'): | |
| import numpy as np | |
| np.save(file_path, np.zeros((1, 10))) # Dummy array | |
| elif filename.endswith('.json'): | |
| import json | |
| with open(file_path, 'w') as f: | |
| json.dump({"dummy": True}, f) | |
| return spawn_dir | |
| def setup_environment_variables(self): | |
| """Set up environment variables for deployment""" | |
| # Disable CUDA if not available (for CPU-only deployment) | |
| if not self.has_cuda(): | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| # Set Python path | |
| python_path = str(self.base_path / "src") | |
| current_path = os.environ.get("PYTHONPATH", "") | |
| if python_path not in current_path: | |
| os.environ["PYTHONPATH"] = f"{python_path}:{current_path}" if current_path else python_path | |
| def has_cuda(self) -> bool: | |
| """Check if CUDA is available""" | |
| try: | |
| import torch | |
| return torch.cuda.is_available() | |
| except ImportError: | |
| return False | |
| def create_default_configs(self): | |
| """Create default configuration files if they don't exist""" | |
| config_dir = self.get_config_path() | |
| # Create agent config | |
| agent_dir = config_dir / "agent" | |
| agent_dir.mkdir(exist_ok=True) | |
| agent_config_path = agent_dir / "csgo.yaml" | |
| if not agent_config_path.exists(): | |
| with open(agent_config_path, 'w') as f: | |
| f.write("""_target_: agent.AgentConfig | |
| denoiser: | |
| _target_: models.diffusion.DenoiserConfig | |
| sigma_data: 0.5 | |
| sigma_offset_noise: 0.1 | |
| noise_previous_obs: true | |
| upsampling_factor: null | |
| inner_model: | |
| _target_: models.diffusion.InnerModelConfig | |
| img_channels: 3 | |
| num_steps_conditioning: 4 | |
| cond_channels: 2048 | |
| depths: [2, 2, 2, 2] | |
| channels: [128, 256, 512, 1024] | |
| attn_depths: [0, 0, 1, 1] | |
| upsampler: | |
| _target_: models.diffusion.DenoiserConfig | |
| sigma_data: 0.5 | |
| sigma_offset_noise: 0.1 | |
| noise_previous_obs: false | |
| upsampling_factor: 5 | |
| inner_model: | |
| _target_: models.diffusion.InnerModelConfig | |
| img_channels: 3 | |
| num_steps_conditioning: 1 | |
| cond_channels: 2048 | |
| depths: [2, 2, 2, 2] | |
| channels: [64, 64, 128, 256] | |
| attn_depths: [0, 0, 0, 1] | |
| rew_end_model: null | |
| actor_critic: null | |
| """) | |
| # Create env config | |
| env_dir = config_dir / "env" | |
| env_dir.mkdir(exist_ok=True) | |
| env_config_path = env_dir / "csgo.yaml" | |
| if not env_config_path.exists(): | |
| with open(env_config_path, 'w') as f: | |
| f.write("""train: | |
| id: csgo | |
| size: [150, 600] | |
| num_actions: 51 | |
| path_data_low_res: /tmp/dummy_data_low_res | |
| path_data_full_res: /tmp/dummy_data_full_res | |
| keymap: csgo | |
| """) | |
| # Create world model env config | |
| wm_env_dir = config_dir / "world_model_env" | |
| wm_env_dir.mkdir(exist_ok=True) | |
| wm_config_path = wm_env_dir / "fast.yaml" | |
| if not wm_config_path.exists(): | |
| with open(wm_config_path, 'w') as f: | |
| f.write("""_target_: envs.WorldModelEnvConfig | |
| horizon: 1000 | |
| num_batches_to_preload: 1 | |
| diffusion_sampler_next_obs: | |
| _target_: models.diffusion.DiffusionSamplerConfig | |
| num_steps_denoising: 10 | |
| sigma_min: 0.002 | |
| sigma_max: 5.0 | |
| rho: 7 | |
| order: 1 | |
| diffusion_sampler_upsampling: | |
| _target_: models.diffusion.DiffusionSamplerConfig | |
| num_steps_denoising: 5 | |
| sigma_min: 0.002 | |
| sigma_max: 5.0 | |
| rho: 7 | |
| order: 1 | |
| """) | |
| # Create trainer config | |
| trainer_config_path = config_dir / "trainer.yaml" | |
| if not trainer_config_path.exists(): | |
| with open(trainer_config_path, 'w') as f: | |
| f.write("""defaults: | |
| - _self_ | |
| - env: csgo | |
| - agent: csgo | |
| - world_model_env: fast | |
| static_dataset: | |
| path: /tmp/dummy_data_low_res | |
| ignore_sample_weights: True | |
| """) | |
| # Global config instance | |
| web_config = WebConfig() | |