# coding=utf-8 # dataset object definition import cv2 import os import logging import math import numpy as np from nn import resizeImage # dataset object need to implement the following function # get_sample(self, idx) # collect_batch(self, datalist) class ImageDataset(object): def __init__(self, cfg, split, imglst, annotations=None): """ imglst: a file containing a list of absolute path to all the images """ self.cfg = cfg # this should include short_edge_size, max_size, etc. self.split = split self.imglst = imglst self.annotations = annotations # machine-specific config self.num_gpu = cfg.gpu self.batch_size = cfg.im_batch_size self.batch_size_per_gpu = self.batch_size // cfg.gpu assert self.batch_size % cfg.gpu == 0, "bruh" if self.split == "train": self.num_epochs = cfg.num_epochs else: self.num_epochs = 1 # load the img file list self.imgs = [line.strip() for line in open(self.imglst).readlines()] self.num_samples = len(self.imgs) # one epoch length self.num_batches_per_epoch = int( math.ceil(self.num_samples / float(self.batch_size))) self.num_batches = int(self.num_batches_per_epoch * self.num_epochs) self.valid_idxs = range(self.num_samples) logging.info("Loaded %s imgs", len(self.imgs)) def get_sample(self, idx): """ preprocess one sample from the list """ cfg = self.cfg img_file_path = self.imgs[idx] imgname = os.path.splitext(os.path.basename(img_file_path))[0] frame = cv2.imread(img_file_path) im = frame.astype("float32") resized_image = resizeImage(im, cfg.short_edge_size, cfg.max_size) scale = (resized_image.shape[0] * 1.0 / im.shape[0] + \ resized_image.shape[1] * 1.0 / im.shape[1]) / 2.0 return resized_image, scale, imgname, (im.shape[0], im.shape[1]) def collect_batch(self, data, idxs=None): """ collect the idxs of the data list into a dictionary """ if idxs is None: idxs = range(len(data)) imgs, scales, imgnames, shapes = zip(*[data[idx] for idx in idxs]) return { "imgs": imgs, "scales": scales, "imgnames": imgnames, "ori_shapes": shapes }