|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Any |
|
|
from datetime import datetime |
|
|
from itertools import chain |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
DEFAULT_MEAN = IMAGENET_DEFAULT_MEAN |
|
|
DEFAULT_STD = IMAGENET_DEFAULT_STD |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unnorm(x): |
|
|
mean = torch.as_tensor(DEFAULT_MEAN, device=x.device)[None, ..., None, None] |
|
|
std = torch.as_tensor(DEFAULT_STD, device=x.device)[None, ..., None, None] |
|
|
return x.mul(std).add(mean) |
|
|
|
|
|
|
|
|
|
|
|
def check_nonfinite(x, name=""): |
|
|
rank = dist.get_rank() |
|
|
n_nan = x.isnan().sum() |
|
|
n_inf = x.isinf().sum() |
|
|
if n_nan or n_inf: |
|
|
print(f"[RANK {rank}] {name} is not finite: #nan={n_nan}, #inf={n_inf}") |
|
|
return True |
|
|
|
|
|
print(f"[RANK {rank}] {name} is OK ...") |
|
|
return False |
|
|
|
|
|
|
|
|
def normalize(t, dim, eps=1e-6): |
|
|
"""Large default eps for fp16""" |
|
|
return F.normalize(t, dim=dim, eps=eps) |
|
|
|
|
|
|
|
|
def timestamp(fmt="%y%m%d-%H%M%S"): |
|
|
return datetime.now().strftime(fmt) |
|
|
|
|
|
|
|
|
def merge_dicts_by_key(dics: List[Dict]) -> Dict[Any, List]: |
|
|
"""Merge dictionaries by key. All of dicts must have same keys.""" |
|
|
ret = {key: [] for key in dics[0].keys()} |
|
|
for dic in dics: |
|
|
for key, value in dic.items(): |
|
|
ret[key].append(value) |
|
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
def flatten_2d_list(list2d): |
|
|
return list(chain.from_iterable(list2d)) |
|
|
|
|
|
|
|
|
def num_params(module): |
|
|
return sum(p.numel() for p in module.parameters()) |
|
|
|
|
|
|
|
|
def param_trace(name, module, depth=0, max_depth=999, threshold=0, printf=print): |
|
|
if depth > max_depth: |
|
|
return |
|
|
prefix = " " * depth |
|
|
n_params = num_params(module) |
|
|
if n_params > threshold: |
|
|
printf("{:60s}\t{:10.3f}M".format(prefix + name, n_params / 1024 / 1024)) |
|
|
for n, m in module.named_children(): |
|
|
if depth == 0: |
|
|
child_name = n |
|
|
else: |
|
|
child_name = "{}.{}".format(name, n) |
|
|
param_trace(child_name, m, depth + 1, max_depth, threshold, printf) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def hash_bn(module): |
|
|
summary = [] |
|
|
for m in module.modules(): |
|
|
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): |
|
|
w = m.weight.detach().mean().item() |
|
|
b = m.bias.detach().mean().item() |
|
|
rm = m.running_mean.detach().mean().item() |
|
|
rv = m.running_var.detach().mean().item() |
|
|
summary.append((w, b, rm, rv)) |
|
|
|
|
|
if not summary: |
|
|
return 0.0, 0.0 |
|
|
|
|
|
w, b, rm, rv = [np.mean(col) for col in zip(*summary)] |
|
|
p = np.mean([w, b]) |
|
|
s = np.mean([rm, rv]) |
|
|
|
|
|
return p, s |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def hash_params(module): |
|
|
return torch.as_tensor([p.mean() for p in module.parameters()]).mean().item() |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def hashm(module): |
|
|
p = hash_params(module) |
|
|
_, s = hash_bn(module) |
|
|
|
|
|
return p, s |
|
|
|