model.py
Go to the documentation of this file.
1 """
2 Unets for filling dead channels of collection view images and induction view images. Architectures
3 for the two models are the same but the padding is different as the views produce different sized images.
4 
5 Architecture modelled on the submanifold sparse Unet (arXiv:1711.10275). Ideally this model would
6 make use of sparse convolutions but current software restrictions means it does not.
7 """
8 
9 import torch
10 from torch import nn
11 
12 
13 class UnetInduction(nn.Module):
14  def __init__(self, in_channels=1, out_channels=1):
15  super(UnetInduction, self).__init__()
16 
17  self.conv_in = self.single_conv_in(1, 4, 3)
18  self.convs1_L = self.conv_block(4, 4 ,3)
19  self.down_conv1 = self.down_conv(4, 8, 3, 3)
20  self.convs2_L = self.conv_block(8, 8, 3)
21  self.down_conv2 = self.down_conv(8, 16, 3, 3)
22  self.convs3_L = self.conv_block(16, 16, 3, padding=[(0,0), (0,0)])
23  self.down_conv3 = self.down_conv(16, 32, 3, 3)
24  self.convs4_L = self.conv_block(32, 32, 3, padding=[(0,0),(0,0)])
25 
26  self.down_conv_bottom = self.down_conv(32, 64, 3, 3)
27  self.convs_bottom = self.conv_block(64, 64, 3, padding=[(1,1),(1,1)])
28  self.up_conv_bottom = self.up_conv(64, 32, 3, 3, output_padding=(0,0))
29 
30  self.convs4_R = self.conv_block(32*2, 32, 3, padding=[(2,2),(2,2)])
31  self.up_conv1 = self.up_conv(32, 16, 3, 3, output_padding=(2,0))
32  self.convs3_R = self.conv_block(16*2, 16, 3, padding=[(2,2),(2,2)])
33  self.up_conv2 = self.up_conv(16, 8, 3, 3, output_padding=(2,2))
34  self.convs2_R = self.conv_block(8*2, 8, 3)
35  self.up_conv3 = self.up_conv(8, 4, 3, 3, output_padding=(0,2))
36  self.convs1_R = self.conv_block(4*2, 4, 3)
37  self.conv_out = self.single_conv_out(4, 1, 3)
38 
39  def forward(self, conv1):
40  conv1 = self.conv_in(conv1)
41  conv1 = self.convs1_L(conv1)
42  conv2 = self.down_conv1(conv1)
43  conv2 = self.convs2_L(conv2)
44  conv3 = self.down_conv2(conv2)
45  conv3 = self.convs3_L(conv3)
46  conv4 = self.down_conv3(conv3)
47  conv4 = self.convs4_L(conv4)
48 
49  conv_bottom = self.down_conv_bottom(conv4)
50  conv_bottom = self.convs_bottom(conv_bottom)
51  conv_bottom = self.up_conv_bottom(conv_bottom)
52 
53  conv4 = self.convs4_R(torch.cat([conv_bottom, conv4], 1))
54  conv4 = self.up_conv1(conv4)
55  conv3 = self.convs3_R(torch.cat([conv4, conv3], 1))
56  conv3 = self.up_conv2(conv3)
57  conv2 = self.convs2_R(torch.cat([conv3, conv2], 1))
58  conv2 = self.up_conv3(conv2)
59  conv1 = self.convs1_R(torch.cat([conv2, conv1], 1))
60  conv1 = self.conv_out(conv1)
61 
62  return conv1
63 
64  def single_conv_in(self, in_channels, out_channels, kernel_size, padding=1):
65  conv = nn.Sequential(
66  nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
67 
68  return conv
69 
70  def single_conv_out(self, in_channels, out_channels, kernel_size, padding=1):
71  conv = nn.Sequential(
72  nn.BatchNorm2d(in_channels),
73  nn.ReLU(),
74  nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
75 
76  return conv
77 
78  def conv_block(self, in_channels, out_channels, kernel_size, padding=[(1,1),(1,1)]):
79  conv_block = nn.Sequential(
80  nn.BatchNorm2d(in_channels),
81  nn.ReLU(),
82  nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding[0]),
83  nn.BatchNorm2d(out_channels),
84  nn.ReLU(),
85  nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding[1]))
86 
87  return conv_block
88 
89  def down_conv(self, in_channels, out_channels, kernel_size, stride):
90  down_conv = nn.Sequential(
91  nn.BatchNorm2d(in_channels),
92  nn.ReLU(),
93  nn.Conv2d(in_channels, out_channels, kernel_size, stride))
94 
95  return down_conv
96 
97  def up_conv(self, in_channels, out_channels, kernel_size, stride, output_padding=0):
98  up_conv = nn.Sequential(
99  nn.BatchNorm2d(in_channels),
100  nn.ReLU(),
101  nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, output_padding=output_padding))
102 
103  return up_conv
104 
105 
106 class UnetCollection(nn.Module):
107  def __init__(self, in_channels=1, out_channels=1):
108  super(UnetCollection, self).__init__()
109 
110  self.conv_in = self.single_conv_in(1, 4, 3)
111  self.convs1_L = self.conv_block(4, 4 ,3)
112  self.down_conv1 = self.down_conv(4, 8, 3, 3)
113  self.convs2_L = self.conv_block(8, 8, 3)
114  self.down_conv2 = self.down_conv(8, 16, 3, 3)
115  self.convs3_L = self.conv_block(16, 16, 3, padding=[(0,0), (0,0)])
116  self.down_conv3 = self.down_conv(16, 32, 3, 3)
117  self.convs4_L = self.conv_block(32, 32, 3, padding=[(0,0),(0,0)])
118 
119  self.down_conv_bottom = self.down_conv(32, 64, 3, 3)
120  self.convs_bottom = self.conv_block(64, 64, 3, padding=[(1,1),(1,1)])
121  self.up_conv_bottom = self.up_conv(64, 32, 3, 3, output_padding=(0,0))
122 
123  self.convs4_R = self.conv_block(32*2, 32, 3, padding=[(2,2),(2,2)])
124  self.up_conv1 = self.up_conv(32, 16, 3, 3, output_padding=(2,1))
125  self.convs3_R = self.conv_block(16*2, 16, 3, padding=[(2,2),(2,2)])
126  self.up_conv2 = self.up_conv(16, 8, 3, 3, output_padding=(2,1))
127  self.convs2_R = self.conv_block(8*2, 8, 3)
128  self.up_conv3 = self.up_conv(8, 4, 3, 3, output_padding=(0,0))
129  self.convs1_R = self.conv_block(4*2, 4, 3)
130  self.conv_out = self.single_conv_out(4, 1, 3)
131 
132  def forward(self, conv1):
133  conv1 = self.conv_in(conv1)
134  conv1 = self.convs1_L(conv1)
135  conv2 = self.down_conv1(conv1)
136  conv2 = self.convs2_L(conv2)
137  conv3 = self.down_conv2(conv2)
138  conv3 = self.convs3_L(conv3)
139  conv4 = self.down_conv3(conv3)
140  conv4 = self.convs4_L(conv4)
141 
142  conv_bottom = self.down_conv_bottom(conv4)
143  conv_bottom = self.convs_bottom(conv_bottom)
144  conv_bottom = self.up_conv_bottom(conv_bottom)
145 
146  conv4 = self.convs4_R(torch.cat([conv_bottom, conv4], 1))
147  conv4 = self.up_conv1(conv4)
148  conv3 = self.convs3_R(torch.cat([conv4, conv3], 1))
149  conv3 = self.up_conv2(conv3)
150  conv2 = self.convs2_R(torch.cat([conv3, conv2], 1))
151  conv2 = self.up_conv3(conv2)
152  conv1 = self.convs1_R(torch.cat([conv2, conv1], 1))
153  conv1 = self.conv_out(conv1)
154 
155  return conv1
156 
157  def single_conv_in(self, in_channels, out_channels, kernel_size, padding=1):
158  conv = nn.Sequential(
159  nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
160 
161  return conv
162 
163  def single_conv_out(self, in_channels, out_channels, kernel_size, padding=1):
164  conv = nn.Sequential(
165  nn.BatchNorm2d(in_channels),
166  nn.ReLU(),
167  nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding))
168 
169  return conv
170 
171  def conv_block(self, in_channels, out_channels, kernel_size, padding=[(1,1),(1,1)]):
172  conv_block = nn.Sequential(
173  nn.BatchNorm2d(in_channels),
174  nn.ReLU(),
175  nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding[0]),
176  nn.BatchNorm2d(out_channels),
177  nn.ReLU(),
178  nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding[1]))
179 
180  return conv_block
181 
182  def down_conv(self, in_channels, out_channels, kernel_size, stride):
183  down_conv = nn.Sequential(
184  nn.BatchNorm2d(in_channels),
185  nn.ReLU(),
186  nn.Conv2d(in_channels, out_channels, kernel_size, stride))
187 
188  return down_conv
189 
190  def up_conv(self, in_channels, out_channels, kernel_size, stride, output_padding=0):
191  up_conv = nn.Sequential(
192  nn.BatchNorm2d(in_channels),
193  nn.ReLU(),
194  nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, output_padding=output_padding))
195 
196  return up_conv
197 
def down_conv(self, in_channels, out_channels, kernel_size, stride)
Definition: model.py:182
def single_conv_out(self, in_channels, out_channels, kernel_size, padding=1)
Definition: model.py:163
def down_conv(self, in_channels, out_channels, kernel_size, stride)
Definition: model.py:89
def forward(self, conv1)
Definition: model.py:39
def conv_block(self, in_channels, out_channels, kernel_size, padding=[(1, 1))
Definition: model.py:78
def up_conv(self, in_channels, out_channels, kernel_size, stride, output_padding=0)
Definition: model.py:97
def conv_block(self, in_channels, out_channels, kernel_size, padding=[(1, 1))
Definition: model.py:171
def __init__(self, in_channels=1, out_channels=1)
Definition: model.py:107
def forward(self, conv1)
Definition: model.py:132
def __init__(self, in_channels=1, out_channels=1)
Definition: model.py:14
def single_conv_in(self, in_channels, out_channels, kernel_size, padding=1)
Definition: model.py:157
def up_conv(self, in_channels, out_channels, kernel_size, stride, output_padding=0)
Definition: model.py:190
def single_conv_out(self, in_channels, out_channels, kernel_size, padding=1)
Definition: model.py:70
def single_conv_in(self, in_channels, out_channels, kernel_size, padding=1)
Definition: model.py:64