ErikTse Runtime

  • 首页 / Home
  • | 算法学习 / Algorithm
    • 所有 / All
    • 简单 / Easy
    • 中等 / Medium
    • 困难 / Hard
  • | 技术分享 / Technology
    • 所有 / All
    • 网络技术 / NetWork
    • 资源共享 / Resource
    • 项目实践 / Event
  • ETOJ在线评测系统
Keep Going.
温故而知新.
  1. 首页
  2. 技术分享
  3. 学科学习
  4. 正文

[Pytorch]机器学习入门项目笔记:图片分类(1)

2022年8月31日 146点热度 0人点赞 0条评论

记录一下在学习图片分类时遇到的一些问题,还有一些函数、参数做一些笔记。

常用的库引入

from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt

数据集的结构

自己要建立一个图片数据集的话,假如我有一个名字为dog_cat的文件夹,那么这个文件夹就是root,也就是作为root参数传入到ImageFolder函数中。

一定要建立子文件夹,label是根据子文件夹的序号自动标记的(从0开始标记)。

比如下面的0dog中的所有图片的label都是0,而1cat中的所有图片的label都是1。

数据集文件夹的结构

这时候我们可以定义一个路径path。

# 定义数据集根目录
path = "./data/dog_cat"

定义Transform方法

在加载图片的时候为了使得图片可以更好的处理,通常需要将图片标准化,这里需要用到torchversion库中的transforms.Compose方法来建立一个处理列表,也就是每一张图片都会经过这样的处理,转化成一个易于处理的Tensor(张量)。

tfs = transforms.Compose([
    transforms.Resize((256, 256)),  # 规定图形大小
    transforms.ToTensor(),  # 转化为张量数据
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  # 数据归中
])

做一些解释:ToTensor会使得图片的每一个像素tuple归一化,也就是原本像素值从[0, 255]变为[0.0, 1.0],但是transforms.Normalize又会使得图片的像素值从[0.0, 1.0]变为一个均值为0的像素值,这样可以加速图片的处理,加速函数的收敛,获得更好的训练效果。

加载图片数据

train_data = ImageFolder(path, transform=tfs)  # 加载图片

在加载完成后的train_data是一个由二元组组成的可迭代对象,每一个元组都是一个二元组(Tensor, label(int))。

打印图片数据

需要用到plt,所以需要引入库文件。

import matplotlib.pyplot as plt

如果直接用下面这几行代码来打印图片的话,会出错

for img in train_data:
    plt_img = img[0]
    plt.imshow(plt_img)
    plt.show()  # 报错
报错内容

注意此时Tensor的格式为(3, 256, 256),其中第一个表示RGB,这样的格式是不可以直接给plt显示的,需要通过tensor的swapaxes函数来进行维度转化后方可显示。

所以打印图片数据的方法为:

for img in train_data:
    print('图片的label:', img[1])
    plt_img = img[0]
    plt_img = plt_img.swapaxes(0, 1)
    plt_img = plt_img.swapaxes(1, 2)
    plt.imshow(plt_img)
    plt.show()

这样就可以正常的打印图片数据了,但是仍然会报错,因为这时候图片的像素值区间为一个正负对称的区间,并非plt支持的[0.0, 1.0],所以会报错。

未进行归一化报错

其实在之前ToTensor这个函数已经将图片进行归一化了,但是又用transforms.Normalize进行了归中化来加速函数收敛。

为了正常地展示图片,可以通过将图片反归中化来显示。

for img in train_data:
    print('图片的label:', img[1])
    plt_img = img[0]
    plt_img = plt_img.swapaxes(0, 1)
    plt_img = plt_img.swapaxes(1, 2)
    plt.imshow((plt_img + 3) / 6)
    plt.show()

注意这样只是方便了plt的显示,并不会影响Tensor真正的矩阵,也就不会影响机器学习的效率。

最终结果

本章总结

本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可
标签: CV python pytorch tensor torchvision 机器学习 深度学习 计算机视觉
最后更新:2022年8月31日

Eriktse

18岁,性别未知,ACM-ICPC现役选手,ICPC亚洲区域赛银牌摆烂人,CCPC某省赛铜牌蒟蒻,武汉某院校计算机科学与技术专业本科在读。

点赞
< 上一篇
下一篇 >

文章评论

取消回复

Eriktse

18岁,性别未知,ACM-ICPC现役选手,ICPC亚洲区域赛银牌摆烂人,CCPC某省赛铜牌蒟蒻,武汉某院校计算机科学与技术专业本科在读。

文章目录
  • 常用的库引入
  • 数据集的结构
  • 定义Transform方法
  • 加载图片数据
  • 打印图片数据
  • 本章总结

友情链接 | 站点地图

COPYRIGHT © 2022 ErikTse Runtime. ALL RIGHTS RESERVED.

Theme Kratos | Hosted In TENCENT CLOUD

赣ICP备2022001555号-1

赣公网安备 36092402000057号