from dataclasses import dataclass from typing import Any, Dict, Optional, Union import torch from torch import Tensor, nn from torch.nn.functional import fold, unfold from prx_layers import ( EmbedND, # spellchecker:disable-line LastLayer, PRXBlock, MLPEmbedder, get_image_ids, timestep_embedding, ) @dataclass class PRXParams: in_channels: int patch_size: int context_in_dim: int hidden_size: int mlp_ratio: float num_heads: int depth: int axes_dim: list[int] theta: int use_image_guidance: bool = False use_dyn_tanh: bool = False image_guidance_hidden_size: int = 1280 # Time embedding parameters time_factor: float = 1000.0 time_max_period: int = 10_000 conditioning_block_ids: list[int] | None = None PRXTinyConfig = PRXParams( in_channels=4, patch_size=2, context_in_dim=512, hidden_size=2304, mlp_ratio=3.5, num_heads=32, depth=3, axes_dim=[64, 64], theta=10_000, ) PRXSmallConfig = PRXParams( # 1.24B - 159 ms in_channels=16, patch_size=2, context_in_dim=2304, hidden_size=1792, mlp_ratio=3.5, num_heads=28, depth=16, axes_dim=[32, 32], theta=10_000, ) PRXDCAESmallConfig = PRXParams( # 1.24B - 159 ms in_channels=32, patch_size=1, context_in_dim=2304, hidden_size=1792, mlp_ratio=3.5, num_heads=28, depth=16, axes_dim=[32, 32], theta=10_000, ) def img2seq(img: Tensor, patch_size: int) -> Tensor: """ Flatten an image into a sequence of patches b c (h ph) (w pw) -> b (h w) (c ph pw) """ return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: """ Revert img2seq b (h w) (c ph pw) -> b c (h ph) (w pw) """ if isinstance(shape, tuple): shape = shape[-2:] elif isinstance(shape, torch.Tensor): shape = (int(shape[0]), int(shape[1])) else: raise NotImplementedError(f"shape type {type(shape)} not supported") return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) class PRX(nn.Module): """ PRX """ transformer_block_class = PRXBlock def __init__(self, params: PRXParams | Dict[str, Any] | None = None, **kwargs: Any): super().__init__() if params is None: # Case when loaded from bucket: model_class(**parameters) params = kwargs if isinstance(params, dict): # Remove metadata keys params_dict = {k: v for k, v in params.items() if not k.startswith("_")} # Create PRXParams from the cleaned dictionary params = PRXParams(**params_dict) elif not isinstance(params, PRXParams): raise TypeError("params must be either PRXParams, a dictionary, or keyword arguments") self.params = params # self.max_img_seq_len = params.max_img_seq_len self.in_channels = params.in_channels self.patch_size = params.patch_size self.use_image_guidance = params.use_image_guidance self.image_guidance_hidden_size = params.image_guidance_hidden_size self.out_channels = self.in_channels * self.patch_size**2 self.time_factor = params.time_factor self.time_max_period = params.time_max_period if params.hidden_size % params.num_heads != 0: raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND( # spellchecker:disable-line dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) conditioning_block_ids: list[int] = params.conditioning_block_ids or list( range(params.depth) ) # Use only conditioning blocks if conditioning_block_ids is not defined def block_class(idx: int) -> PRXBlock: return self.transformer_block_class if idx in conditioning_block_ids else PRXBlock self.blocks = nn.ModuleList( [ block_class(i)( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, use_image_guidance=self.use_image_guidance, image_guidance_hidden_size=params.image_guidance_hidden_size, ) for i in range(params.depth) ] ) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) if params.use_dyn_tanh: # Replace all the LayerNorms by DynTanh print("Replacing all the LayerNorms by DynTanh") self.blocks = convert_ln_to_dyt(self.blocks) self.final_layer = convert_ln_to_dyt(self.final_layer) def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: "Timestep independent stuff" txt = self.txt_in(txt) img = img2seq(image_latent, self.patch_size) bs, _, h, w = image_latent.shape img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) pe = self.pe_embedder(img_ids) # [bs, 1, seq_length, 64, 2, 2] return img, txt, pe def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: return self.time_in( timestep_embedding(t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor).to( dtype ) ) def forward_transformers( self, image_latent: Tensor, cross_attn_conditioning: Tensor, timestep: Optional[Tensor] = None, time_embedding: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, **block_kwargs: Any, ) -> Tensor: img = self.img_in(image_latent) if time_embedding is not None: # In that case, the provided timestep is already embedded. vec = time_embedding else: if timestep is None: raise ValueError("Please provide either a timestep or a timestep_embedding") vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) for block in self.blocks: img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img def forward( self, image_latent: Tensor, timestep: Tensor, cross_attn_conditioning: Tensor, # TODO: rename text embedding everywhere micro_conditioning: Tensor, # TODO: rename to micro_conditioning cross_attn_mask: None | Tensor = None, image_conditioning: None | Tensor = None, image_guidance_scale: None | float | Tensor = None, guidance: None = None, # unused here but required by the LatentDiffusion interface to be Flux compatible apply_token_drop: bool = False, # unused here but required by the LatentDiffusion interface to be Flux compatible ) -> Tensor: img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) img_seq = self.forward_transformers( img_seq, txt, timestep, pe=pe, image_conditioning=image_conditioning, image_guidance_scale=image_guidance_scale, attention_mask=cross_attn_mask, ) return seq2img(img_seq, self.patch_size, image_latent.shape) if __name__ == "__main__": DEVICE = torch.device("cuda") DTYPE = torch.bfloat16 BS = 2 LATENT_C = 16 FEATURE_H, FEATURE_W = 512 // 8, 512 // 8 PROMPT_L = 120 config = PRXSmallConfig denoiser = PRX(config) total_params = sum(p.numel() for p in denoiser.parameters()) print(f"Total number of parameters : {total_params / 1e9: .3f}B") denoiser = denoiser.to(DEVICE, DTYPE) out = denoiser( image_latent=torch.randn(BS, LATENT_C, FEATURE_H, FEATURE_W, device=DEVICE, dtype=DTYPE), timestep=torch.zeros(BS, device=DEVICE, dtype=DTYPE), cross_attn_conditioning=torch.zeros(BS, PROMPT_L, 2304, device=DEVICE, dtype=DTYPE), # T5 text encoding micro_conditioning=None, cross_attn_mask=torch.ones(BS, PROMPT_L, device=DEVICE, dtype=DTYPE), ) loss = out.sum() loss.backward() print("ok") checkpoint_path = "../diffusers_ok/old_and_checkpoints/computer_vision_checkpoints/denoiser_sft_weights.pth" # check loading checkpoint print(f"Loading checkpoint from: {checkpoint_path}") state_dict = torch.load(checkpoint_path) included_keys = denoiser.load_state_dict(torch.load(checkpoint_path), strict=True) print(f"Included keys: {included_keys}")