154 lines
5.6 KiB
Python
154 lines
5.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class LayerNormFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, weight, bias, eps):
|
|
ctx.eps = eps
|
|
N, C, H, W = x.size()
|
|
mu = x.mean(1, keepdim=True)
|
|
var = (x - mu).pow(2).mean(1, keepdim=True)
|
|
y = (x - mu) / (var + eps).sqrt()
|
|
ctx.save_for_backward(y, var, weight)
|
|
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
eps = ctx.eps
|
|
N, C, H, W = grad_output.size()
|
|
y, var, weight = ctx.saved_variables
|
|
g = grad_output * weight.view(1, C, 1, 1)
|
|
mean_g = g.mean(dim=1, keepdim=True)
|
|
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
|
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
|
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None
|
|
|
|
class LayerNorm2d(nn.Module):
|
|
def __init__(self, channels, eps=1e-6):
|
|
super(LayerNorm2d, self).__init__()
|
|
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
|
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
|
|
|
class SimpleGate(nn.Module):
|
|
def forward(self, x):
|
|
x1, x2 = x.chunk(2, dim=1)
|
|
return x1 * x2
|
|
|
|
class NAFBlock(nn.Module):
|
|
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
|
super().__init__()
|
|
dw_channel = c * DW_Expand
|
|
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
|
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True)
|
|
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
|
|
|
self.sca = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True),
|
|
)
|
|
|
|
self.sg = SimpleGate()
|
|
|
|
ffn_channel = FFN_Expand * c
|
|
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
|
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
|
|
|
self.norm1 = LayerNorm2d(c)
|
|
self.norm2 = LayerNorm2d(c)
|
|
|
|
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
|
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
|
|
|
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
|
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
|
|
|
def forward(self, inp):
|
|
x = inp
|
|
x = self.norm1(x)
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
x = self.sg(x)
|
|
x = x * self.sca(x)
|
|
x = self.conv3(x)
|
|
x = self.dropout1(x)
|
|
y = inp + x * self.beta
|
|
x = self.conv4(self.norm2(y))
|
|
x = self.sg(x)
|
|
x = self.conv5(x)
|
|
x = self.dropout2(x)
|
|
return y + x * self.gamma
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, in_ch, out_ch, dw_expand=2, ffn_expand=2, dropout=0.):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
|
|
self.naf_block = NAFBlock(
|
|
c=out_ch,
|
|
DW_Expand=dw_expand,
|
|
FFN_Expand=ffn_expand,
|
|
drop_out_rate=dropout
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.naf_block(x)
|
|
return x
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, chs=(3,64,128,256)):
|
|
super().__init__()
|
|
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
|
self.pool = nn.MaxPool2d(2)
|
|
|
|
def forward(self, x):
|
|
ftrs = []
|
|
for block in self.enc_blocks:
|
|
x = block(x)
|
|
ftrs.append(x)
|
|
x = self.pool(x)
|
|
return ftrs
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, chs=(256, 128, 64)):
|
|
super().__init__()
|
|
self.chs = chs
|
|
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
|
|
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
|
|
|
def forward(self, x, encoder_features):
|
|
for i in range(len(self.chs)-1):
|
|
x = self.upconvs[i](x)
|
|
enc_ftrs = encoder_features[i]
|
|
x = torch.cat([x, enc_ftrs], dim=1)
|
|
x = self.dec_blocks[i](x)
|
|
return x
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(self, enc_chs=(3, 32, 64, 128), dec_chs=(128, 64, 32), out_ch=4, out_sz=(252, 252)):
|
|
super().__init__()
|
|
self.encoder = Encoder(enc_chs)
|
|
self.decoder = Decoder(dec_chs)
|
|
self.head = nn.Conv2d(dec_chs[-1], out_ch, 1)
|
|
self.out_sz = out_sz
|
|
|
|
def forward(self, x):
|
|
enc_ftrs = self.encoder(x)
|
|
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
|
|
out = self.head(out)
|
|
out = F.interpolate(out, self.out_sz)
|
|
out = torch.clamp(out, min=0., max=1.)
|
|
return out
|
|
|
|
class hard_log_loss(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
loss = (-1 * torch.log(1 - torch.clamp(torch.abs(x - y),0,1) + 1e-6)).mean()
|
|
return loss |