import os from abc import ABC, abstractmethod import logging import functools from functools import partial import numpy as np import torch import torch_geometric from torch_geometric.transforms import Compose, FixedPoints import copy from torch_points3d.models import model_interface from torch_points3d.core.data_transform import instantiate_transforms, MultiScaleTransform from torch_points3d.core.data_transform import instantiate_filters from torch_points3d.datasets.batch import SimpleBatch from torch_points3d.datasets.multiscale_data import MultiScaleBatch from torch_points3d.utils.enums import ConvolutionFormat from torch_points3d.utils.config import ConvolutionFormatFactory from torch_points3d.utils.colors import COLORS, colored_print # A logger for this file log = logging.getLogger(__name__) def explode_transform(transforms): """ Returns a flattened list of transform Arguments: transforms {[list | T.Compose]} -- Contains list of transform to be added Returns: [list] -- [List of transforms] """ out = [] if transforms is not None: if isinstance(transforms, Compose): out = copy.deepcopy(transforms.transforms) elif isinstance(transforms, list): out = copy.deepcopy(transforms) else: raise Exception("Transforms should be provided either within a list or a Compose") return out def save_used_properties(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): # Save used_properties for mocking dataset when calling pretrained registry result = func(self, *args, **kwargs) if isinstance(result, torch.Tensor): self.used_properties[func.__name__] = result.numpy().tolist() elif isinstance(result, np.ndarray): self.used_properties[func.__name__] = result.tolist() else: self.used_properties[func.__name__] = result return result return wrapper class BaseDataset: def __init__(self, dataset_opt): self.dataset_opt = dataset_opt # Default dataset path dataset_name = dataset_opt.get("dataset_name", None) if dataset_name: self._data_path = os.path.join(dataset_opt.dataroot, dataset_name) else: class_name = self.__class__.__name__.lower().replace("dataset", "") self._data_path = os.path.join(dataset_opt.dataroot, class_name) self._batch_size = None self.strategies = {} self._contains_dataset_name = False self.train_sampler = None self.test_sampler = None self.val_sampler = None self._train_dataset = None self._test_dataset = None self._val_dataset = None self.pre_batch_collate_transform = None BaseDataset.set_transform(self, dataset_opt) self.set_filter(dataset_opt) self.used_properties = {} @staticmethod def remove_transform(transform_in, list_transform_class): """ Remove a transform if within list_transform_class Arguments: transform_in {[type]} -- [Compose | List of transform] list_transform_class {[type]} -- [List of transform class to be removed] Returns: [type] -- [description] """ if isinstance(transform_in, Compose) or isinstance(transform_in, list): if len(list_transform_class) > 0: transform_out = [] transforms = transform_in.transforms if isinstance(transform_in, Compose) else transform_in for t in transforms: if not isinstance(t, tuple(list_transform_class)): transform_out.append(t) transform_out = Compose(transform_out) else: transform_out = transform_in return transform_out @staticmethod def set_transform(obj, dataset_opt): """This function create and set the transform to the obj as attributes """ obj.pre_transform = None obj.test_transform = None obj.train_transform = None obj.val_transform = None obj.inference_transform = None obj.pre_batch_collate_transform = None for key_name in dataset_opt.keys(): if "transform" in key_name: new_name = key_name.replace("transforms", "transform") try: transform = instantiate_transforms(getattr(dataset_opt, key_name)) except Exception: log.exception("Error trying to create {}, {}".format(new_name, getattr(dataset_opt, key_name))) continue setattr(obj, new_name, transform) inference_transform = explode_transform(obj.pre_transform) inference_transform += explode_transform(obj.test_transform) obj.inference_transform = Compose(inference_transform) if len(inference_transform) > 0 else None def set_filter(self, dataset_opt): """This function create and set the pre_filter to the obj as attributes """ self.pre_filter = None for key_name in dataset_opt.keys(): if "filter" in key_name: new_name = key_name.replace("filters", "filter") try: filt = instantiate_filters(getattr(dataset_opt, key_name)) except Exception: log.exception("Error trying to create {}, {}".format(new_name, getattr(dataset_opt, key_name))) continue setattr(self, new_name, filt) @staticmethod def _collate_fn(batch, collate_fn=None, pre_collate_transform=None): if pre_collate_transform: batch = pre_collate_transform(batch) return collate_fn(batch) @staticmethod def _get_collate_function(conv_type, is_multiscale, pre_collate_transform=None): is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_multiscale: if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): fn = MultiScaleBatch.from_data_list else: raise NotImplementedError( "MultiscaleTransform is activated and supported only for partial_dense format" ) else: if is_dense: fn = SimpleBatch.from_data_list else: fn = torch_geometric.data.batch.Batch.from_data_list return partial(BaseDataset._collate_fn, collate_fn=fn, pre_collate_transform=pre_collate_transform) @staticmethod def get_num_samples(batch, conv_type): is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_dense: return batch.pos.shape[0] else: return batch.batch.max() + 1 @staticmethod def get_sample(batch, key, index, conv_type): assert hasattr(batch, key) is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) if is_dense: return batch[key][index] else: return batch[key][batch.batch == index] def create_dataloaders( self, model: model_interface.DatasetInterface, batch_size: int, shuffle: bool, num_workers: int, precompute_multi_scale: bool, ): """ Creates the data loaders. Must be called in order to complete the setup of the Dataset """ conv_type = model.conv_type self._batch_size = batch_size batch_collate_function = self.__class__._get_collate_function( conv_type, precompute_multi_scale, self.pre_batch_collate_transform ) dataloader = partial( torch.utils.data.DataLoader, collate_fn=batch_collate_function, worker_init_fn=np.random.seed ) if self.train_sampler: log.info(self.train_sampler) if self.train_dataset: self._train_loader = dataloader( self.train_dataset, batch_size=batch_size, shuffle=shuffle and not self.train_sampler, num_workers=num_workers, sampler=self.train_sampler, ) if self.test_dataset: self._test_loaders = [ dataloader( dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, sampler=self.test_sampler, ) for dataset in self.test_dataset ] if self.val_dataset: self._val_loader = dataloader( self.val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, sampler=self.val_sampler, ) if precompute_multi_scale: self.set_strategies(model) @property def has_train_loader(self): return hasattr(self, "_train_loader") @property def has_val_loader(self): return hasattr(self, "_val_loader") @property def has_test_loaders(self): return hasattr(self, "_test_loaders") @property def train_dataset(self): return self._train_dataset @train_dataset.setter def train_dataset(self, value): self._train_dataset = value if not hasattr(self._train_dataset, "name"): setattr(self._train_dataset, "name", "train") @property def val_dataset(self): return self._val_dataset @val_dataset.setter def val_dataset(self, value): self._val_dataset = value if not hasattr(self._val_dataset, "name"): setattr(self._val_dataset, "name", "val") @property def test_dataset(self): return self._test_dataset @test_dataset.setter def test_dataset(self, value): if isinstance(value, list): self._test_dataset = value else: self._test_dataset = [value] for i, dataset in enumerate(self._test_dataset): if not hasattr(dataset, "name"): if self.num_test_datasets > 1: setattr(dataset, "name", "test_%i" % i) else: setattr(dataset, "name", "test") else: self._contains_dataset_name = True # Check for uniqueness all_names = [d.name for d in self.test_dataset] if len(set(all_names)) != len(all_names): raise ValueError("Datasets need to have unique names. Current names are {}".format(all_names)) @property def train_dataloader(self): return self._train_loader @property def val_dataloader(self): return self._val_loader @property def test_dataloaders(self): if self.has_test_loaders: return self._test_loaders else: return [] @property def _loaders(self): loaders = [] if self.has_train_loader: loaders += [self.train_dataloader] if self.has_val_loader: loaders += [self.val_dataloader] if self.has_test_loaders: loaders += self.test_dataloaders return loaders @property def num_test_datasets(self): return len(self._test_dataset) if self._test_dataset else 0 @property def _test_datatset_names(self): if self.test_dataset: return [d.name for d in self.test_dataset] else: return [] @property def available_stage_names(self): out = self._test_datatset_names if self.has_val_loader: out += [self._val_dataset.name] return out @property def available_dataset_names(self): return ["train"] + self.available_stage_names def get_raw_data(self, stage, idx, **kwargs): assert stage in self.available_dataset_names dataset = self.get_dataset(stage) if hasattr(dataset, "get_raw_data"): return dataset.get_raw_data(idx, **kwargs) else: raise Exception("Dataset {} doesn t have a get_raw_data function implemented".format(dataset)) def has_labels(self, stage: str) -> bool: """ Tests if a given dataset has labels or not Parameters ---------- stage : str name of the dataset to test """ assert stage in self.available_dataset_names dataset = self.get_dataset(stage) if hasattr(dataset, "has_labels"): return dataset.has_labels sample = dataset[0] if hasattr(sample, "y"): return sample.y is not None return False @property # type: ignore @save_used_properties def is_hierarchical(self): """ Used by the metric trackers to log hierarchical metrics """ return False @property # type: ignore @save_used_properties def class_to_segments(self): """ Use this property to return the hierarchical map between classes and segment ids, example: { 'Airplaine': [0,1,2], 'Boat': [3,4,5] } """ return None @property # type: ignore @save_used_properties def num_classes(self): return self.train_dataset.num_classes @property def weight_classes(self): return getattr(self.train_dataset, "weight_classes", None) @property # type: ignore @save_used_properties def feature_dimension(self): if self.train_dataset: return self.train_dataset.num_features elif self.test_dataset is not None: if isinstance(self.test_dataset, list): return self.test_dataset[0].num_features else: return self.test_dataset.num_features elif self.val_dataset is not None: return self.val_dataset.num_features else: raise NotImplementedError() @property def batch_size(self): return self._batch_size @property def num_batches(self): out = { self.train_dataset.name: len(self._train_loader), "val": len(self._val_loader) if self.has_val_loader else 0, } if self.test_dataset: for loader in self._test_loaders: stage_name = loader.dataset.name out[stage_name] = len(loader) return out def get_dataset(self, name): """ Get a dataset by name. Raises an exception if no dataset was found Parameters ---------- name : str """ all_datasets = [self.train_dataset, self.val_dataset] if self.test_dataset: all_datasets += self.test_dataset for dataset in all_datasets: if dataset is not None and dataset.name == name: return dataset raise ValueError("No dataset with name %s was found." % name) def _set_composed_multiscale_transform(self, attr, transform): current_transform = getattr(attr.dataset, "transform", None) if current_transform is None: setattr(attr.dataset, "transform", transform) else: if ( isinstance(current_transform, Compose) and transform not in current_transform.transforms ): # The transform contains several transformations current_transform.transforms += [transform] elif current_transform != transform: setattr( attr.dataset, "transform", Compose([current_transform, transform]), ) def _set_multiscale_transform(self, transform): for _, attr in self.__dict__.items(): if isinstance(attr, torch.utils.data.DataLoader): self._set_composed_multiscale_transform(attr, transform) for loader in self.test_dataloaders: self._set_composed_multiscale_transform(loader, transform) def set_strategies(self, model): strategies = model.get_spatial_ops() transform = MultiScaleTransform(strategies) self._set_multiscale_transform(transform) @abstractmethod def get_tracker(self, wandb_log: bool, tensorboard_log: bool): pass def resolve_saving_stage(self, selection_stage): """This function is responsible to determine if the best model selection is going to be on the validation or test datasets """ log.info( "Available stage selection datasets: {} {} {}".format( COLORS.IPurple, self.available_stage_names, COLORS.END_NO_TOKEN ) ) if self.num_test_datasets > 1 and not self._contains_dataset_name: msg = "If you want to have better trackable names for your test datasets, add a " msg += COLORS.IPurple + "name" + COLORS.END_NO_TOKEN msg += " attribute to them" log.info(msg) if selection_stage == "": if self.has_val_loader: selection_stage = self.val_dataset.name else: selection_stage = self.test_dataset[0].name log.info( "The models will be selected using the metrics on following dataset: {} {} {}".format( COLORS.IPurple, selection_stage, COLORS.END_NO_TOKEN ) ) return selection_stage def add_weights(self, dataset_name="train", class_weight_method="sqrt"): """ Add class weights to a given dataset that are then accessible using the `class_weights` attribute """ L = self.num_classes weights = torch.ones(L) dataset = self.get_dataset(dataset_name) idx_classes, counts = torch.unique(dataset.data.y, return_counts=True) dataset.idx_classes = torch.arange(L).long() weights[idx_classes] = counts.float() weights = weights.float() weights = weights.mean() / weights if class_weight_method == "sqrt": weights = torch.sqrt(weights) elif str(class_weight_method).startswith("log"): weights = torch.log(1.1 + weights / weights.sum()) else: raise ValueError("Method %s not supported" % class_weight_method) weights /= torch.sum(weights) log.info("CLASS WEIGHT : {}".format([np.round(weight.item(), 4) for weight in weights])) setattr(dataset, "weight_classes", weights) return dataset def __repr__(self): message = "Dataset: %s \n" % self.__class__.__name__ for attr in self.__dict__: if "transform" in attr: message += "{}{} {}= {}\n".format(COLORS.IPurple, attr, COLORS.END_NO_TOKEN, getattr(self, attr)) for attr in self.__dict__: if attr.endswith("_dataset"): dataset = getattr(self, attr) if isinstance(dataset, list): if len(dataset) > 1: size = ", ".join([str(len(d)) for d in dataset]) else: size = len(dataset[0]) elif dataset: size = len(dataset) else: size = 0 if attr.startswith("_"): attr = attr[1:] message += "Size of {}{} {}= {}\n".format(COLORS.IPurple, attr, COLORS.END_NO_TOKEN, size) for key, attr in self.__dict__.items(): if key.endswith("_sampler") and attr: message += "{}{} {}= {}\n".format(COLORS.IPurple, key, COLORS.END_NO_TOKEN, attr) message += "{}Batch size ={} {}".format(COLORS.IPurple, COLORS.END_NO_TOKEN, self.batch_size) return message