PyTorch 深度学习框架快速上手指南
PyTorch 可以说是目前最常用的深度学习框架 , 常应用于搭建深度学习网络 , 完成一些深度学习任务 (CV、NLP领域)
要想快速上手 PyTorch , 你需要知道什么 :
- 一个项目的完整流程 , 即到什么点该干什么事
- 几个常用 (或者说必备的) 组件
剩下的时间你就需要了解 , 完成什么任务 , 需要什么网络 , 而且需要用大量的时间去做这件事情
$^{(e.g.)}$例如 : 你现在有一个图像分类任务 , 完成该任务需要什么网络, 你需要通过查找资料来了解需要查找什么网络。
需要注意的是 , 有一些常识性的问题你必须知道 , 例如: 图像层面无法或很难使用机器学习方法 , 卷积神经网络最多的是应用于图像领域等
下面我将通过一个具体的分类项目流程来讲述到什么点该干什么事
一个完整的 PyTorch 分类项目需要以下几个方面:
- 准备数据集
- 加载数据集
- 使用变换(Transforms模块)
- 构建模型
- 训练模型 + 验证模型
- 推理模型
- 准备数据集 一般来说 , 比赛会给出你数据集, 不同数据集的组织方式不同 , 我们要想办法把他构造成我们期待的样子
- 分类数据集一般比较简单, 一般是将某个分类的文件全都放在一个文件夹中, 例如:
- 二分类问题 : Fake(文件夹) / Real(文件夹)
- 多分类问题 : 分类 1(文件夹) / 分类 2(文件夹) / … / 分类 N(文件夹)
- 当然有些时候他们会给出其他方式 , 如 UBC-OCEAN , 他们将所有的图片放在一个文件夹中 , 并用 csv 文件存储这些文件的路径(或者是文件名) , 然后在 csv 文件中进行标注(如下):
- 以后你可能还会遇到更复杂的目标检测的数据集, 这种数据集会有一些固定格式 , 如 VOC格式 , COCO格式等
- 在数据集方面 , 需要明确三个概念——训练集、验证集和测试集 , 请务必明确这三个概念 , 这是基本中的基本
- 训练集(Train) : 字如其名 , 简单来说就是知道数据 , 也知道标签 的数据 , 我们用其进行训练
- 验证集(Valid) : 验证集 和 测试集 是非常容易混淆的概念 , 简单来说 , 验证集就是我们也知道数据和标签 , 但是我们的一般不将这些数据用于训练 , 而是将他们当作我们的测试集 , 即我们已经站在了出题人的角度 , 给出参赛者输入数据 , 而我们知道这个数据对应的输出 , 但是我们不让模型知道
- 测试集(Test) : 测试集就是 , 我们不知道输入数据的输出标签 , 只有真正的出题人知道 , 一般来说 , 我们无法拿到测试集 , 测试集是由出题人掌控的
- 需要注意的是 , 如果你通过某种途径知道了所有的测试集的标签时 , 不可使用测试集进行训练 , 这是非常严重的学术不端行为 , 会被学术界和工业界唾弃
1 | # 现在我们已经有了一个数据集 , 我将以 FAKE_OR_REAL 数据集为例 , 展示我们数据集的结构 |
- 加载数据集
- 请务必记住 , 不管是什么数据集 , 数据集是如何构成的 , 在使用 PyTorch 框架时 , 我们都要像尽办法将他们加载入
Dataset
类中 - 简单来说 ,
Dataset
类就是描述了我们数据的组成的类 - 需要注意 , PyTorch 实现了许多自己的
Dataset
类 , 这些类可以轻松的加载特定格式的数据集 , 但是我强烈建议所有的数据集都要自己继承Dataset类 , 自行加载 , 这样我们可以跟清晰的指导数据集的组成方式 , 也可以使得我们加载任意格式的数据集 - 实现
DataSet
类需要我们先继承Dataset
类 , 在继承Dataset
类后, 我们只需要实现其中的__init__
、__len__
和__getitem__
三个方法 , 即可完成对数据集的加载 , 这三个方法就和他的名字一样 :__init__
方法是构造函数 , 用于初始化__len__
方法用于获取数据集的大小__getitem__
方法用于获取数据集的元素 , 我将从下面的代码中进行更详细的解释
- 有些数据集并不分别提供 Train训练集 和 Valid验证集, 我们可以使用
random_split()
方法对数据集进行划分需要注意的是, 每次重新划分数据集时, 必须重新训练模型, 因为
random_split()
方法随机性, 划分后的数据不可能和之前的数据完全重合, 因此会导致数据交叉的情况, 下面一段使用random_split()
进行划分的 Python 代码示例 :1
2
3
4
5
6
7
8# 下面演示使用 random_split 来划分数据集的操作
# 我们假设已经定义了 CustomImageDataSet
split_ratio = 0.8 # 表示划分比例为 8 : 2
dataset = CustomImageDataSet(fake_dir, real_dir) # 定义 CustomImageDataSet 类, 假设此时没有划分训练集和验证集
train_dataset_num = int(dataset.lens * split_ratio) # 定义训练集的大小
valid_dataset_num = dataset.lens - train_dataset_num # 定义验证集的大小
# random_split(dataset, [train_dataset_num, valid_dataset_num]) 表示将 dataset 按照 [train_dataset_num: valid_dataset_num] 的比例进行划分
train_dataset, valid_dataset = random_split(dataset, [train_dataset_num, valid_dataset_num])当数据集不是很大的时, 推荐人为的将数据集进行划分, 可以写一个 Python 脚本(.py) 或者 批处理脚本(.bat) 来完成这个操作
- 请务必记住 , 不管是什么数据集 , 数据集是如何构成的 , 在使用 PyTorch 框架时 , 我们都要像尽办法将他们加载入
完整的数据集加载代码如下:
1 | import torch |
- 使用 Transforms
- 不要简单的使用原始图片进行训练 , 当然如果一定要使用原始图片进行训练, 也可以使用 transforms 模块
- 一般来说, 训练集和验证集的 transforms 是不同的, 因为我们希望验证集和测试集的图片贴合真实的情况
- 下面的代码演示了如何定义 transforms
- 在定义完
transforms
我们就可以完全定义我们的Dataset
和Dataloader
了
1 | import torch |
1 | from torch.utils.data import DataLoader |
1 | # 查看Dataloader数据 |
- 构建模型
- 构建模型是比较重要的一部分, 一般来说做好数据集之后, 最重要的事情就是修改模型, 通过训练结果改进模型, 判断自己的模型的正确性, 这里就是整个你要用到的神经网络的部分 , 需要注意的是 , 这里指定什么输入 , 推理的时候就要指定什么输入
- 简单用几个符号说明一下就是: $^{Train} model (inputX, inputY, …)$ → $^{Valid} model (inputX, inputY, …)$
- 如何确定输入是什么: 看
forward()
的输入是啥模型的输入就是啥
- 如何确定输入是什么: 看
- 我下面展现了我复现的 ResNet50 , 用这种方式可以顺便教你如何复现网络结构
1 | import torch.nn as nn |
1 | import torch |
- 训练模型 + 验证模型
- 这里需要直接对模型进行训练 , 一般来说 , 在训练的过程中我们会加入 tqdm 库使得训练过程可视化 , 有时我们还会在训练过程中保存更好的训练结果 , 并且设置断点训练等操作 , 我只使用最简单的方式进行预测
train
部分的代码因人而异, 基本上每个人的写法都可能不同, 没有固定的写法- 对于训练完的模型我们需要对其进行评价, 一般来说, 训练和验证都是放在一起的, 不可分开的
- 记得保存一下训练后的模型, 使用如下代码保存/加载整个模型
1
2
3
4
5
6# 保存模型
model_path = "xxxx.pth" # xxxx 表示一个你喜欢的名字
torch.save(model, model_path) # 使用 torch.save(model, model_path) 保存模型
# 加载模型
model = torch.load(model_path) # 使用 torch.load(model_path) 即可加载模型
完整的”训练模型 + 验证模型”代码如下:
1 | from tqdm import tqdm |
当然我们也可以使用绘图函数,来展示过程中的相关数据。
1 | import matplotlib.pyplot as plt |
- 推理模型
- 很高兴, 如果你到这一步, 你的水平肯定已经有了质的飞跃, 这里已经是最后一步了, 结束这个部分, 你就要开始自己的探索之路了
- 推理模型很简单, 我在上面说过, 构造模型时指定什么输入 , 推理的时候就要指定什么输入, 这里就是对应的部分了
1 | from torchvision.transforms import ToTensor, Resize, Normalize |