from dataclasses import dataclass from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn from .envs import TorchEnv, WorldModelEnv from .models.actor_critic import ActorCritic, ActorCriticConfig, ActorCriticLossConfig from .models.diffusion import Denoiser, DenoiserConfig, SigmaDistributionConfig from .models.rew_end_model import RewEndModel, RewEndModelConfig from .utils import extract_state_dict @dataclass class AgentConfig: denoiser: DenoiserConfig upsampler: Optional[DenoiserConfig] rew_end_model: Optional[RewEndModelConfig] actor_critic: Optional[ActorCriticConfig] num_actions: int def __post_init__(self) -> None: self.denoiser.inner_model.num_actions = self.num_actions if self.upsampler is not None: self.upsampler.inner_model.num_actions = self.num_actions if self.rew_end_model is not None: self.rew_end_model.num_actions = self.num_actions if self.actor_critic is not None: self.actor_critic.num_actions = self.num_actions class Agent(nn.Module): def __init__(self, cfg: AgentConfig) -> None: super().__init__() self.denoiser = Denoiser(cfg.denoiser) self.upsampler = Denoiser(cfg.upsampler) if cfg.upsampler is not None else None self.rew_end_model = RewEndModel(cfg.rew_end_model) if cfg.rew_end_model is not None else None self.actor_critic = ActorCritic(cfg.actor_critic) if cfg.actor_critic is not None else None @property def device(self): return self.denoiser.device def setup_training( self, sigma_distribution_cfg: SigmaDistributionConfig, sigma_distribution_cfg_upsampler: Optional[SigmaDistributionConfig], actor_critic_loss_cfg: Optional[ActorCriticLossConfig], rl_env: Optional[Union[TorchEnv, WorldModelEnv]], ) -> None: self.denoiser.setup_training(sigma_distribution_cfg) if self.upsampler is not None: self.upsampler.setup_training(sigma_distribution_cfg_upsampler) if self.actor_critic is not None: self.actor_critic.setup_training(rl_env, actor_critic_loss_cfg) def load( self, path_to_ckpt: Path, load_denoiser: bool = True, load_upsampler: bool = True, load_rew_end_model: bool = True, load_actor_critic: bool = True, ) -> None: sd = torch.load(Path(path_to_ckpt), map_location=self.device) self.load_state_dict(sd, load_denoiser, load_upsampler, load_rew_end_model, load_actor_critic) def load_state_dict( self, state_dict: dict, load_denoiser: bool = True, load_upsampler: bool = True, load_rew_end_model: bool = True, load_actor_critic: bool = True, ) -> None: """Load state dict directly without file I/O""" if load_denoiser: self.denoiser.load_state_dict(extract_state_dict(state_dict, "denoiser")) if load_upsampler and self.upsampler is not None: self.upsampler.load_state_dict(extract_state_dict(state_dict, "upsampler")) if load_rew_end_model and self.rew_end_model is not None: self.rew_end_model.load_state_dict(extract_state_dict(state_dict, "rew_end_model")) if load_actor_critic and self.actor_critic is not None: self.actor_critic.load_state_dict(extract_state_dict(state_dict, "actor_critic"))