1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
| import torch import torch.nn as nn import torch.nn.functional as F import numpy as np
class ResBlock(nn.Module): def __init__(self, inchannel, outchannel, stride=1): super(ResBlock, self).__init__() self.left = nn.Sequential( nn.Conv2d(inchannel, outchannel, kernel_size=(1, 3), stride=stride, padding=(0, 1), bias=False), nn.BatchNorm2d(outchannel), nn.ReLU(inplace=True), nn.Conv2d(outchannel, outchannel, kernel_size=(1, 3), stride=1, padding=(0, 1), bias=False), nn.BatchNorm2d(outchannel) ) self.shortcut = nn.Sequential() if stride != 1 or inchannel != outchannel: self.shortcut = nn.Sequential( nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(outchannel) )
def forward(self, x): out = self.left(x) out = out + self.shortcut(x) out = F.relu(out)
return out
class ResNet(nn.Module): def __init__(self, ResBlock, num_classes=1000): super(ResNet, self).__init__() self.inchannel = 64 self.conv1 = nn.Sequential( nn.Conv2d(1, 64, kernel_size=(1, 3), stride=1, padding=(0, 1), bias=False), nn.BatchNorm2d(64), nn.ReLU() ) self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1) self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2) self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2) self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2) self.fc = nn.Linear(15872, num_classes)
def make_layer(self, block, channels, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append(block(self.inchannel, channels, stride)) self.inchannel = channels return nn.Sequential(*layers)
def forward(self, x): out = self.conv1(x) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, kernel_size=(1, 4)) out = out.view(out.size(0), -1) out = self.fc(out).reshape((out.shape[0], 1, 1, 1000)) return out
|