|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
from torch.nn.functional import fold, unfold |
|
|
|
|
|
from prx_layers import ( |
|
|
EmbedND, |
|
|
LastLayer, |
|
|
PRXBlock, |
|
|
MLPEmbedder, |
|
|
get_image_ids, |
|
|
timestep_embedding, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PRXParams: |
|
|
in_channels: int |
|
|
patch_size: int |
|
|
context_in_dim: int |
|
|
hidden_size: int |
|
|
mlp_ratio: float |
|
|
num_heads: int |
|
|
depth: int |
|
|
axes_dim: list[int] |
|
|
theta: int |
|
|
use_image_guidance: bool = False |
|
|
use_dyn_tanh: bool = False |
|
|
image_guidance_hidden_size: int = 1280 |
|
|
|
|
|
|
|
|
time_factor: float = 1000.0 |
|
|
time_max_period: int = 10_000 |
|
|
|
|
|
conditioning_block_ids: list[int] | None = None |
|
|
|
|
|
|
|
|
PRXTinyConfig = PRXParams( |
|
|
in_channels=4, |
|
|
patch_size=2, |
|
|
context_in_dim=512, |
|
|
hidden_size=2304, |
|
|
mlp_ratio=3.5, |
|
|
num_heads=32, |
|
|
depth=3, |
|
|
axes_dim=[64, 64], |
|
|
theta=10_000, |
|
|
) |
|
|
|
|
|
|
|
|
PRXSmallConfig = PRXParams( |
|
|
in_channels=16, |
|
|
patch_size=2, |
|
|
context_in_dim=2304, |
|
|
hidden_size=1792, |
|
|
mlp_ratio=3.5, |
|
|
num_heads=28, |
|
|
depth=16, |
|
|
axes_dim=[32, 32], |
|
|
theta=10_000, |
|
|
) |
|
|
|
|
|
|
|
|
PRXDCAESmallConfig = PRXParams( |
|
|
in_channels=32, |
|
|
patch_size=1, |
|
|
context_in_dim=2304, |
|
|
hidden_size=1792, |
|
|
mlp_ratio=3.5, |
|
|
num_heads=28, |
|
|
depth=16, |
|
|
axes_dim=[32, 32], |
|
|
theta=10_000, |
|
|
) |
|
|
|
|
|
|
|
|
def img2seq(img: Tensor, patch_size: int) -> Tensor: |
|
|
""" |
|
|
Flatten an image into a sequence of patches |
|
|
b c (h ph) (w pw) -> b (h w) (c ph pw) |
|
|
""" |
|
|
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) |
|
|
|
|
|
|
|
|
def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: |
|
|
""" |
|
|
Revert img2seq |
|
|
b (h w) (c ph pw) -> b c (h ph) (w pw) |
|
|
""" |
|
|
if isinstance(shape, tuple): |
|
|
shape = shape[-2:] |
|
|
elif isinstance(shape, torch.Tensor): |
|
|
shape = (int(shape[0]), int(shape[1])) |
|
|
else: |
|
|
raise NotImplementedError(f"shape type {type(shape)} not supported") |
|
|
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) |
|
|
|
|
|
|
|
|
class PRX(nn.Module): |
|
|
""" |
|
|
PRX |
|
|
""" |
|
|
transformer_block_class = PRXBlock |
|
|
|
|
|
def __init__(self, params: PRXParams | Dict[str, Any] | None = None, **kwargs: Any): |
|
|
super().__init__() |
|
|
|
|
|
if params is None: |
|
|
|
|
|
params = kwargs |
|
|
|
|
|
if isinstance(params, dict): |
|
|
|
|
|
params_dict = {k: v for k, v in params.items() if not k.startswith("_")} |
|
|
|
|
|
params = PRXParams(**params_dict) |
|
|
elif not isinstance(params, PRXParams): |
|
|
raise TypeError("params must be either PRXParams, a dictionary, or keyword arguments") |
|
|
|
|
|
self.params = params |
|
|
|
|
|
self.in_channels = params.in_channels |
|
|
self.patch_size = params.patch_size |
|
|
self.use_image_guidance = params.use_image_guidance |
|
|
self.image_guidance_hidden_size = params.image_guidance_hidden_size |
|
|
|
|
|
self.out_channels = self.in_channels * self.patch_size**2 |
|
|
|
|
|
self.time_factor = params.time_factor |
|
|
self.time_max_period = params.time_max_period |
|
|
|
|
|
if params.hidden_size % params.num_heads != 0: |
|
|
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") |
|
|
|
|
|
pe_dim = params.hidden_size // params.num_heads |
|
|
|
|
|
if sum(params.axes_dim) != pe_dim: |
|
|
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") |
|
|
|
|
|
self.hidden_size = params.hidden_size |
|
|
self.num_heads = params.num_heads |
|
|
self.pe_embedder = EmbedND( |
|
|
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim |
|
|
) |
|
|
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) |
|
|
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) |
|
|
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) |
|
|
|
|
|
conditioning_block_ids: list[int] = params.conditioning_block_ids or list( |
|
|
range(params.depth) |
|
|
) |
|
|
|
|
|
def block_class(idx: int) -> PRXBlock: |
|
|
return self.transformer_block_class if idx in conditioning_block_ids else PRXBlock |
|
|
|
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
block_class(i)( |
|
|
self.hidden_size, |
|
|
self.num_heads, |
|
|
mlp_ratio=params.mlp_ratio, |
|
|
use_image_guidance=self.use_image_guidance, |
|
|
image_guidance_hidden_size=params.image_guidance_hidden_size, |
|
|
) |
|
|
for i in range(params.depth) |
|
|
] |
|
|
) |
|
|
|
|
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) |
|
|
|
|
|
if params.use_dyn_tanh: |
|
|
|
|
|
print("Replacing all the LayerNorms by DynTanh") |
|
|
self.blocks = convert_ln_to_dyt(self.blocks) |
|
|
self.final_layer = convert_ln_to_dyt(self.final_layer) |
|
|
|
|
|
def process_inputs(self, image_latent: Tensor, txt: Tensor, **_: Any) -> tuple[Tensor, Tensor, Tensor]: |
|
|
"Timestep independent stuff" |
|
|
txt = self.txt_in(txt) |
|
|
img = img2seq(image_latent, self.patch_size) |
|
|
bs, _, h, w = image_latent.shape |
|
|
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) |
|
|
pe = self.pe_embedder(img_ids) |
|
|
return img, txt, pe |
|
|
|
|
|
def compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: |
|
|
return self.time_in( |
|
|
timestep_embedding(t=timestep, dim=256, max_period=self.time_max_period, time_factor=self.time_factor).to( |
|
|
dtype |
|
|
) |
|
|
) |
|
|
|
|
|
def forward_transformers( |
|
|
self, |
|
|
image_latent: Tensor, |
|
|
cross_attn_conditioning: Tensor, |
|
|
timestep: Optional[Tensor] = None, |
|
|
time_embedding: Optional[Tensor] = None, |
|
|
attention_mask: Optional[Tensor] = None, |
|
|
**block_kwargs: Any, |
|
|
) -> Tensor: |
|
|
img = self.img_in(image_latent) |
|
|
|
|
|
if time_embedding is not None: |
|
|
|
|
|
vec = time_embedding |
|
|
else: |
|
|
if timestep is None: |
|
|
raise ValueError("Please provide either a timestep or a timestep_embedding") |
|
|
vec = self.compute_timestep_embedding(timestep, dtype=img.dtype) |
|
|
for block in self.blocks: |
|
|
img = block(img=img, txt=cross_attn_conditioning, vec=vec, attention_mask=attention_mask, **block_kwargs) |
|
|
|
|
|
img = self.final_layer(img, vec) |
|
|
return img |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
image_latent: Tensor, |
|
|
timestep: Tensor, |
|
|
cross_attn_conditioning: Tensor, |
|
|
micro_conditioning: Tensor, |
|
|
cross_attn_mask: None | Tensor = None, |
|
|
image_conditioning: None | Tensor = None, |
|
|
image_guidance_scale: None | float | Tensor = None, |
|
|
guidance: None = None, |
|
|
apply_token_drop: bool = False, |
|
|
) -> Tensor: |
|
|
img_seq, txt, pe = self.process_inputs(image_latent, cross_attn_conditioning) |
|
|
img_seq = self.forward_transformers( |
|
|
img_seq, |
|
|
txt, |
|
|
timestep, |
|
|
pe=pe, |
|
|
image_conditioning=image_conditioning, |
|
|
image_guidance_scale=image_guidance_scale, |
|
|
attention_mask=cross_attn_mask, |
|
|
) |
|
|
return seq2img(img_seq, self.patch_size, image_latent.shape) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
DEVICE = torch.device("cuda") |
|
|
DTYPE = torch.bfloat16 |
|
|
|
|
|
BS = 2 |
|
|
LATENT_C = 16 |
|
|
FEATURE_H, FEATURE_W = 512 // 8, 512 // 8 |
|
|
PROMPT_L = 120 |
|
|
config = PRXSmallConfig |
|
|
|
|
|
denoiser = PRX(config) |
|
|
total_params = sum(p.numel() for p in denoiser.parameters()) |
|
|
print(f"Total number of parameters : {total_params / 1e9: .3f}B") |
|
|
denoiser = denoiser.to(DEVICE, DTYPE) |
|
|
|
|
|
out = denoiser( |
|
|
image_latent=torch.randn(BS, LATENT_C, FEATURE_H, FEATURE_W, device=DEVICE, dtype=DTYPE), |
|
|
timestep=torch.zeros(BS, device=DEVICE, dtype=DTYPE), |
|
|
cross_attn_conditioning=torch.zeros(BS, PROMPT_L, 2304, device=DEVICE, dtype=DTYPE), |
|
|
micro_conditioning=None, |
|
|
cross_attn_mask=torch.ones(BS, PROMPT_L, device=DEVICE, dtype=DTYPE), |
|
|
) |
|
|
loss = out.sum() |
|
|
loss.backward() |
|
|
print("ok") |
|
|
checkpoint_path = "../diffusers_ok/old_and_checkpoints/computer_vision_checkpoints/denoiser_sft_weights.pth" |
|
|
|
|
|
print(f"Loading checkpoint from: {checkpoint_path}") |
|
|
state_dict = torch.load(checkpoint_path) |
|
|
included_keys = denoiser.load_state_dict(torch.load(checkpoint_path), strict=True) |
|
|
print(f"Included keys: {included_keys}") |
|
|
|
|
|
|
|
|
|