File size: 3,611 Bytes
eda40d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# Copyright 2020 TU Darmstadt
# Licnese: Apache 2.0 License.
# https://github.com/visinf/1-stage-wseg/blob/master/models/mods/pamr.py
import torch
import torch.nn.functional as F
import torch.nn as nn
from functools import partial
#
# Helper modules
#
class LocalAffinity(nn.Module):
def __init__(self, dilations=[1]):
super(LocalAffinity, self).__init__()
self.dilations = dilations
weight = self._init_aff()
self.register_buffer('kernel', weight)
def _init_aff(self):
# initialising the shift kernel
weight = torch.zeros(8, 1, 3, 3)
for i in range(weight.size(0)):
weight[i, 0, 1, 1] = 1
weight[0, 0, 0, 0] = -1
weight[1, 0, 0, 1] = -1
weight[2, 0, 0, 2] = -1
weight[3, 0, 1, 0] = -1
weight[4, 0, 1, 2] = -1
weight[5, 0, 2, 0] = -1
weight[6, 0, 2, 1] = -1
weight[7, 0, 2, 2] = -1
self.weight_check = weight.clone()
return weight
def forward(self, x):
self.weight_check = self.weight_check.type_as(x)
assert torch.all(self.weight_check.eq(self.kernel))
B,K,H,W = x.size()
x = x.view(B*K,1,H,W)
x_affs = []
for d in self.dilations:
x_pad = F.pad(x, [d]*4, mode='replicate')
x_aff = F.conv2d(x_pad, self.kernel, dilation=d)
x_affs.append(x_aff)
x_aff = torch.cat(x_affs, 1)
return x_aff.view(B,K,-1,H,W)
class LocalAffinityCopy(LocalAffinity):
def _init_aff(self):
# initialising the shift kernel
weight = torch.zeros(8, 1, 3, 3)
weight[0, 0, 0, 0] = 1
weight[1, 0, 0, 1] = 1
weight[2, 0, 0, 2] = 1
weight[3, 0, 1, 0] = 1
weight[4, 0, 1, 2] = 1
weight[5, 0, 2, 0] = 1
weight[6, 0, 2, 1] = 1
weight[7, 0, 2, 2] = 1
self.weight_check = weight.clone()
return weight
class LocalStDev(LocalAffinity):
def _init_aff(self):
weight = torch.zeros(9, 1, 3, 3)
weight.zero_()
weight[0, 0, 0, 0] = 1
weight[1, 0, 0, 1] = 1
weight[2, 0, 0, 2] = 1
weight[3, 0, 1, 0] = 1
weight[4, 0, 1, 1] = 1
weight[5, 0, 1, 2] = 1
weight[6, 0, 2, 0] = 1
weight[7, 0, 2, 1] = 1
weight[8, 0, 2, 2] = 1
self.weight_check = weight.clone()
return weight
def forward(self, x):
# returns (B,K,P,H,W), where P is the number
# of locations
x = super(LocalStDev, self).forward(x)
return x.std(2, keepdim=True)
class LocalAffinityAbs(LocalAffinity):
def forward(self, x):
x = super(LocalAffinityAbs, self).forward(x)
return torch.abs(x)
#
# PAMR module
#
class PAMR(nn.Module):
def __init__(self, num_iter=1, dilations=[1]):
super(PAMR, self).__init__()
self.num_iter = num_iter
self.aff_x = LocalAffinityAbs(dilations)
self.aff_m = LocalAffinityCopy(dilations)
self.aff_std = LocalStDev(dilations)
def forward(self, x, mask):
mask = F.interpolate(mask, size=x.size()[-2:], mode="bilinear", align_corners=True)
# x: [BxKxHxW]
# mask: [BxCxHxW]
B,K,H,W = x.size()
_,C,_,_ = mask.size()
x_std = self.aff_std(x)
x = -self.aff_x(x) / (1e-8 + 0.1 * x_std)
x = x.mean(1, keepdim=True)
x = F.softmax(x, 2)
for _ in range(self.num_iter):
m = self.aff_m(mask) # [BxCxPxHxW]
mask = (m * x).sum(2)
# xvals: [BxCxHxW]
return mask
|