Talk2DINO-ViTB / pamr.py
lorebianchi98's picture
Fixed error
d120439
# Copyright 2020 TU Darmstadt
# Licnese: Apache 2.0 License.
# https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py
import torch
import torch.nn.functional as F
import torch.nn as nn
from functools import partial
#
# Helper modules
#
class LocalAffinity(nn.Module):
def __init__(self, dilations=[1]):
super(LocalAffinity, self).__init__()
self.dilations = dilations
weight = self._init_aff()
self.register_buffer('kernel', weight)
def _init_aff(self):
# initialising the shift kernel
weight = torch.zeros(8, 1, 3, 3)
for i in range(weight.size(0)):
weight[i, 0, 1, 1] = 1
weight[0, 0, 0, 0] = -1
weight[1, 0, 0, 1] = -1
weight[2, 0, 0, 2] = -1
weight[3, 0, 1, 0] = -1
weight[4, 0, 1, 2] = -1
weight[5, 0, 2, 0] = -1
weight[6, 0, 2, 1] = -1
weight[7, 0, 2, 2] = -1
self.weight_check = weight.clone()
return weight
def forward(self, x):
self.weight_check = self.weight_check.type_as(x)
assert torch.all(self.weight_check.eq(self.kernel))
B,K,H,W = x.size()
x = x.view(B*K,1,H,W)
x_affs = []
for d in self.dilations:
x_pad = F.pad(x, [d]*4, mode='replicate')
x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
x_affs.append(x_aff)
x_aff = torch.cat(x_affs, 1)
return x_aff.view(B,K,-1,H,W)
class LocalAffinityCopy(LocalAffinity):
def _init_aff(self):
# initialising the shift kernel
weight = torch.zeros(8, 1, 3, 3)
weight[0, 0, 0, 0] = 1
weight[1, 0, 0, 1] = 1
weight[2, 0, 0, 2] = 1
weight[3, 0, 1, 0] = 1
weight[4, 0, 1, 2] = 1
weight[5, 0, 2, 0] = 1
weight[6, 0, 2, 1] = 1
weight[7, 0, 2, 2] = 1
self.weight_check = weight.clone()
return weight
class LocalStDev(LocalAffinity):
def _init_aff(self):
weight = torch.zeros(9, 1, 3, 3)
weight.zero_()
weight[0, 0, 0, 0] = 1
weight[1, 0, 0, 1] = 1
weight[2, 0, 0, 2] = 1
weight[3, 0, 1, 0] = 1
weight[4, 0, 1, 1] = 1
weight[5, 0, 1, 2] = 1
weight[6, 0, 2, 0] = 1
weight[7, 0, 2, 1] = 1
weight[8, 0, 2, 2] = 1
self.weight_check = weight.clone()
return weight
def forward(self, x):
# returns (B,K,P,H,W), where P is the number
# of locations
x = super(LocalStDev, self).forward(x)
return x.std(2, keepdim=True)
class LocalAffinityAbs(LocalAffinity):
def forward(self, x):
x = super(LocalAffinityAbs, self).forward(x)
return torch.abs(x)
#
# PAMR module
#
class PAMR(nn.Module):
def __init__(self, num_iter=1, dilations=[1]):
super(PAMR, self).__init__()
self.num_iter = num_iter
self.aff_x = LocalAffinityAbs(dilations)
self.aff_m = LocalAffinityCopy(dilations)
self.aff_std = LocalStDev(dilations)
def forward(self, x, mask):
mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
# x: [BxKxHxW]
# mask: [BxCxHxW]
B,K,H,W = x.size()
_,C,_,_ = mask.size()
x_std = self.aff_std(x)
x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
x = x.mean(1, keepdim=True)
x = F.softmax(x, 2)
for _ in range(self.num_iter):
m = self.aff_m(mask) # [BxCxPxHxW]
mask = (m * x).sum(2)
# xvals: [BxCxHxW]
return mask