我目前正在从事医学图像分割项目,我正在使用U-Net架构来训练我的数据集并执行分割。我想看看我的代码架构是什么样子的,我尝试了很多方法来实现相同的目标,但我无法在U-Net上绘制代码的架构。
这是密码它是U-Net架构的代码
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch
class Block(Module):
def __init__(self, inChannels, outChannels):
super().__init__()
# store the convolution and RELU layers
self.conv1 = Conv2d(inChannels, outChannels, 3)
self.relu = ReLU()
self.conv2 = Conv2d(outChannels, outChannels, 3)
def forward(self, x):
# apply CONV => RELU => CONV block to the inputs and return it
return self.conv2(self.relu(self.conv1(x)))
class Encoder(Module):
def __init__(self, channels=(3, 16, 32, 64)):
super().__init__()
# store the encoder blocks and maxpooling layer
self.encBlocks = ModuleList(
[Block(channels[i], channels[i + 1])
for i in range(len(channels) - 1)])
self.pool = MaxPool2d(2)
def forward(self, x):
# initialize an empty list to store the intermediate outputs
blockOutputs = []
# loop through the encoder blocks
for block in self.encBlocks:
# pass the inputs through the current encoder block, store
# the outputs, and then apply maxpooling on the output
x = block(x)
blockOutputs.append(x)
x = self.pool(x)
# return the list containing the intermediate outputs
return blockOutputs
class Decoder(Module):
def __init__(self, channels=(64, 32, 16)):
super().__init__()
# initialize the number of channels, upsampler blocks, and
# decoder blocks
self.channels = channels
self.upconvs = ModuleList(
[ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
for i in range(len(channels) - 1)])
self.dec_blocks = ModuleList(
[Block(channels[i], channels[i + 1])
for i in range(len(channels) - 1)])
#deco = self.upconvs + self.dec_blocks
#3return deco
def forward(self, x, encFeatures):
# loop through the number of channels
for i in range(len(self.channels) - 1):
# pass the inputs through the upsampler blocks
x = self.upconvs[i](x)
# crop the current features from the encoder blocks,
# concatenate them with the current upsampled features,
# and pass the concatenated output through the current
# decoder block
encFeat = self.crop(encFeatures[i], x)
x = torch.cat([x, encFeat], dim=1)
x = self.dec_blocks[i](x)
# return the final decoder output
return x
def crop(self, encFeatures, x):
# grab the dimensions of the inputs, and crop the encoder
# features to match the dimensions
(_, _, H, W) = x.shape
encFeatures = CenterCrop([H, W])(encFeatures)
# return the cropped features
return encFeatures
class UNet(Module):
def __init__(self, encChannels=(3, 16, 32, 64),
decChannels=(64, 32, 16),
nbClasses=1, retainDim=True,
outSize=(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)):
super().__init__()
# initialize the encoder and decoder
self.encoder = Encoder(encChannels)
self.decoder = Decoder(decChannels)
# initialize the regression head and store the class variables
self.head = Conv2d(decChannels[-1], nbClasses, 1)
self.retainDim = retainDim
self.outSize = outSize
def forward(self, x):
# grab the features from the encoder
encFeatures = self.encoder(x)
# pass the encoder features through decoder making sure that
# their dimensions are suited for concatenation
decFeatures = self.decoder(encFeatures[::-1][0],
encFeatures[::-1][1:])
# pass the decoder features through the regression head to
# obtain the segmentation mask
map = self.head(decFeatures)
# check to see if we are retaining the original output
# dimensions and if so, then resize the output to match them
if self.retainDim:
map = F.interpolate(map, self.outSize)
# return the segmentation map
return map
我使用plot_model函数,但没有得到任何结果。
我使用了类似于此的代码来可视化U-Net架构。
from keras.utils import plot_model
# Build your model
model = UNet
# Visualize the model
plot_model(model, to_file='model2.png')
请帮助我根据我的U-Net代码绘制一个网络架构。
1条答案
按热度按时间aamkag611#
您可以使用tensorboard绘制UNet架构这是我使用的代码示例