DeepBeepMeep commited on
Commit
7d5369f
·
1 Parent(s): 48deef5

Fixed pytorch compilation

Browse files
README.md CHANGED
@@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
 
21
  ## 🔥 Latest News!!
 
22
  * Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
23
  If you upgrade you will need to do a 'pip install -r requirements.txt' again.
24
  * Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
 
19
 
20
 
21
  ## 🔥 Latest News!!
22
+ * Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
23
  * Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
24
  If you upgrade you will need to do a 'pip install -r requirements.txt' again.
25
  * Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
gradio_server.py CHANGED
@@ -853,7 +853,7 @@ def generate_video(
853
  if use_image2video:
854
  samples = wan_model.generate(
855
  prompt,
856
- image_to_continue[video_no-1],
857
  frame_num=(video_length // 4)* 4 + 1,
858
  max_area=MAX_AREA_CONFIGS[resolution],
859
  shift=flow_shift,
 
853
  if use_image2video:
854
  samples = wan_model.generate(
855
  prompt,
856
+ image_to_continue[ (video_no-1) % len(image_to_continue)],
857
  frame_num=(video_length // 4)* 4 + 1,
858
  max_area=MAX_AREA_CONFIGS[resolution],
859
  shift=flow_shift,
wan/image2video.py CHANGED
@@ -24,7 +24,7 @@ from .modules.vae import WanVAE
24
  from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
  get_sampling_sigmas, retrieve_timesteps)
26
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
-
28
 
29
  class WanI2V:
30
 
@@ -290,7 +290,7 @@ class WanI2V:
290
  # sample videos
291
  latent = noise
292
 
293
- freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
294
 
295
  arg_c = {
296
  'context': [context[0]],
@@ -318,6 +318,7 @@ class WanI2V:
318
  callback(-1, None)
319
 
320
  for i, t in enumerate(tqdm(timesteps)):
 
321
  latent_model_input = [latent.to(self.device)]
322
  timestep = [t]
323
 
 
24
  from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
25
  get_sampling_sigmas, retrieve_timesteps)
26
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
27
+ from wan.modules.posemb_layers import get_rotary_pos_embed
28
 
29
  class WanI2V:
30
 
 
290
  # sample videos
291
  latent = noise
292
 
293
+ freqs = get_rotary_pos_embed(frame_num, h, w, enable_RIFLEx= enable_RIFLEx )
294
 
295
  arg_c = {
296
  'context': [context[0]],
 
318
  callback(-1, None)
319
 
320
  for i, t in enumerate(tqdm(timesteps)):
321
+ offload.set_step_no_for_lora(i)
322
  latent_model_input = [latent.to(self.device)]
323
  timestep = [t]
324
 
wan/modules/model.py CHANGED
@@ -67,17 +67,15 @@ def rope_params_riflex(max_seq_len, dim, theta=10000, L_test=30, k=6):
67
  inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
68
 
69
  freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
70
- freqs = torch.polar(torch.ones_like(freqs), freqs)
 
 
 
 
 
71
  return freqs
72
 
73
- def rope_params(max_seq_len, dim, theta=10000):
74
- assert dim % 2 == 0
75
- freqs = torch.outer(
76
- torch.arange(max_seq_len),
77
- 1.0 / torch.pow(theta,
78
- torch.arange(0, dim, 2).to(torch.float32).div(dim)))
79
- freqs = torch.polar(torch.ones_like(freqs), freqs)
80
- return freqs
81
 
82
 
83
  def rope_apply_(x, grid_sizes, freqs):
@@ -209,6 +207,7 @@ class WanLayerNorm(nn.LayerNorm):
209
  return x
210
  # return super().forward(x).type_as(x)
211
 
 
212
 
213
  class WanSelfAttention(nn.Module):
214
 
@@ -257,8 +256,11 @@ class WanSelfAttention(nn.Module):
257
  k = k.view(b, s, n, d)
258
  v = self.v(x).view(b, s, n, d)
259
  del x
260
- rope_apply_(q, grid_sizes, freqs)
261
- rope_apply_(k, grid_sizes, freqs)
 
 
 
262
  qkv_list = [q,k,v]
263
  del q,k,v
264
  x = pay_attention(
@@ -652,20 +654,18 @@ class WanModel(ModelMixin, ConfigMixin):
652
  # ],dim=1)
653
 
654
 
655
- def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None):
656
  dim = self.dim
657
  num_heads = self.num_heads
658
  d = dim // num_heads
659
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
660
 
661
 
662
- freqs = torch.cat([
663
- rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)), #44
664
- rope_params(1024, 2 * (d // 6)), #42
665
- rope_params(1024, 2 * (d // 6)) #42
666
- ],dim=1)
667
 
668
- return freqs
669
 
670
 
671
  def forward(
@@ -706,7 +706,7 @@ class WanModel(ModelMixin, ConfigMixin):
706
  assert clip_fea is not None and y is not None
707
  # params
708
  device = self.patch_embedding.weight.device
709
- if freqs.device != device:
710
  freqs = freqs.to(device)
711
 
712
  if y is not None:
 
67
  inv_theta_pow[k-1] = 0.9 * 2 * torch.pi / L_test
68
 
69
  freqs = torch.outer(torch.arange(max_seq_len), inv_theta_pow)
70
+ if True:
71
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
72
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
73
+ return (freqs_cos, freqs_sin)
74
+ else:
75
+ freqs = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
76
  return freqs
77
 
78
+
 
 
 
 
 
 
 
79
 
80
 
81
  def rope_apply_(x, grid_sizes, freqs):
 
207
  return x
208
  # return super().forward(x).type_as(x)
209
 
210
+ from wan.modules.posemb_layers import apply_rotary_emb
211
 
212
  class WanSelfAttention(nn.Module):
213
 
 
256
  k = k.view(b, s, n, d)
257
  v = self.v(x).view(b, s, n, d)
258
  del x
259
+ # rope_apply_(q, grid_sizes, freqs)
260
+ # rope_apply_(k, grid_sizes, freqs)
261
+ qklist = [q,k]
262
+ del q,k
263
+ q,k = apply_rotary_emb(qklist, freqs, head_first=False)
264
  qkv_list = [q,k,v]
265
  del q,k,v
266
  x = pay_attention(
 
654
  # ],dim=1)
655
 
656
 
657
+ def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
658
  dim = self.dim
659
  num_heads = self.num_heads
660
  d = dim // num_heads
661
  assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
662
 
663
 
664
+ c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
665
+ c2, s2 = rope_params(1024, 2 * (d // 6)) #42
666
+ c3, s3 = rope_params(1024, 2 * (d // 6)) #42
 
 
667
 
668
+ return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
669
 
670
 
671
  def forward(
 
706
  assert clip_fea is not None and y is not None
707
  # params
708
  device = self.patch_embedding.weight.device
709
+ if torch.is_tensor(freqs) and freqs.device != device:
710
  freqs = freqs.to(device)
711
 
712
  if y is not None:
wan/modules/posemb_layers.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List, Optional
3
+ import numpy as np
4
+
5
+
6
+ ###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos
7
+ #
8
+ def get_1d_rotary_pos_embed_riflex(
9
+ dim: int,
10
+ pos: Union[np.ndarray, int],
11
+ theta: float = 10000.0,
12
+ use_real=False,
13
+ k: Optional[int] = None,
14
+ L_test: Optional[int] = None,
15
+ ):
16
+ """
17
+ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
18
+
19
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
20
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
21
+ data type.
22
+
23
+ Args:
24
+ dim (`int`): Dimension of the frequency tensor.
25
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
26
+ theta (`float`, *optional*, defaults to 10000.0):
27
+ Scaling factor for frequency computation. Defaults to 10000.0.
28
+ use_real (`bool`, *optional*):
29
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
30
+ k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
31
+ L_test (`int`, *optional*, defaults to None): the number of frames for inference
32
+ Returns:
33
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
34
+ """
35
+ assert dim % 2 == 0
36
+
37
+ if isinstance(pos, int):
38
+ pos = torch.arange(pos)
39
+ if isinstance(pos, np.ndarray):
40
+ pos = torch.from_numpy(pos) # type: ignore # [S]
41
+
42
+ freqs = 1.0 / (
43
+ theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)
44
+ ) # [D/2]
45
+
46
+ # === Riflex modification start ===
47
+ # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
48
+ # Empirical observations show that a few videos may exhibit repetition in the tail frames.
49
+ # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
50
+ if k is not None:
51
+ freqs[k-1] = 0.9 * 2 * torch.pi / L_test
52
+ # === Riflex modification end ===
53
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
54
+
55
+ if use_real:
56
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
57
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
58
+ return freqs_cos, freqs_sin
59
+ else:
60
+ # lumina
61
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
62
+ return freqs_cis
63
+
64
+ def identify_k( b: float, d: int, N: int):
65
+ """
66
+ This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
67
+
68
+ Args:
69
+ b (`float`): The base frequency for RoPE.
70
+ d (`int`): Dimension of the frequency tensor
71
+ N (`int`): the first observed repetition frame in latent space
72
+ Returns:
73
+ k (`int`): the index of intrinsic frequency component
74
+ N_k (`int`): the period of intrinsic frequency component in latent space
75
+ Example:
76
+ In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
77
+ k, N_k = identify_k(b=256, d=16, N=48)
78
+ In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
79
+ """
80
+
81
+ # Compute the period of each frequency in RoPE according to Eq.(4)
82
+ periods = []
83
+ for j in range(1, d // 2 + 1):
84
+ theta_j = 1.0 / (b ** (2 * (j - 1) / d))
85
+ N_j = round(2 * torch.pi / theta_j)
86
+ periods.append(N_j)
87
+
88
+ # Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
89
+ diffs = [abs(N_j - N) for N_j in periods]
90
+ k = diffs.index(min(diffs)) + 1
91
+ N_k = periods[k-1]
92
+ return k, N_k
93
+
94
+ def _to_tuple(x, dim=2):
95
+ if isinstance(x, int):
96
+ return (x,) * dim
97
+ elif len(x) == dim:
98
+ return x
99
+ else:
100
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
101
+
102
+
103
+ def get_meshgrid_nd(start, *args, dim=2):
104
+ """
105
+ Get n-D meshgrid with start, stop and num.
106
+
107
+ Args:
108
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
109
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
110
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
111
+ n-tuples.
112
+ *args: See above.
113
+ dim (int): Dimension of the meshgrid. Defaults to 2.
114
+
115
+ Returns:
116
+ grid (np.ndarray): [dim, ...]
117
+ """
118
+ if len(args) == 0:
119
+ # start is grid_size
120
+ num = _to_tuple(start, dim=dim)
121
+ start = (0,) * dim
122
+ stop = num
123
+ elif len(args) == 1:
124
+ # start is start, args[0] is stop, step is 1
125
+ start = _to_tuple(start, dim=dim)
126
+ stop = _to_tuple(args[0], dim=dim)
127
+ num = [stop[i] - start[i] for i in range(dim)]
128
+ elif len(args) == 2:
129
+ # start is start, args[0] is stop, args[1] is num
130
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
131
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
132
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
133
+ else:
134
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
135
+
136
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
137
+ axis_grid = []
138
+ for i in range(dim):
139
+ a, b, n = start[i], stop[i], num[i]
140
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
141
+ axis_grid.append(g)
142
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
143
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
144
+
145
+ return grid
146
+
147
+
148
+ #################################################################################
149
+ # Rotary Positional Embedding Functions #
150
+ #################################################################################
151
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
152
+
153
+
154
+ def reshape_for_broadcast(
155
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
156
+ x: torch.Tensor,
157
+ head_first=False,
158
+ ):
159
+ """
160
+ Reshape frequency tensor for broadcasting it with another tensor.
161
+
162
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
163
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
164
+
165
+ Notes:
166
+ When using FlashMHAModified, head_first should be False.
167
+ When using Attention, head_first should be True.
168
+
169
+ Args:
170
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
171
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
172
+ head_first (bool): head dimension first (except batch dim) or not.
173
+
174
+ Returns:
175
+ torch.Tensor: Reshaped frequency tensor.
176
+
177
+ Raises:
178
+ AssertionError: If the frequency tensor doesn't match the expected shape.
179
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
180
+ """
181
+ ndim = x.ndim
182
+ assert 0 <= 1 < ndim
183
+
184
+ if isinstance(freqs_cis, tuple):
185
+ # freqs_cis: (cos, sin) in real space
186
+ if head_first:
187
+ assert freqs_cis[0].shape == (
188
+ x.shape[-2],
189
+ x.shape[-1],
190
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
191
+ shape = [
192
+ d if i == ndim - 2 or i == ndim - 1 else 1
193
+ for i, d in enumerate(x.shape)
194
+ ]
195
+ else:
196
+ assert freqs_cis[0].shape == (
197
+ x.shape[1],
198
+ x.shape[-1],
199
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
200
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
201
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
202
+ else:
203
+ # freqs_cis: values in complex space
204
+ if head_first:
205
+ assert freqs_cis.shape == (
206
+ x.shape[-2],
207
+ x.shape[-1],
208
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
209
+ shape = [
210
+ d if i == ndim - 2 or i == ndim - 1 else 1
211
+ for i, d in enumerate(x.shape)
212
+ ]
213
+ else:
214
+ assert freqs_cis.shape == (
215
+ x.shape[1],
216
+ x.shape[-1],
217
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
218
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
219
+ return freqs_cis.view(*shape)
220
+
221
+
222
+ def rotate_half(x):
223
+ x_real, x_imag = (
224
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
225
+ ) # [B, S, H, D//2]
226
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
227
+
228
+
229
+ def apply_rotary_emb( qklist,
230
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
231
+ head_first: bool = False,
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ """
234
+ Apply rotary embeddings to input tensors using the given frequency tensor.
235
+
236
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
237
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
238
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
239
+ returned as real tensors.
240
+
241
+ Args:
242
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
243
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
244
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
245
+ head_first (bool): head dimension first (except batch dim) or not.
246
+
247
+ Returns:
248
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
249
+
250
+ """
251
+ xq, xk = qklist
252
+ qklist.clear()
253
+ xk_out = None
254
+ if isinstance(freqs_cis, tuple):
255
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
256
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
257
+ # real * cos - imag * sin
258
+ # imag * cos + real * sin
259
+ xq_dtype = xq.dtype
260
+ xq_out = xq.to(torch.float)
261
+ xq = None
262
+ xq_rot = rotate_half(xq_out)
263
+ xq_out *= cos
264
+ xq_rot *= sin
265
+ xq_out += xq_rot
266
+ del xq_rot
267
+ xq_out = xq_out.to(xq_dtype)
268
+
269
+ xk_out = xk.to(torch.float)
270
+ xk = None
271
+ xk_rot = rotate_half(xk_out)
272
+ xk_out *= cos
273
+ xk_rot *= sin
274
+ xk_out += xk_rot
275
+ del xk_rot
276
+ xk_out = xk_out.to(xq_dtype)
277
+ else:
278
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
279
+ xq_ = torch.view_as_complex(
280
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
281
+ ) # [B, S, H, D//2]
282
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
283
+ xq.device
284
+ ) # [S, D//2] --> [1, S, 1, D//2]
285
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
286
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
287
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
288
+ xk_ = torch.view_as_complex(
289
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
290
+ ) # [B, S, H, D//2]
291
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
292
+
293
+ return xq_out, xk_out
294
+
295
+
296
+
297
+
298
+ return xq_out, xk_out
299
+ def get_nd_rotary_pos_embed(
300
+ rope_dim_list,
301
+ start,
302
+ *args,
303
+ theta=10000.0,
304
+ use_real=False,
305
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
306
+ interpolation_factor: Union[float, List[float]] = 1.0,
307
+ k = 6,
308
+ L_test = 66,
309
+ enable_riflex = True
310
+ ):
311
+ """
312
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
313
+
314
+ Args:
315
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
316
+ sum(rope_dim_list) should equal to head_dim of attention layer.
317
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
318
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
319
+ *args: See above.
320
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
321
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
322
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
323
+ part and an imaginary part separately.
324
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
325
+
326
+ Returns:
327
+ pos_embed (torch.Tensor): [HW, D/2]
328
+ """
329
+
330
+ grid = get_meshgrid_nd(
331
+ start, *args, dim=len(rope_dim_list)
332
+ ) # [3, W, H, D] / [2, W, H]
333
+
334
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
335
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
336
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
337
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
338
+ assert len(theta_rescale_factor) == len(
339
+ rope_dim_list
340
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
341
+
342
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
343
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
344
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
345
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
346
+ assert len(interpolation_factor) == len(
347
+ rope_dim_list
348
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
349
+
350
+ # use 1/ndim of dimensions to encode grid_axis
351
+ embs = []
352
+ for i in range(len(rope_dim_list)):
353
+ # emb = get_1d_rotary_pos_embed(
354
+ # rope_dim_list[i],
355
+ # grid[i].reshape(-1),
356
+ # theta,
357
+ # use_real=use_real,
358
+ # theta_rescale_factor=theta_rescale_factor[i],
359
+ # interpolation_factor=interpolation_factor[i],
360
+ # ) # 2 x [WHD, rope_dim_list[i]]
361
+
362
+
363
+ # === RIFLEx modification start ===
364
+ # apply RIFLEx for time dimension
365
+ if i == 0 and enable_riflex:
366
+ emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
367
+ # === RIFLEx modification end ===
368
+ else:
369
+ emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
370
+ embs.append(emb)
371
+
372
+ if use_real:
373
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
374
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
375
+ return cos, sin
376
+ else:
377
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
378
+ return emb
379
+
380
+
381
+ def get_1d_rotary_pos_embed(
382
+ dim: int,
383
+ pos: Union[torch.FloatTensor, int],
384
+ theta: float = 10000.0,
385
+ use_real: bool = False,
386
+ theta_rescale_factor: float = 1.0,
387
+ interpolation_factor: float = 1.0,
388
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
389
+ """
390
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
391
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
392
+
393
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
394
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
395
+ The returned tensor contains complex values in complex64 data type.
396
+
397
+ Args:
398
+ dim (int): Dimension of the frequency tensor.
399
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
400
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
401
+ use_real (bool, optional): If True, return real part and imaginary part separately.
402
+ Otherwise, return complex numbers.
403
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
404
+
405
+ Returns:
406
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
407
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
408
+ """
409
+ if isinstance(pos, int):
410
+ pos = torch.arange(pos).float()
411
+
412
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
413
+ # has some connection to NTK literature
414
+ if theta_rescale_factor != 1.0:
415
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
416
+
417
+ freqs = 1.0 / (
418
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
419
+ ) # [D/2]
420
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
421
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
422
+ if use_real:
423
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
424
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
425
+ return freqs_cos, freqs_sin
426
+ else:
427
+ freqs_cis = torch.polar(
428
+ torch.ones_like(freqs), freqs
429
+ ) # complex64 # [S, D/2]
430
+ return freqs_cis
431
+
432
+ def get_rotary_pos_embed(video_length, height, width, enable_RIFLEx = False):
433
+ target_ndim = 3
434
+ ndim = 5 - 2
435
+
436
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
437
+ patch_size = [1, 2, 2]
438
+ if isinstance(patch_size, int):
439
+ assert all(s % patch_size == 0 for s in latents_size), (
440
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
441
+ f"but got {latents_size}."
442
+ )
443
+ rope_sizes = [s // patch_size for s in latents_size]
444
+ elif isinstance(patch_size, list):
445
+ assert all(
446
+ s % patch_size[idx] == 0
447
+ for idx, s in enumerate(latents_size)
448
+ ), (
449
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), "
450
+ f"but got {latents_size}."
451
+ )
452
+ rope_sizes = [
453
+ s // patch_size[idx] for idx, s in enumerate(latents_size)
454
+ ]
455
+
456
+ if len(rope_sizes) != target_ndim:
457
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
458
+ head_dim = 128
459
+ rope_dim_list = [44, 42, 42]
460
+ if rope_dim_list is None:
461
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
462
+ assert (
463
+ sum(rope_dim_list) == head_dim
464
+ ), "sum(rope_dim_list) should equal to head_dim of attention layer"
465
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
466
+ rope_dim_list,
467
+ rope_sizes,
468
+ theta=10000,
469
+ use_real=True,
470
+ theta_rescale_factor=1,
471
+ L_test = (video_length - 1) // 4 + 1,
472
+ enable_riflex = enable_RIFLEx
473
+ )
474
+ return (freqs_cos, freqs_sin)
wan/text2video.py CHANGED
@@ -8,7 +8,7 @@ import sys
8
  import types
9
  from contextlib import contextmanager
10
  from functools import partial
11
-
12
  import torch
13
  import torch.cuda.amp as amp
14
  import torch.distributed as dist
@@ -21,6 +21,7 @@ from .modules.vae import WanVAE
21
  from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
22
  get_sampling_sigmas, retrieve_timesteps)
23
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 
24
 
25
 
26
  class WanT2V:
@@ -236,13 +237,7 @@ class WanT2V:
236
  # sample videos
237
  latents = noise
238
 
239
- # from .modules.model import identify_k
240
- # for nf in range(20, 50):
241
- # k, N_k = identify_k(10000, 44, 26)
242
- # print(f"value nb latent frames={nf}, k={k}, n_k={N_k}")
243
-
244
- freqs = self.model.get_rope_freqs(nb_latent_frames = int((frame_num - 1)/4 + 1), RIFLEx_k = 6 if enable_RIFLEx else None )
245
-
246
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
247
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
248
 
@@ -252,7 +247,7 @@ class WanT2V:
252
  for i, t in enumerate(tqdm(timesteps)):
253
  latent_model_input = latents
254
  timestep = [t]
255
-
256
  timestep = torch.stack(timestep)
257
 
258
  # self.model.to(self.device)
 
8
  import types
9
  from contextlib import contextmanager
10
  from functools import partial
11
+ from mmgp import offload
12
  import torch
13
  import torch.cuda.amp as amp
14
  import torch.distributed as dist
 
21
  from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
22
  get_sampling_sigmas, retrieve_timesteps)
23
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
24
+ from wan.modules.posemb_layers import get_rotary_pos_embed
25
 
26
 
27
  class WanT2V:
 
237
  # sample videos
238
  latents = noise
239
 
240
+ freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
 
 
 
 
 
 
241
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
242
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
243
 
 
247
  for i, t in enumerate(tqdm(timesteps)):
248
  latent_model_input = latents
249
  timestep = [t]
250
+ offload.set_step_no_for_lora(i)
251
  timestep = torch.stack(timestep)
252
 
253
  # self.model.to(self.device)