class ChamferDistance nn Module def __init__ self super __init__ pass

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class ChamferDistance(nn.Module):
def __init__(self):
super().__init__()
pass
def pairwise_dist(self, x, y):
xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t())
rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))
P = (rx.t() + ry - 2 * zz)
return P
def batch_pairwise_dist(self, a, b):
x, y = a, b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
diag_ind = torch.arange(0, num_points).type(torch.cuda.LongTensor)
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P
def forward(self, input, target):
dist = self.batch_pairwise_dist(input, target)
# dist = chamfer_distance_with_batch(input, target)
# return dist
values_1, indices = dist.min(dim=1)
values_2, indices = dist.min(dim=2)
return torch.sum(values_1, dim=1) + torch.sum(values_2, dim=1)