Spaces:
Runtime error
Runtime error
Commit
Β·
b6a5660
1
Parent(s):
d8df719
Add tensor_to_mp4
Browse files- lvdm/utils/saving_utils.py +18 -0
lvdm/utils/saving_utils.py
CHANGED
|
@@ -15,6 +15,24 @@ from torch import Tensor
|
|
| 15 |
from torchvision.transforms.functional import to_tensor
|
| 16 |
|
| 17 |
# ----------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def savenp2sheet(imgs, savepath, nrow=None):
|
| 19 |
""" save multiple imgs (in numpy array type) to a img sheet.
|
| 20 |
img sheet is one row.
|
|
|
|
| 15 |
from torchvision.transforms.functional import to_tensor
|
| 16 |
|
| 17 |
# ----------------------------------------------------------------------------------------------
|
| 18 |
+
def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
|
| 19 |
+
"""
|
| 20 |
+
video: torch.Tensor, b,c,t,h,w, 0-1
|
| 21 |
+
if -1~1, enable rescale=True
|
| 22 |
+
"""
|
| 23 |
+
n = video.shape[0]
|
| 24 |
+
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
| 25 |
+
nrow = int(np.sqrt(n)) if nrow is None else nrow
|
| 26 |
+
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video] # [3, grid_h, grid_w]
|
| 27 |
+
grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [T, 3, grid_h, grid_w]
|
| 28 |
+
grid = torch.clamp(grid.float(), -1., 1.)
|
| 29 |
+
if rescale:
|
| 30 |
+
grid = (grid + 1.0) / 2.0
|
| 31 |
+
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
|
| 32 |
+
#print(f'Save video to {savepath}')
|
| 33 |
+
torchvision.io.write_video(savepath, grid, fps=fps, video_codec='h264', options={'crf': '10'})
|
| 34 |
+
|
| 35 |
+
# ----------------------------------------------------------------------------------------------
|
| 36 |
def savenp2sheet(imgs, savepath, nrow=None):
|
| 37 |
""" save multiple imgs (in numpy array type) to a img sheet.
|
| 38 |
img sheet is one row.
|