[Pytorch]机器学习入门项目笔记:图片分类(1)

发布于 2022-08-31  394 次阅读


记录一下在学习图片分类时遇到的一些问题,还有一些函数、参数做一些笔记。

常用的库引入

from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt

数据集的结构

自己要建立一个图片数据集的话,假如我有一个名字为dog_cat的文件夹,那么这个文件夹就是root,也就是作为root参数传入到ImageFolder函数中。

一定要建立子文件夹,label是根据子文件夹的序号自动标记的(从0开始标记)。

比如下面的0dog中的所有图片的label都是0,而1cat中的所有图片的label都是1。

数据集文件夹的结构

这时候我们可以定义一个路径path。

# 定义数据集根目录
path = "./data/dog_cat"

定义Transform方法

在加载图片的时候为了使得图片可以更好的处理,通常需要将图片标准化,这里需要用到torchversion库中的transforms.Compose方法来建立一个处理列表,也就是每一张图片都会经过这样的处理,转化成一个易于处理的Tensor(张量)。

tfs = transforms.Compose([
    transforms.Resize((256, 256)),  # 规定图形大小
    transforms.ToTensor(),  # 转化为张量数据
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # 数据归中
])

做一些解释:ToTensor会使得图片的每一个像素tuple归一化,也就是原本像素值从[0, 255]变为[0.0, 1.0],但是transforms.Normalize又会使得图片的像素值从[0.0, 1.0]变为一个均值为0的像素值,这样可以加速图片的处理,加速函数的收敛,获得更好的训练效果。

加载图片数据

train_data = ImageFolder(path, transform=tfs)  # 加载图片

在加载完成后的train_data是一个由二元组组成的可迭代对象,每一个元组都是一个二元组(Tensor, label(int))。

打印图片数据

需要用到plt,所以需要引入库文件。

import matplotlib.pyplot as plt

如果直接用下面这几行代码来打印图片的话,会出错

for img in train_data:
    plt_img = img[0]
    plt.imshow(plt_img)
    plt.show()  # 报错
报错内容

注意此时Tensor的格式为(3, 256, 256),其中第一个表示RGB,这样的格式是不可以直接给plt显示的,需要通过tensor的swapaxes函数来进行维度转化后方可显示。

所以打印图片数据的方法为:

for img in train_data:
    print('图片的label:', img[1])
    plt_img = img[0]
    plt_img = plt_img.swapaxes(0, 1)
    plt_img = plt_img.swapaxes(1, 2)
    plt.imshow(plt_img)
    plt.show()

这样就可以正常的打印图片数据了,但是仍然会报错,因为这时候图片的像素值区间为一个正负对称的区间,并非plt支持的[0.0, 1.0],所以会报错。

未进行归一化报错

其实在之前ToTensor这个函数已经将图片进行归一化了,但是又用transforms.Normalize进行了归中化来加速函数收敛。

为了正常地展示图片,可以通过将图片反归中化来显示。

for img in train_data:
    print('图片的label:', img[1])
    plt_img = img[0]
    plt_img = plt_img.swapaxes(0, 1)
    plt_img = plt_img.swapaxes(1, 2)
    plt.imshow((plt_img + 3) / 6)
    plt.show()

注意这样只是方便了plt的显示,并不会影响Tensor真正的矩阵,也就不会影响机器学习的效率。

最终结果

本章总结