论文资料
U-Net: Convolutional Networks for Biomedical Image Segmentation
相关架构
编码器为常规的卷积层和池化层,解码器将U对面的解码器的卷积层上采样到当前解码器层。
代码实现
unet keras 该仓库使用keras来实现unet,由于unet数据过少,仓库使用了相应的数据增强的方法来扩充数据集。
class UNetEnc(nn.Module):
def __init__(self, in_channels, features, out_channels):
super().__init__()
self.up = nn.Sequential(
nn.Conv2d(in_channels, features, 3),
nn.ReLU(inplace=True),
nn.Conv2d(features, features, 3),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(features, out_channels, 2, stride=2),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.up(x)
class UNetDec(nn.Module):
def __init__(self, in_channels, out_channels, dropout=False):
super().__init__()
layers = [
nn.Conv2d(in_channels, out_channels, 3),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3),
nn.ReLU(inplace=True),
]
if dropout:
layers += [nn.Dropout(.5)]
layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]
self.down = nn.Sequential(*layers)
def forward(self, x):
return self.down(x)
class UNet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.dec1 = UNetDec(3, 64)
self.dec2 = UNetDec(64, 128)
self.dec3 = UNetDec(128, 256)
self.dec4 = UNetDec(256, 512, dropout=True)
self.center = nn.Sequential(
nn.Conv2d(512, 1024, 3),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.ConvTranspose2d(1024, 512, 2, stride=2),
nn.ReLU(inplace=True),
)
self.enc4 = UNetEnc(1024, 512, 256)
self.enc3 = UNetEnc(512, 256, 128)
self.enc2 = UNetEnc(256, 128, 64)
self.enc1 = nn.Sequential(
nn.Conv2d(128, 64, 3),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, 3),
nn.ReLU(inplace=True),
)
self.final = nn.Conv2d(64, num_classes, 1)
def forward(self, x):
dec1 = self.dec1(x)
dec2 = self.dec2(dec1)
dec3 = self.dec3(dec2)
dec4 = self.dec4(dec3)
center = self.center(dec4)
enc4 = self.enc4(torch.cat([
center, F.upsample_bilinear(dec4, center.size()[2:])], 1))
enc3 = self.enc3(torch.cat([
enc4, F.upsample_bilinear(dec3, enc4.size()[2:])], 1))
enc2 = self.enc2(torch.cat([
enc3, F.upsample_bilinear(dec2, enc3.size()[2:])], 1))
enc1 = self.enc1(torch.cat([
enc2, F.upsample_bilinear(dec1, enc2.size()[2:])], 1))
return F.upsample_bilinear(self.final(enc1), x.size()[2:])