-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
58 lines (46 loc) · 2.66 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from torch.nn import functional as F
class NormalVectorLoss(nn.Module):
def __init__(self, face):
super(NormalVectorLoss, self).__init__()
self.face = face
def forward(self, coord_out, coord_gt):
face = torch.LongTensor(self.face).cuda()
v1_out = coord_out[:, face[:, 1], :] - coord_out[:, face[:, 0], :]
v1_out = F.normalize(v1_out, p=2, dim=2) # L2 normalize to make unit vector
v2_out = coord_out[:, face[:, 2], :] - coord_out[:, face[:, 0], :]
v2_out = F.normalize(v2_out, p=2, dim=2) # L2 normalize to make unit vector
v3_out = coord_out[:, face[:, 2], :] - coord_out[:, face[:, 1], :]
v3_out = F.normalize(v3_out, p=2, dim=2) # L2 nroamlize to make unit vector
v1_gt = coord_gt[:, face[:, 1], :] - coord_gt[:, face[:, 0], :]
v1_gt = F.normalize(v1_gt, p=2, dim=2) # L2 normalize to make unit vector
v2_gt = coord_gt[:, face[:, 2], :] - coord_gt[:, face[:, 0], :]
v2_gt = F.normalize(v2_gt, p=2, dim=2) # L2 normalize to make unit vector
normal_gt = torch.cross(v1_gt, v2_gt, dim=2)
normal_gt = F.normalize(normal_gt, p=2, dim=2) # L2 normalize to make unit vector
cos1 = torch.abs(torch.sum(v1_out * normal_gt, 2, keepdim=True))
cos2 = torch.abs(torch.sum(v2_out * normal_gt, 2, keepdim=True))
cos3 = torch.abs(torch.sum(v3_out * normal_gt, 2, keepdim=True))
loss = torch.cat((cos1, cos2, cos3), 1)
return loss
class EdgeLengthLoss(nn.Module):
def __init__(self, face):
super(EdgeLengthLoss, self).__init__()
self.face = face
def forward(self, coord_out, coord_gt):
face = torch.LongTensor(self.face).cuda()
d1_out = torch.sqrt(
torch.sum((coord_out[:, face[:, 0], :] - coord_out[:, face[:, 1], :]) ** 2, 2, keepdim=True))
d2_out = torch.sqrt(
torch.sum((coord_out[:, face[:, 0], :] - coord_out[:, face[:, 2], :]) ** 2, 2, keepdim=True))
d3_out = torch.sqrt(
torch.sum((coord_out[:, face[:, 1], :] - coord_out[:, face[:, 2], :]) ** 2, 2, keepdim=True))
d1_gt = torch.sqrt(torch.sum((coord_gt[:, face[:, 0], :] - coord_gt[:, face[:, 1], :]) ** 2, 2, keepdim=True))
d2_gt = torch.sqrt(torch.sum((coord_gt[:, face[:, 0], :] - coord_gt[:, face[:, 2], :]) ** 2, 2, keepdim=True))
d3_gt = torch.sqrt(torch.sum((coord_gt[:, face[:, 1], :] - coord_gt[:, face[:, 2], :]) ** 2, 2, keepdim=True))
diff1 = torch.abs(d1_out - d1_gt)
diff2 = torch.abs(d2_out - d2_gt)
diff3 = torch.abs(d3_out - d3_gt)
loss = torch.cat((diff1, diff2, diff3), 1)
return loss