Bertoin's picture
Upload folder using huggingface_hub
35526c7 verified
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import torch
from torch import Tensor, nn
from torch.nn.functional import fold, unfold
from prx_layers import (
EmbedND, # spellchecker:disable-line
LastLayer,
PRXBlock,
MLPEmbedder,
get_image_ids,
timestep_embedding,
)
@dataclass
class PRXParams:
in_channels: int
patch_size: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
axes_dim: list[int]
theta: int
use_image_guidance: bool = False
use_dyn_tanh: bool = False
image_guidance_hidden_size: int = 1280
# Time embedding parameters
time_factor: float = 1000.0
time_max_period: int = 10_000
conditioning_block_ids: list[int] | None = None
PRXTinyConfig = PRXParams(
in_channels=4,
patch_size=2,
context_in_dim=512,
hidden_size=2304,
mlp_ratio=3.5,
num_heads=32,
depth=3,
axes_dim=[64, 64],
theta=10_000,
)
PRXSmallConfig = PRXParams( # 1.24B - 159 ms
in_channels=16,
patch_size=2,
context_in_dim=2304,
hidden_size=1792,
mlp_ratio=3.5,
num_heads=28,
depth=16,
axes_dim=[32, 32],
theta=10_000,
)
PRXDCAESmallConfig = PRXParams( # 1.24B - 159 ms
in_channels=32,
patch_size=1,
context_in_dim=2304,
hidden_size=1792,
mlp_ratio=3.5,
num_heads=28,
depth=16,
axes_dim=[32, 32],
theta=10_000,
)
def img2seq(img: Tensor, patch_size: int) -> Tensor:
"""
Flatten an image into a sequence of patches
b c (h ph) (w pw) -> b (h w) (c ph pw)
"""
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
"""
Revert img2seq
b (h w) (c ph pw) -> b c (h ph) (w pw)
"""
if isinstance(shape, tuple):
shape = shape[-2:]
elif isinstance(shape, torch.Tensor):
shape = (int(shape[0]), int(shape[1]))
else:
raise NotImplementedError(f"shape type {type(shape)} not supported")
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PRX(nn.Module):
"""
PRX
"""
transformer_block_class = PRXBlock
def __init__(self, params: PRXParams | Dict[str, Any] | None = None, **kwargs: Any):
super().__init__()
if params is None:
# Case when loaded from bucket: model_class(**parameters)
params = kwargs
if isinstance(params, dict):
# Remove metadata keys
params_dict = {k: v for k, v in params.items() if not k.startswith("_")}
# Create PRXParams from the cleaned dictionary
params = PRXParams(**params_dict)
elif not isinstance(params, PRXParams):
raise TypeError("params must be either PRXParams, a dictionary, or keyword arguments")
self.params = params
# self.max_img_seq_len = params.max_img_seq_len
self.in_channels = params.in_channels
self.patch_size = params.patch_size
self.use_image_guidance = params.use_image_guidance
self.image_guidance_hidden_size = params.image_guidance_hidden_size
self.out_channels = self.in_channels * self.patch_size**2
self.time_factor = params.time_factor
self.time_max_period = params.time_max_period
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND( # spellchecker:disable-line
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
conditioning_block_ids: list[int] = params.conditioning_block_ids or list(
range(params.depth)
) # Use only conditioning blocks if conditioning_block_ids is not defined
def block_class(idx: int) -> PRXBlock:
return self.transformer_block_class if idx in conditioning_block_ids else PRXBlock
self.blocks = nn.ModuleList(
[
block_class(i)(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
use_image_guidance=self.use_image_guidance,
image_guidance_hidden_size=params.image_guidance_hidden_size,
)
for i in range(params.depth)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
if params.use_dyn_tanh:
# Replace all the LayerNorms by DynTanh
print("Replacing all the LayerNorms by DynTanh")
self.blocks = convert_ln_to_dyt(self.blocks)
self.final_layer = convert_ln_to_dyt(self.final_layer)
def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]:
"Timestep independent stuff"
txt = self.txt_in(txt)
img = img2seq(image_latent, self.patch_size)
bs, _, h, w = image_latent.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
pe = self.pe_embedder(img_ids) # [bs, 1, seq_length, 64, 2, 2]
return img, txt, pe
def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor:
return self.time_in(
timestep_embedding(t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor).to(
dtype
)
)
def forward_transformers(
self,
image_latent: Tensor,
cross_attn_conditioning: Tensor,
timestep: Optional[Tensor] = None,
time_embedding: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None,
**block_kwargs: Any,
) -> Tensor:
img = self.img_in(image_latent)
if time_embedding is not None:
# In that case, the provided timestep is already embedded.
vec = time_embedding
else:
if timestep is None:
raise ValueError("Please provide either a timestep or a timestep_embedding")
vec = self.compute_timestep_embedding(timestep, dtype=img.dtype)
for block in self.blocks:
img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs)
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(
self,
image_latent: Tensor,
timestep: Tensor,
cross_attn_conditioning: Tensor, # TODO: rename text embedding everywhere
micro_conditioning: Tensor, # TODO: rename to micro_conditioning
cross_attn_mask: None | Tensor = None,
image_conditioning: None | Tensor = None,
image_guidance_scale: None | float | Tensor = None,
guidance: None = None, # unused here but required by the LatentDiffusion interface to be Flux compatible
apply_token_drop: bool = False, # unused here but required by the LatentDiffusion interface to be Flux compatible
) -> Tensor:
img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning)
img_seq = self.forward_transformers(
img_seq,
txt,
timestep,
pe=pe,
image_conditioning=image_conditioning,
image_guidance_scale=image_guidance_scale,
attention_mask=cross_attn_mask,
)
return seq2img(img_seq, self.patch_size, image_latent.shape)
if __name__ == "__main__":
DEVICE = torch.device("cuda")
DTYPE = torch.bfloat16
BS = 2
LATENT_C = 16
FEATURE_H, FEATURE_W = 512 // 8, 512 // 8
PROMPT_L = 120
config = PRXSmallConfig
denoiser = PRX(config)
total_params = sum(p.numel() for p in denoiser.parameters())
print(f"Total number of parameters : {total_params / 1e9: .3f}B")
denoiser = denoiser.to(DEVICE, DTYPE)
out = denoiser(
image_latent=torch.randn(BS, LATENT_C, FEATURE_H, FEATURE_W, device=DEVICE, dtype=DTYPE),
timestep=torch.zeros(BS, device=DEVICE, dtype=DTYPE),
cross_attn_conditioning=torch.zeros(BS, PROMPT_L, 2304, device=DEVICE, dtype=DTYPE), # T5 text encoding
micro_conditioning=None,
cross_attn_mask=torch.ones(BS, PROMPT_L, device=DEVICE, dtype=DTYPE),
)
loss = out.sum()
loss.backward()
print("ok")
checkpoint_path = "../diffusers_ok/old_and_checkpoints/computer_vision_checkpoints/denoiser_sft_weights.pth"
# check loading checkpoint
print(f"Loading checkpoint from: {checkpoint_path}")
state_dict = torch.load(checkpoint_path)
included_keys = denoiser.load_state_dict(torch.load(checkpoint_path), strict=True)
print(f"Included keys: {included_keys}")