1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| class LeavesData(Dataset): def __init__(self, csv_path, file_path, mode='train', valid_ratio=0.2, resize_height=256, resize_width=256): """ Args: csv_path (string): csv 文件路径 img_path (string): 图像文件所在路径 mode (string): 训练模式还是测试模式 valid_ratio (float): 验证集比例 """ self.resize_height = resize_height self.resize_width = resize_width
self.file_path = file_path self.mode = mode
self.data_info = pd.read_csv(csv_path, header=None) self.data_len = len(self.data_info.index) - 1 self.train_len = int(self.data_len * (1 - valid_ratio)) if mode == 'train': self.train_image = np.asarray(self.data_info.iloc[1:self.train_len, 0]) self.train_label = np.asarray(self.data_info.iloc[1:self.train_len, 1]) self.image_arr = self.train_image self.label_arr = self.train_label elif mode == 'valid': self.valid_image = np.asarray(self.data_info.iloc[self.train_len:, 0]) self.valid_label = np.asarray(self.data_info.iloc[self.train_len:, 1]) self.image_arr = self.valid_image self.label_arr = self.valid_label elif mode == 'test': self.test_image = np.asarray(self.data_info.iloc[1:, 0]) self.image_arr = self.test_image self.real_len = len(self.image_arr)
print('Finished reading the {} set of Leaves Dataset ({} samples found)' .format(mode, self.real_len))
def __getitem__(self, index): single_image_name = self.image_arr[index]
img_as_img = Image.open(self.file_path + single_image_name)
if self.mode == 'train': transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor() ]) else: transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) img_as_img = transform(img_as_img) if self.mode == 'test': return img_as_img else: label = self.label_arr[index] number_label = class_to_num[label]
return img_as_img, number_label
def __len__(self): return self.real_len
|