Skip to content

Commit 410cb08

Browse files
authored
Add files via upload
1 parent 6104d40 commit 410cb08

22 files changed

+836
-48
lines changed

dataprocess/Augmain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataprocess.Augmentation.ImageAugmentation import DataAug3D
22

33
if __name__ == '__main__':
4-
aug = DataAug3D(rotation=5, width_shift=0.01, height_shift=0.01, depth_shift=0, zoom_range=0,
4+
aug = DataAug3D(rotation=10, width_shift=0.01, height_shift=0.01, depth_shift=0, zoom_range=0,
55
vertical_flip=True, horizontal_flip=True)
6-
aug.DataAugmentation('data/traindata.csv', 15, aug_path='D:\challenge\data\KiPA2022\\trainstage\\augtrain/')
6+
aug.DataAugmentation('data/traindata.csv', 10, aug_path='D:\challenge\data\KiPA2022\\trainstage\\augtrain/')
6.71 KB
Binary file not shown.

dataprocess/data/trainaugdata.csv

Lines changed: 651 additions & 0 deletions
Large diffs are not rendered by default.

dataprocess/data/traindata.csv

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
Image,Mask
2+
D:\challenge\data\KiPA2022\trainstage\train/Image/0.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/0.npy
3+
D:\challenge\data\KiPA2022\trainstage\train/Image/1.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/1.npy
4+
D:\challenge\data\KiPA2022\trainstage\train/Image/10.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/10.npy
5+
D:\challenge\data\KiPA2022\trainstage\train/Image/11.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/11.npy
6+
D:\challenge\data\KiPA2022\trainstage\train/Image/12.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/12.npy
7+
D:\challenge\data\KiPA2022\trainstage\train/Image/13.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/13.npy
8+
D:\challenge\data\KiPA2022\trainstage\train/Image/14.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/14.npy
9+
D:\challenge\data\KiPA2022\trainstage\train/Image/15.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/15.npy
10+
D:\challenge\data\KiPA2022\trainstage\train/Image/16.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/16.npy
11+
D:\challenge\data\KiPA2022\trainstage\train/Image/17.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/17.npy
12+
D:\challenge\data\KiPA2022\trainstage\train/Image/18.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/18.npy
13+
D:\challenge\data\KiPA2022\trainstage\train/Image/19.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/19.npy
14+
D:\challenge\data\KiPA2022\trainstage\train/Image/2.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/2.npy
15+
D:\challenge\data\KiPA2022\trainstage\train/Image/20.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/20.npy
16+
D:\challenge\data\KiPA2022\trainstage\train/Image/21.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/21.npy
17+
D:\challenge\data\KiPA2022\trainstage\train/Image/22.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/22.npy
18+
D:\challenge\data\KiPA2022\trainstage\train/Image/23.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/23.npy
19+
D:\challenge\data\KiPA2022\trainstage\train/Image/24.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/24.npy
20+
D:\challenge\data\KiPA2022\trainstage\train/Image/25.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/25.npy
21+
D:\challenge\data\KiPA2022\trainstage\train/Image/26.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/26.npy
22+
D:\challenge\data\KiPA2022\trainstage\train/Image/27.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/27.npy
23+
D:\challenge\data\KiPA2022\trainstage\train/Image/28.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/28.npy
24+
D:\challenge\data\KiPA2022\trainstage\train/Image/29.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/29.npy
25+
D:\challenge\data\KiPA2022\trainstage\train/Image/3.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/3.npy
26+
D:\challenge\data\KiPA2022\trainstage\train/Image/30.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/30.npy
27+
D:\challenge\data\KiPA2022\trainstage\train/Image/31.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/31.npy
28+
D:\challenge\data\KiPA2022\trainstage\train/Image/32.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/32.npy
29+
D:\challenge\data\KiPA2022\trainstage\train/Image/33.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/33.npy
30+
D:\challenge\data\KiPA2022\trainstage\train/Image/34.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/34.npy
31+
D:\challenge\data\KiPA2022\trainstage\train/Image/35.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/35.npy
32+
D:\challenge\data\KiPA2022\trainstage\train/Image/36.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/36.npy
33+
D:\challenge\data\KiPA2022\trainstage\train/Image/37.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/37.npy
34+
D:\challenge\data\KiPA2022\trainstage\train/Image/38.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/38.npy
35+
D:\challenge\data\KiPA2022\trainstage\train/Image/39.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/39.npy
36+
D:\challenge\data\KiPA2022\trainstage\train/Image/4.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/4.npy
37+
D:\challenge\data\KiPA2022\trainstage\train/Image/40.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/40.npy
38+
D:\challenge\data\KiPA2022\trainstage\train/Image/41.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/41.npy
39+
D:\challenge\data\KiPA2022\trainstage\train/Image/42.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/42.npy
40+
D:\challenge\data\KiPA2022\trainstage\train/Image/43.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/43.npy
41+
D:\challenge\data\KiPA2022\trainstage\train/Image/44.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/44.npy
42+
D:\challenge\data\KiPA2022\trainstage\train/Image/45.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/45.npy
43+
D:\challenge\data\KiPA2022\trainstage\train/Image/46.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/46.npy
44+
D:\challenge\data\KiPA2022\trainstage\train/Image/47.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/47.npy
45+
D:\challenge\data\KiPA2022\trainstage\train/Image/48.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/48.npy
46+
D:\challenge\data\KiPA2022\trainstage\train/Image/49.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/49.npy
47+
D:\challenge\data\KiPA2022\trainstage\train/Image/5.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/5.npy
48+
D:\challenge\data\KiPA2022\trainstage\train/Image/50.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/50.npy
49+
D:\challenge\data\KiPA2022\trainstage\train/Image/51.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/51.npy
50+
D:\challenge\data\KiPA2022\trainstage\train/Image/52.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/52.npy
51+
D:\challenge\data\KiPA2022\trainstage\train/Image/53.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/53.npy
52+
D:\challenge\data\KiPA2022\trainstage\train/Image/54.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/54.npy
53+
D:\challenge\data\KiPA2022\trainstage\train/Image/55.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/55.npy
54+
D:\challenge\data\KiPA2022\trainstage\train/Image/56.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/56.npy
55+
D:\challenge\data\KiPA2022\trainstage\train/Image/57.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/57.npy
56+
D:\challenge\data\KiPA2022\trainstage\train/Image/58.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/58.npy
57+
D:\challenge\data\KiPA2022\trainstage\train/Image/59.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/59.npy
58+
D:\challenge\data\KiPA2022\trainstage\train/Image/6.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/6.npy
59+
D:\challenge\data\KiPA2022\trainstage\train/Image/60.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/60.npy
60+
D:\challenge\data\KiPA2022\trainstage\train/Image/61.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/61.npy
61+
D:\challenge\data\KiPA2022\trainstage\train/Image/62.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/62.npy
62+
D:\challenge\data\KiPA2022\trainstage\train/Image/63.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/63.npy
63+
D:\challenge\data\KiPA2022\trainstage\train/Image/64.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/64.npy
64+
D:\challenge\data\KiPA2022\trainstage\train/Image/7.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/7.npy
65+
D:\challenge\data\KiPA2022\trainstage\train/Image/8.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/8.npy
66+
D:\challenge\data\KiPA2022\trainstage\train/Image/9.npy,D:\challenge\data\KiPA2022\trainstage\train/Mask/9.npy

dataprocess/data/validata.csv

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Image,Mask
2+
D:\challenge\data\KiPA2022\trainstage\validation/Image/0.npy,D:\challenge\data\KiPA2022\trainstage\validation/Mask/0.npy
3+
D:\challenge\data\KiPA2022\trainstage\validation/Image/1.npy,D:\challenge\data\KiPA2022\trainstage\validation/Mask/1.npy
4+
D:\challenge\data\KiPA2022\trainstage\validation/Image/2.npy,D:\challenge\data\KiPA2022\trainstage\validation/Mask/2.npy
5+
D:\challenge\data\KiPA2022\trainstage\validation/Image/3.npy,D:\challenge\data\KiPA2022\trainstage\validation/Mask/3.npy
6+
D:\challenge\data\KiPA2022\trainstage\validation/Image/4.npy,D:\challenge\data\KiPA2022\trainstage\validation/Mask/4.npy

dataprocess/data/validata1.csv

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Image,Mask
2+
D:\challenge\data\KiPA2022\trainstage\validation/Image/0.npy,D:\challenge\data\KiPA2022\trainstage\validation/Mask/0.npy

dataprocess/说明文档.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1、统计数据平均大小mean size,mean spacing: (array([153.29230769, 153.29230769, 197.49230769]), array([0.63416487, 0.63416487, 0.63416487]))
2+
2、图像缩放到固定大小(112x112x128)
3+
3、归一化采用(5,95)归一化范围
4+
4、损失采用focalloss

inference.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
import os
3+
from model import *
4+
from dataprocess.utils import file_name_path
5+
import SimpleITK as sitk
6+
7+
# Use CUDA
8+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
9+
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
10+
use_cuda = torch.cuda.is_available()
11+
12+
13+
def inferencemutilunet3dtest():
14+
newSize = (112, 112, 128)
15+
Unet3d = MutilUNet3dModel(image_depth=128, image_height=112, image_width=112, image_channel=1, numclass=1,
16+
batch_size=1, loss_name='MutilFocalLoss', inference=True,
17+
model_path=r'log\MutilUNet3d\focalloss\BinaryVNet2dSegModel.pth')
18+
datapath = r"F:\MedicalData\(ok)2022KiPA\dataset\test\image"
19+
makspath = r"F:\MedicalData\(ok)2022KiPA\dataset\test\label"
20+
image_path_list = file_name_path(datapath, False, True)
21+
for i in range(len(image_path_list)):
22+
imagepathname = datapath + "/" + image_path_list[i]
23+
sitk_image = sitk.ReadImage(imagepathname)
24+
sitk_mask = Unet3d.inference(sitk_image, newSize)
25+
maskpathname = makspath + "/" + image_path_list[i]
26+
sitk.WriteImage(sitk_mask, maskpathname)
27+
28+
29+
if __name__ == '__main__':
30+
inferencemutilunet3dtest()
-610 Bytes
Binary file not shown.
651 Bytes
Binary file not shown.
245 Bytes
Binary file not shown.
222 Bytes
Binary file not shown.
102 Bytes
Binary file not shown.

model/losses.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def forward(self, y_pred_logits, y_true):
202202
intersection = torch.sum(y_true * y_pred, dim=(0, 2))
203203
denominator = torch.sum(y_true + y_pred, dim=(0, 2))
204204
gen_dice_coef = ((2. * intersection + smooth) / (denominator + smooth)).clamp_min(eps)
205-
loss = 1 - gen_dice_coef
205+
loss = - gen_dice_coef
206206
# Dice loss is undefined for non-empty classes
207207
# So we zero contribution of channel that does not have true pixels
208208
# NOTE: A better workaround would be to use loss term `mean(y_pred)`
@@ -212,27 +212,11 @@ def forward(self, y_pred_logits, y_true):
212212
return (loss * self.alpha).mean()
213213

214214

215-
class MutilCrossEntropyDiceLoss(nn.Module):
216-
"""
217-
mutil ce and dice
218-
"""
219-
220-
def __init__(self, alpha):
221-
super(MutilCrossEntropyDiceLoss, self).__init__()
222-
self.alpha = alpha
223-
224-
def forward(self, y_pred_logits, y_true):
225-
diceloss = MutilDiceLoss(self.alpha)
226-
dice = diceloss(y_pred_logits, y_true)
227-
bceloss = MutilCrossEntropyLoss(self.alpha)
228-
bce = bceloss(y_pred_logits, y_true)
229-
return bce + dice
230-
231-
232215
class LovaszLoss(nn.Module):
233216
"""
234217
mutil LovaszLoss
235218
"""
219+
236220
def __init__(self, per_image=False, ignore=None):
237221
super(LovaszLoss, self).__init__()
238222
self.ignore = ignore

model/modelUnet.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .dataset import datasetModelSegwithopencv, datasetModelSegwithnpy
66
from torch.utils.data import DataLoader
77
from .losses import BinaryDiceLoss, BinaryFocalLoss, BinaryCrossEntropyLoss, BinaryCrossEntropyDiceLoss, \
8-
MutilDiceLoss, MutilFocalLoss, MutilCrossEntropyLoss, MutilCrossEntropyDiceLoss
8+
MutilDiceLoss, MutilFocalLoss, MutilCrossEntropyLoss
99
import torch.optim as optim
1010
import numpy as np
1111
from tqdm import tqdm
@@ -709,7 +709,8 @@ def __init__(self, image_depth, image_height, image_width, image_channel, numcla
709709
self.image_channel = image_channel
710710
self.numclass = numclass
711711

712-
self.alpha = [1.] * self.numclass
712+
# self.alpha = [1.] * self.numclass
713+
self.alpha = [1., 5., 1., 5., 3.]
713714
self.gamma = 3
714715

715716
self.use_cuda = use_cuda
@@ -745,8 +746,6 @@ def _loss_function(self, lossname):
745746
return MutilDiceLoss(alpha=self.alpha)
746747
if lossname is 'MutilFocalLoss':
747748
return MutilFocalLoss(alpha=self.alpha, gamma=self.gamma)
748-
if lossname is 'MutilCrossEntropyDiceLoss':
749-
return MutilCrossEntropyDiceLoss(alpha=self.alpha)
750749

751750
def _accuracy_function(self, accuracyname, input, target):
752751
if accuracyname is 'dice':
@@ -772,8 +771,9 @@ def trainprocess(self, trainimage, trainmask, validationimage, validationmask, m
772771
showpixelvalue = showpixelvalue // (self.numclass - 1)
773772
# 1、initialize loss function and optimizer
774773
lossFunc = self._loss_function(self.loss_name)
775-
opt = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
776-
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=2, verbose=True)
774+
# opt = optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-4)
775+
opt = optim.Adam(self.model.parameters(), lr=lr)
776+
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=10, verbose=True)
777777
# 2、load data train and validation dataset
778778
train_loader = self._dataloder(trainimage, trainmask, True)
779779
val_loader = self._dataloder(validationimage, validationmask)
@@ -793,7 +793,7 @@ def trainprocess(self, trainimage, trainmask, validationimage, validationmask, m
793793
totalValidationLoss = []
794794
totalValiadtionAccu = []
795795
# 4.3、loop over the training set
796-
for batch in train_loader:
796+
for batch in val_loader:
797797
# x should tensor with shape (N,C,D,W,H)
798798
x = batch['image']
799799
# y should tensor with shape (N,C,D,W,H),
@@ -805,6 +805,9 @@ def trainprocess(self, trainimage, trainmask, validationimage, validationmask, m
805805
pred_logit, pred = self.model(x)
806806
loss = lossFunc(pred_logit, y)
807807
accu = self._accuracy_function(self.accuracyname, pred, y)
808+
savepath = model_dir + '/' + str(e + 1) + "_train_EPOCH_"
809+
save_images3d(torch.argmax(pred[0], 0), y[0], showwind, savepath,
810+
pixelvalue=showpixelvalue)
808811
# first, zero out any previously accumulated gradients,
809812
# then perform backpropagation,
810813
# and then update model parameters
@@ -841,7 +844,7 @@ def trainprocess(self, trainimage, trainmask, validationimage, validationmask, m
841844
avgValidationLoss = torch.mean(torch.stack(totalValidationLoss))
842845
avgTrainAccu = torch.mean(torch.stack(totalTrainAccu))
843846
avgValidationAccu = torch.mean(torch.stack(totalValiadtionAccu))
844-
lr_scheduler.step(avgValidationLoss)
847+
# lr_scheduler.step(avgValidationLoss)
845848
# 4.6、update our training history
846849
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
847850
H["valdation_loss"].append(avgValidationLoss.cpu().detach().numpy())

networks/ResNet2d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ def __init__(self, image_channel, numclass, elu=True):
117117

118118
self.down_tr32 = DownTransition2d(16, 32, 2, elu)
119119
self.down_tr64 = DownTransition2d(32, 64, 3, elu)
120-
self.down_tr128 = DownTransition2d(64, 128, 3, elu, dropout=True)
121-
self.down_tr256 = DownTransition2d(128, 256, 3, elu, dropout=True)
120+
self.down_tr128 = DownTransition2d(64, 128, 3, elu)
121+
self.down_tr256 = DownTransition2d(128, 256, 3, elu)
122122

123123
self.avg = GlobalAveragePooling()
124124

125125
self.fc_layers = nn.Sequential(
126-
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(),
126+
nn.Linear(256, 128),
127+
nn.ReLU(inplace=True),
127128
nn.Linear(128, self.numclass))
128129

129130
def forward(self, x):

networks/ResNet3d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ def __init__(self, image_channel, numclass, elu=True):
117117

118118
self.down_tr32 = DownTransition3d(16, 32, 2, elu)
119119
self.down_tr64 = DownTransition3d(32, 64, 3, elu)
120-
self.down_tr128 = DownTransition3d(64, 128, 3, elu, dropout=True)
121-
self.down_tr256 = DownTransition3d(128, 256, 3, elu, dropout=True)
120+
self.down_tr128 = DownTransition3d(64, 128, 3, elu)
121+
self.down_tr256 = DownTransition3d(128, 256, 3, elu)
122122

123123
self.avg = GlobalAveragePooling()
124124

125125
self.fc_layers = nn.Sequential(
126-
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(),
126+
nn.Linear(256, 128),
127+
nn.ReLU(inplace=True),
127128
nn.Linear(128, self.numclass))
128129

129130
def forward(self, x):

networks/Unet3d.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ def __init__(self, in_channels, out_channels, init_features=16):
1818
self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
1919
self.encoder2 = UNet3d._block(self.features, self.features * 2, name="enc2")
2020
self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
21-
self.encoder3 = UNet3d._block(self.features * 2, self.features * 4, name="enc3", dropout=False)
21+
self.encoder3 = UNet3d._block(self.features * 2, self.features * 4, name="enc3")
2222
self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
23-
self.encoder4 = UNet3d._block(self.features * 4, self.features * 8, name="enc4", dropout=False)
23+
self.encoder4 = UNet3d._block(self.features * 4, self.features * 8, name="enc4")
2424
self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
25-
self.bottleneck = UNet3d._block(self.features * 8, self.features * 16, name="bottleneck", dropout=False)
25+
self.bottleneck = UNet3d._block(self.features * 8, self.features * 16, name="bottleneck")
2626
self.upconv4 = nn.ConvTranspose3d(self.features * 16, self.features * 8, kernel_size=2, stride=2)
27-
self.decoder4 = UNet3d._block((self.features * 8) * 2, self.features * 8, name="dec4", dropout=False)
27+
self.decoder4 = UNet3d._block((self.features * 8) * 2, self.features * 8, name="dec4")
2828
self.upconv3 = nn.ConvTranspose3d(self.features * 8, self.features * 4, kernel_size=2, stride=2)
29-
self.decoder3 = UNet3d._block((self.features * 4) * 2, self.features * 4, name="dec3", dropout=False)
29+
self.decoder3 = UNet3d._block((self.features * 4) * 2, self.features * 4, name="dec3")
3030
self.upconv2 = nn.ConvTranspose3d(self.features * 4, self.features * 2, kernel_size=2, stride=2)
3131
self.decoder2 = UNet3d._block((self.features * 2) * 2, self.features * 2, name="dec2")
3232
self.upconv1 = nn.ConvTranspose3d(self.features * 2, self.features, kernel_size=2, stride=2)
@@ -73,7 +73,6 @@ def _block(in_channels, features, name, dropout=False):
7373
bias=False, ),),
7474
(name + "norm1", nn.BatchNorm3d(num_features=features)),
7575
(name + "relu1", nn.ReLU(inplace=True)),
76-
(name + "dropout1", nn.Dropout3d()),
7776
(name + "conv2", nn.Conv3d(
7877
in_channels=features,
7978
out_channels=features,

0 commit comments

Comments
 (0)