|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.nn as nn |
|
|
|
|
|
from functools import partial |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LocalAffinity(nn.Module): |
|
|
|
|
|
def __init__(self, dilations=[1]): |
|
|
super(LocalAffinity, self).__init__() |
|
|
self.dilations = dilations |
|
|
weight = self._init_aff() |
|
|
self.register_buffer('kernel', weight) |
|
|
|
|
|
def _init_aff(self): |
|
|
|
|
|
weight = torch.zeros(8, 1, 3, 3) |
|
|
|
|
|
for i in range(weight.size(0)): |
|
|
weight[i, 0, 1, 1] = 1 |
|
|
|
|
|
weight[0, 0, 0, 0] = -1 |
|
|
weight[1, 0, 0, 1] = -1 |
|
|
weight[2, 0, 0, 2] = -1 |
|
|
|
|
|
weight[3, 0, 1, 0] = -1 |
|
|
weight[4, 0, 1, 2] = -1 |
|
|
|
|
|
weight[5, 0, 2, 0] = -1 |
|
|
weight[6, 0, 2, 1] = -1 |
|
|
weight[7, 0, 2, 2] = -1 |
|
|
|
|
|
self.weight_check = weight.clone() |
|
|
|
|
|
return weight |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
self.weight_check = self.weight_check.type_as(x) |
|
|
assert torch.all(self.weight_check.eq(self.kernel)) |
|
|
|
|
|
B,K,H,W = x.size() |
|
|
x = x.view(B*K,1,H,W) |
|
|
|
|
|
x_affs = [] |
|
|
for d in self.dilations: |
|
|
x_pad = F.pad(x, [d]*4, mode='replicate') |
|
|
x_aff = F.conv2d(x_pad, self.kernel, dilation=d) |
|
|
x_affs.append(x_aff) |
|
|
|
|
|
x_aff = torch.cat(x_affs, 1) |
|
|
return x_aff.view(B,K,-1,H,W) |
|
|
|
|
|
class LocalAffinityCopy(LocalAffinity): |
|
|
|
|
|
def _init_aff(self): |
|
|
|
|
|
weight = torch.zeros(8, 1, 3, 3) |
|
|
|
|
|
weight[0, 0, 0, 0] = 1 |
|
|
weight[1, 0, 0, 1] = 1 |
|
|
weight[2, 0, 0, 2] = 1 |
|
|
|
|
|
weight[3, 0, 1, 0] = 1 |
|
|
weight[4, 0, 1, 2] = 1 |
|
|
|
|
|
weight[5, 0, 2, 0] = 1 |
|
|
weight[6, 0, 2, 1] = 1 |
|
|
weight[7, 0, 2, 2] = 1 |
|
|
|
|
|
self.weight_check = weight.clone() |
|
|
return weight |
|
|
|
|
|
class LocalStDev(LocalAffinity): |
|
|
|
|
|
def _init_aff(self): |
|
|
weight = torch.zeros(9, 1, 3, 3) |
|
|
weight.zero_() |
|
|
|
|
|
weight[0, 0, 0, 0] = 1 |
|
|
weight[1, 0, 0, 1] = 1 |
|
|
weight[2, 0, 0, 2] = 1 |
|
|
|
|
|
weight[3, 0, 1, 0] = 1 |
|
|
weight[4, 0, 1, 1] = 1 |
|
|
weight[5, 0, 1, 2] = 1 |
|
|
|
|
|
weight[6, 0, 2, 0] = 1 |
|
|
weight[7, 0, 2, 1] = 1 |
|
|
weight[8, 0, 2, 2] = 1 |
|
|
|
|
|
self.weight_check = weight.clone() |
|
|
return weight |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
x = super(LocalStDev, self).forward(x) |
|
|
|
|
|
return x.std(2, keepdim=True) |
|
|
|
|
|
class LocalAffinityAbs(LocalAffinity): |
|
|
|
|
|
def forward(self, x): |
|
|
x = super(LocalAffinityAbs, self).forward(x) |
|
|
return torch.abs(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PAMR(nn.Module): |
|
|
|
|
|
def __init__(self, num_iter=1, dilations=[1]): |
|
|
super(PAMR, self).__init__() |
|
|
|
|
|
self.num_iter = num_iter |
|
|
self.aff_x = LocalAffinityAbs(dilations) |
|
|
self.aff_m = LocalAffinityCopy(dilations) |
|
|
self.aff_std = LocalStDev(dilations) |
|
|
|
|
|
def forward(self, x, mask): |
|
|
mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True) |
|
|
|
|
|
|
|
|
|
|
|
B,K,H,W = x.size() |
|
|
_,C,_,_ = mask.size() |
|
|
|
|
|
x_std = self.aff_std(x) |
|
|
|
|
|
x = -self.aff_x(x) / (1e-8 + 0.1 * x_std) |
|
|
x = x.mean(1, keepdim=True) |
|
|
x = F.softmax(x, 2) |
|
|
|
|
|
for _ in range(self.num_iter): |
|
|
m = self.aff_m(mask) |
|
|
mask = (m * x).sum(2) |
|
|
|
|
|
|
|
|
return mask |
|
|
|