-
Notifications
You must be signed in to change notification settings - Fork 32
/
model.py
80 lines (63 loc) · 2.51 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
class CDAutoEncoder(nn.Module):
r"""
Convolutional denoising autoencoder layer for stacked autoencoders.
This module is automatically trained when in model.training is True.
Args:
input_size: The number of features in the input
output_size: The number of features to output
stride: Stride of the convolutional layers.
"""
def __init__(self, input_size, output_size, stride):
super(CDAutoEncoder, self).__init__()
self.forward_pass = nn.Sequential(
nn.Conv2d(input_size, output_size, kernel_size=2, stride=stride, padding=0),
nn.ReLU(),
)
self.backward_pass = nn.Sequential(
nn.ConvTranspose2d(output_size, input_size, kernel_size=2, stride=2, padding=0),
nn.ReLU(),
)
self.criterion = nn.MSELoss()
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
def forward(self, x):
# Train each autoencoder individually
x = x.detach()
# Add noise, but use the original lossless input as the target.
x_noisy = x * (Variable(x.data.new(x.size()).normal_(0, 0.1)) > -.1).type_as(x)
y = self.forward_pass(x_noisy)
if self.training:
x_reconstruct = self.backward_pass(y)
loss = self.criterion(x_reconstruct, Variable(x.data, requires_grad=False))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return y.detach()
def reconstruct(self, x):
return self.backward_pass(x)
class StackedAutoEncoder(nn.Module):
r"""
A stacked autoencoder made from the convolutional denoising autoencoders above.
Each autoencoder is trained independently and at the same time.
"""
def __init__(self):
super(StackedAutoEncoder, self).__init__()
self.ae1 = CDAutoEncoder(3, 128, 2)
self.ae2 = CDAutoEncoder(128, 256, 2)
self.ae3 = CDAutoEncoder(256, 512, 2)
def forward(self, x):
a1 = self.ae1(x)
a2 = self.ae2(a1)
a3 = self.ae3(a2)
if self.training:
return a3
else:
return a3, self.reconstruct(a3)
def reconstruct(self, x):
a2_reconstruct = self.ae3.reconstruct(x)
a1_reconstruct = self.ae2.reconstruct(a2_reconstruct)
x_reconstruct = self.ae1.reconstruct(a1_reconstruct)
return x_reconstruct