minist

 
$ ls -ltrh
total 2.4M
-rwxrwxrwx 1 xt xt 2.1M Mar  2  2023 train.csv
-rwxrwxrwx 1 xt xt 332K Mar  2  2023 test.csv
drwxrwxrwx 1 xt xt 4.0K Jan 11 15:21 test
drwxrwxrwx 1 xt xt 4.0K Jan 11 15:24 train

$ less test.csv
datasets\MNIST\test\0\1.jpeg,0
datasets\MNIST\test\0\10.jpeg,0
datasets\MNIST\test\0\100.jpeg,0
datasets\MNIST\test\0\101.jpeg,0
datasets\MNIST\test\0\102.jpeg,0
datasets\MNIST\test\0\103.jpeg,0
datasets\MNIST\test\0\104.jpeg,0
datasets\MNIST\test\0\105.jpeg,0
datasets\MNIST\test\0\106.jpeg,0
datasets\MNIST\test\0\107.jpeg,0
datasets\MNIST\test\0\108.jpeg,0

minist转pkl

由于图像文件太多,转为pkl存储

 
import os
import torch
from torch.utils.data import Dataset,DataLoader
from PIL import Image
from torchvision import transforms
from tpf.params import ImgPath 

class ImgSet(Dataset):
    """手写数字数据集
    """
    def __init__(self, fil=None, pfil=ImgPath.mnist_run,img_size = (32,32)):
        """超参数初始化
        pfil:数据集所有目录的父目录
        """
        with open(file=fil, mode="r", encoding="utf8") as f:
            images = [line.strip().split(",") for line in f.readlines()]
            
            for img in images:
                # 在linux上运行,所以需要转换一下文件分隔符
                img[0] = os.path.join(pfil,img[0].replace("\\","/"))
            self.images = images

        # 图片转Tensor
        self.trans = transforms.Compose([
                transforms.Grayscale(),
                transforms.Resize(size=img_size),
                transforms.ToTensor()
            ])
    
    def __getitem__(self, idx):
        """读取图像并转化为数据与标签
        """
        img_path, label = self.images[idx]
        img_path = os.path.join(ImgPath.mnist_run, img_path)
        img = Image.open(fp=img_path)
        return self.trans(img), torch.tensor(data=int(label)).long()
    
    def __len__(self):
        """数据集大小
        """            
        return len(self.images)

train_csv = os.path.join(ImgPath.mnist_run, "train.csv")
train_dataset = ImgSet(fil=train_csv, img_size=(32,32),pfil=ImgPath.mnist_run)
print(len(train_dataset))  # 60000

test_csv = os.path.join(ImgPath.mnist_run, "test.csv")
test_dataset = ImgSet(fil=test_csv, img_size=(32,32),pfil=ImgPath.mnist_run)
print(len(test_dataset))  # 10000

from tpf.data.d1 import pkl_save
def save_dataset(dataset,save_file):
    img_list = []
    for x,y in dataset:
        img_list.append((x,y))
    pkl_save(data=img_list,file_path=save_file)




# 267M,
train_pkl = os.path.join(ImgPath.mnist,"img0-9_train.pkl")
save_dataset(train_dataset,save_file = train_pkl)

# 45M
test_pkl = os.path.join(ImgPath.mnist,"img0-9_test.pkl")
save_dataset(test_dataset,save_file = test_pkl)


from tpf.data.d1 import pkl_load
# 数据加载测试
class MyDataset(Dataset):
    """手写数字数据集
    """
    def __init__(self, fil):
        """超参数初始化
        """
        self.img_list = pkl_load(file_path=fil)
    
    def __getitem__(self, idx):
        """读取图像及标签
        """
        _img,_label = self.img_list[idx]
        return _img,_label
    
    def __len__(self):
        """数据集大小
        """            
        return len(self.img_list)


train_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_train.pkl"))
print(len(train_dataset))  # 60000

test_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_test.pkl"))
print(len(test_dataset))  # 10000

minist pkl 使用

加载pkl

 
import os 
from tpf.data.d1 import pkl_load
from tpf.params import ImgPath 
from torch.utils.data import DataLoader,Dataset 


# 数据加载测试
class MyDataset(Dataset):
    """手写数字数据集
    """
    def __init__(self, fil):
        """超参数初始化
        """
        self.img_list = pkl_load(file_path=fil)
    
    def __getitem__(self, idx):
        """读取图像及标签
        """
        _img,_label = self.img_list[idx]
        return _img,_label
    
    def __len__(self):
        """数据集大小
        """            
        return len(self.img_list)


train_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_train.pkl"))
print(len(train_dataset))  # 60000

test_dataset = MyDataset(fil=os.path.join(ImgPath.mnist, "img0-9_test.pkl"))
print(len(test_dataset))  # 10000

 
# 构建数据加载器
train_dataloader = DataLoader(dataset=train_dataset, batch_size=512, shuffle=True,drop_last=True)

test_dataloader = DataLoader(dataset=test_dataset, batch_size=512, shuffle=False,drop_last=True)

for X,y in test_dataloader:
    print(X.shape,y.shape)  # torch.Size([512, 1, 32, 32]) torch.Size([512])
    print(X[0],y[0])
    print(X[0].dtype,y[0].dtype)  #torch.float32 torch.int64
    break 
    

 

    

 

 

 

参考