$ ls -ltrh total 2.4M -rwxrwxrwx 1 xt xt 2.1M Mar 2 2023 train.csv -rwxrwxrwx 1 xt xt 332K Mar 2 2023 test.csv drwxrwxrwx 1 xt xt 4.0K Jan 11 15:21 test drwxrwxrwx 1 xt xt 4.0K Jan 11 15:24 train $ less test.csv datasets\MNIST\test\0\1.jpeg,0 datasets\MNIST\test\0\10.jpeg,0 datasets\MNIST\test\0\100.jpeg,0 datasets\MNIST\test\0\101.jpeg,0 datasets\MNIST\test\0\102.jpeg,0 datasets\MNIST\test\0\103.jpeg,0 datasets\MNIST\test\0\104.jpeg,0 datasets\MNIST\test\0\105.jpeg,0 datasets\MNIST\test\0\106.jpeg,0 datasets\MNIST\test\0\107.jpeg,0 datasets\MNIST\test\0\108.jpeg,0
由于图像文件太多,转为pkl存储
import os import torch from torch.utils.data import Dataset,DataLoader from PIL import Image from torchvision import transforms from tpf.params import ImgPath class ImgSet(Dataset): """手写数字数据集 """ def __init__(self, fil=None, pfil=ImgPath.mnist_run,img_size = (32,32)): """超参数初始化 pfil:数据集所有目录的父目录 """ with open(file=fil, mode="r", encoding="utf8") as f: images = [line.strip().split(",") for line in f.readlines()] for img in images: # 在linux上运行,所以需要转换一下文件分隔符 img[0] = os.path.join(pfil,img[0].replace("\\","/")) self.images = images # 图片转Tensor self.trans = transforms.Compose([ transforms.Grayscale(), transforms.Resize(size=img_size), transforms.ToTensor() ]) def __getitem__(self, idx): """读取图像并转化为数据与标签 """ img_path, label = self.images[idx] img_path = os.path.join(ImgPath.mnist_run, img_path) img = Image.open(fp=img_path) return self.trans(img), torch.tensor(data=int(label)).long() def __len__(self): """数据集大小 """ return len(self.images) train_csv = os.path.join(ImgPath.mnist_run, "train.csv") train_dataset = ImgSet(fil=train_csv, img_size=(32,32),pfil=ImgPath.mnist_run) print(len(train_dataset)) # 60000 test_csv = os.path.join(ImgPath.mnist_run, "test.csv") test_dataset = ImgSet(fil=test_csv, img_size=(32,32),pfil=ImgPath.mnist_run) print(len(test_dataset)) # 10000 from tpf.data.d1 import pkl_save def save_dataset(dataset,save_file): img_list = [] for x,y in dataset: img_list.append((x,y)) pkl_save(data=img_list,file_path=save_file) # 267M, train_pkl = os.path.join(ImgPath.mnist,"img0-9_train.pkl") save_dataset(train_dataset,save_file = train_pkl) # 45M test_pkl = os.path.join(ImgPath.mnist,"img0-9_test.pkl") save_dataset(test_dataset,save_file = test_pkl) from tpf.data.d1 import pkl_load # 数据加载测试 class MyDataset(Dataset): """手写数字数据集 """ def __init__(self, fil): """超参数初始化 """ self.img_list = pkl_load(file_path=fil) def __getitem__(self, idx): """读取图像及标签 """ _img,_label = self.img_list[idx] return _img,_label def __len__(self): """数据集大小 """ return len(self.img_list) train_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_train.pkl")) print(len(train_dataset)) # 60000 test_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_test.pkl")) print(len(test_dataset)) # 10000
加载pkl
import os from tpf.data.d1 import pkl_load from tpf.params import ImgPath from torch.utils.data import DataLoader,Dataset # 数据加载测试 class MyDataset(Dataset): """手写数字数据集 """ def __init__(self, fil): """超参数初始化 """ self.img_list = pkl_load(file_path=fil) def __getitem__(self, idx): """读取图像及标签 """ _img,_label = self.img_list[idx] return _img,_label def __len__(self): """数据集大小 """ return len(self.img_list) train_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_train.pkl")) print(len(train_dataset)) # 60000 test_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_test.pkl")) print(len(test_dataset)) # 10000
# 构建数据加载器 train_dataloader = DataLoader(dataset=train_dataset, batch_size=512, shuffle=True,drop_last=True) test_dataloader = DataLoader(dataset=test_dataset, batch_size=512, shuffle=False,drop_last=True) for X,y in test_dataloader: print(X.shape,y.shape) # torch.Size([512, 1, 32, 32]) torch.Size([512]) print(X[0],y[0]) print(X[0].dtype,y[0].dtype) #torch.float32 torch.int64 break