记录一下在学习图片分类时遇到的一些问题,还有一些函数、参数做一些笔记。
常用的库引入
from torchvision.datasets import ImageFolder from torchvision import transforms import matplotlib.pyplot as plt
数据集的结构
自己要建立一个图片数据集的话,假如我有一个名字为dog_cat的文件夹,那么这个文件夹就是root,也就是作为root参数传入到ImageFolder函数中。
一定要建立子文件夹,label是根据子文件夹的序号自动标记的(从0开始标记)。
比如下面的0dog中的所有图片的label都是0,而1cat中的所有图片的label都是1。
这时候我们可以定义一个路径path。
# 定义数据集根目录 path = "./data/dog_cat"
定义Transform方法
在加载图片的时候为了使得图片可以更好的处理,通常需要将图片标准化,这里需要用到torchversion库中的transforms.Compose方法来建立一个处理列表,也就是每一张图片都会经过这样的处理,转化成一个易于处理的Tensor(张量)。
tfs = transforms.Compose([ transforms.Resize((256, 256)), # 规定图形大小 transforms.ToTensor(), # 转化为张量数据 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), # 数据归中 ])
做一些解释:ToTensor会使得图片的每一个像素tuple归一化,也就是原本像素值从[0, 255]变为[0.0, 1.0],但是transforms.Normalize又会使得图片的像素值从[0.0, 1.0]变为一个均值为0的像素值,这样可以加速图片的处理,加速函数的收敛,获得更好的训练效果。
加载图片数据
train_data = ImageFolder(path, transform=tfs) # 加载图片
在加载完成后的train_data是一个由二元组组成的可迭代对象,每一个元组都是一个二元组(Tensor, label(int))。
打印图片数据
需要用到plt,所以需要引入库文件。
import matplotlib.pyplot as plt
如果直接用下面这几行代码来打印图片的话,会出错
for img in train_data: plt_img = img[0] plt.imshow(plt_img) plt.show() # 报错

注意此时Tensor的格式为(3, 256, 256),其中第一个表示RGB,这样的格式是不可以直接给plt显示的,需要通过tensor的swapaxes函数来进行维度转化后方可显示。
所以打印图片数据的方法为:
for img in train_data: print('图片的label:', img[1]) plt_img = img[0] plt_img = plt_img.swapaxes(0, 1) plt_img = plt_img.swapaxes(1, 2) plt.imshow(plt_img) plt.show()
这样就可以正常的打印图片数据了,但是仍然会报错,因为这时候图片的像素值区间为一个正负对称的区间,并非plt支持的[0.0, 1.0],所以会报错。

其实在之前ToTensor这个函数已经将图片进行归一化了,但是又用transforms.Normalize进行了归中化来加速函数收敛。
为了正常地展示图片,可以通过将图片反归中化来显示。
for img in train_data: print('图片的label:', img[1]) plt_img = img[0] plt_img = plt_img.swapaxes(0, 1) plt_img = plt_img.swapaxes(1, 2) plt.imshow((plt_img + 3) / 6) plt.show()
注意这样只是方便了plt的显示,并不会影响Tensor真正的矩阵,也就不会影响机器学习的效率。

本章总结

Comments NOTHING