forked from CoinCheung/pytorch-loss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
swish.py
97 lines (73 loc) · 2.49 KB
/
swish.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
## use autograd
class SwishV1(nn.Module):
def __init__(self):
super(SwishV1, self).__init__()
def forward(self, feat):
return feat * torch.sigmoid(feat)
## use self-computed back-propagation, use less memory and faster
class SwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, feat):
sig = torch.sigmoid(feat)
out = feat * torch.sigmoid(feat)
grad = sig * (1 + feat * (1 - sig))
ctx.grad = grad
return out
@staticmethod
def backward(ctx, grad_output):
grad = ctx.grad
grad *= grad_output
return grad
class SwishV2(nn.Module):
def __init__(self):
super(SwishV2, self).__init__()
def forward(self, feat):
return SwishFunction.apply(feat)
if __name__ == "__main__":
import torchvision
net = torchvision.models.resnet50(pretrained=True)
sd = {k: v for k, v in net.state_dict().items() if k.startswith('conv1.') or k.startswith('bn1.')}
print(sd)
class Net(nn.Module):
def __init__(self, act='swishv1'):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 7, 2, 3)
self.bn1 = nn.BatchNorm2d(64)
if act == 'swishv1':
self.act = SwishV1()
else:
self.act = SwishV2()
self.dense = nn.Linear(64, 10, bias=False)
self.crit = nn.CrossEntropyLoss()
state = self.state_dict()
state.update(sd)
self.load_state_dict(state)
torch.nn.init.constant_(self.dense.weight, 1)
def forward(self, feat, label):
feat = self.conv1(feat)
feat = self.bn1(feat)
feat = self.act(feat)
feat = torch.mean(feat, dim=(2, 3))
logits = self.dense(feat)
loss = self.crit(logits, label)
return loss
net1 = Net(act='swishv1')
net2 = Net(act='swishv2')
opt1 = torch.optim.SGD(net1.parameters(), lr=1e-3)
opt2 = torch.optim.SGD(net2.parameters(), lr=1e-3)
for i in range(10):
inten = torch.randn(16, 3, 512, 512).detach()
label = torch.randint(0, 10, (16, )).detach()
loss1 = net1(inten, label)
opt1.zero_grad()
loss1.backward()
opt1.step()
loss2 = net2(inten, label)
opt2.zero_grad()
loss2.backward()
opt2.step()
print(loss1.item() - loss2.item())