import torch utils data as torch_data import numpy as np from os path

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch.utils.data as torch_data
import numpy as np
from os.path import join
from torchvision import transforms
from .transforms import RandomJitterTransform, RandomRotateTransform, ScaleTransform
MODEL_NET_40_TRAIN_LABELS = '/home/dvolkhonskiy/pcsr/data/modelnet40-normal_numpy/modelnet40_train.txt'
MODEL_NET_40_TEST_LABELS = '/home/dvolkhonskiy/pcsr/data/modelnet40-normal_numpy/modelnet40_test.txt'
class ModelNet(torch_data.Dataset):
classes = {
'airplane': 0, 'bathtub': 1, 'bed': 2, 'bench': 3,
'bookshelf': 4, 'bottle': 5, 'bowl': 6, 'car': 7,
'chair': 8, 'cone': 9, 'cup': 10, 'curtain': 11,
'desk': 12, 'door': 13, 'dresser': 14, 'flower_pot': 15,
'glass_box': 16, 'guitar': 17, 'keyboard': 18, 'lamp': 19,
'laptop': 20, 'mantel': 21, 'monitor': 22, 'night_stand': 23,
'person': 24, 'piano': 25, 'plant': 26, 'radio': 27,
'range_hood': 28, 'sink': 29, 'sofa': 30, 'stairs': 31,
'stool': 32, 'table': 33, 'tent': 34, 'toilet': 35,
'tv_stand': 36, 'vase': 37, 'wardrobe': 38, 'xbox': 39
}
def __init__(self, root, mode, n_points=1024, transform=None):
super().__init__()
self.root = root
self.n_points = n_points
self.transform = transform
if mode == 'train':
self.files = np.loadtxt(join(root, MODEL_NET_40_TRAIN_LABELS), dtype=str)
else:
self.files = np.loadtxt(join(root, MODEL_NET_40_TEST_LABELS), dtype=str)
self.choice_idx = [np.random.choice(10000, self.n_points, replace=False) for _ in range(self.__len__())]
def load_npy(self, f, idx):
f = join(self.root, f)
data = np.load(f)
pc = data[:, :3]
pc = pc[self.choice_idx[idx], :]
if self.transform is not None:
pc = self.transform(pc)
return pc
def __getitem__(self, idx):
f = self.files[idx]
cls = '_'.join(f.split('_')[:-1])
f = '%s/%s.npy' % (cls, f)
pc = self.load_npy(f, idx)
return pc, self.classes[cls]
def __len__(self):
return len(self.files)
def get_model_net_40(datadir, batch_size, n_points):
transform = transforms.Compose([
ScaleTransform(),
RandomRotateTransform(),
RandomJitterTransform()
])
train_data = ModelNet(datadir, mode='train', n_points=n_points, transform=transform)
test_data = ModelNet(datadir, mode='test', n_points=n_points, transform=transform)
train_loader = torch_data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch_data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader