1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
| import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F
class ConvBNR(nn.Module): def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False): super(ConvBNR, self).__init__()
self.block = nn.Sequential( nn.Conv2d(inplanes, planes, kernel_size, stride=stride, padding=dilation, dilation=dilation, bias=bias), nn.BatchNorm2d(planes), nn.ReLU(inplace=True) )
def forward(self, x): return self.block(x)
class Conv1x1(nn.Module): def __init__(self, inplanes, planes): super(Conv1x1, self).__init__() self.conv = nn.Conv2d(inplanes, planes, 1) self.bn = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x)
return x
class EAM(nn.Module): def __init__(self): super(EAM, self).__init__() self.reduce1 = Conv1x1(256, 64) self.reduce4 = Conv1x1(2048, 256) self.block = nn.Sequential( ConvBNR(256 + 64, 256, 3), ConvBNR(256, 256, 3), nn.Conv2d(256, 1, 1) )
def forward(self, x4, x1): size = x1.size()[2:] x1 = self.reduce1(x1) x4 = self.reduce4(x4) x4 = F.interpolate(x4, size, mode='bilinear', align_corners=False) out = torch.cat((x4, x1), dim=1) out = self.block(out)
return out
|