# class ChamferDistance nn Module def __init__ self super __init__ pass

 ``` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35``` ```class ChamferDistance(nn.Module): def __init__(self): super().__init__() pass def pairwise_dist(self, x, y): xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) rx = (xx.diag().unsqueeze(0).expand_as(xx)) ry = (yy.diag().unsqueeze(0).expand_as(yy)) P = (rx.t() + ry - 2 * zz) return P def batch_pairwise_dist(self, a, b): x, y = a, b bs, num_points, points_dim = x.size() xx = torch.bmm(x, x.transpose(2, 1)) yy = torch.bmm(y, y.transpose(2, 1)) zz = torch.bmm(x, y.transpose(2, 1)) diag_ind = torch.arange(0, num_points).type(torch.cuda.LongTensor) rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) P = (rx.transpose(2, 1) + ry - 2 * zz) return P def forward(self, input, target): dist = self.batch_pairwise_dist(input, target) # dist = chamfer_distance_with_batch(input, target) # return dist values_1, indices = dist.min(dim=1) values_2, indices = dist.min(dim=2) return torch.sum(values_1, dim=1) + torch.sum(values_2, dim=1) ```