forked from torch-points3d/torch-points3d
-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathmock.py
74 lines (59 loc) · 2.33 KB
/
mock.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from torch_geometric.data import Data, Batch
from torch_points3d.datasets.batch import SimpleBatch
from torch_points3d.core.data_transform import MultiScaleTransform
from torch_points3d.datasets.multiscale_data import MultiScaleBatch
class MockDatasetConfig(object):
def __init__(self):
pass
def keys(self):
return []
def get(self, dataset_name, default):
return None
class MockDataset(torch.utils.data.Dataset):
def __init__(self, feature_size=0, transform=None, num_points=100):
self.feature_dimension = feature_size
self.num_classes = 10
self.num_points = num_points
self.batch_size = 2
self.weight_classes = None
if feature_size > 0:
self._feature = torch.tensor([range(feature_size) for i in range(self.num_points)], dtype=torch.float,)
else:
self._feature = None
self._y = torch.tensor([0 for i in range(self.num_points)], dtype=torch.long)
self._category = torch.ones((self.num_points,), dtype=torch.long)
self._ms_transform = None
self._transform = transform
def __len__(self):
return self.num_points
def len(self):
return len(self)
@property
def datalist(self):
torch.manual_seed(0)
torch.randn((self.num_points, 3))
datalist = [
Data(pos=torch.randn((self.num_points, 3)), x=self._feature, y=self._y, category=self._category)
for i in range(self.batch_size)
]
if self._transform:
datalist = [self._transform(d.clone()) for d in datalist]
if self._ms_transform:
datalist = [self._ms_transform(d.clone()) for d in datalist]
return datalist
def __getitem__(self, index):
return SimpleBatch.from_data_list(self.datalist)
@property
def class_to_segments(self):
return {"class1": [0, 1, 2, 3, 4, 5], "class2": [6, 7, 8, 9]}
def set_strategies(self, model):
strategies = model.get_spatial_ops()
transform = MultiScaleTransform(strategies)
self._ms_transform = transform
class MockDatasetGeometric(MockDataset):
def __getitem__(self, index):
if self._ms_transform:
return MultiScaleBatch.from_data_list(self.datalist)
else:
return Batch.from_data_list(self.datalist)