BELKA_2024

[TOC]

Leash Bio - Predict New Medicines with BELKA

BELKA 预测新药

Predict small molecule-protein interactions using the Big Encoded Library for Chemical Assessment (BELKA)

使用化学评估大编码库(BELKA)预测小分子蛋白质相互作用

Overview

In this competition, you’ll develop machine learning (ML) models to predict the binding affinity of small molecules to specific protein targets – a critical step in drug development for the pharmaceutical industry that would pave the way for more accurate drug discovery. You’ll help predict which drug-like small molecules (chemicals) will bind to three possible protein targets.

在这场比赛中,你将开发机器学习(ML)模型来预测小分子与特定蛋白质靶标(目标蛋白)的结合亲和力——这是制药行业药物开发的关键一步,将为更准确的药物发现铺平道路。你将帮助预测哪种药物样的小分子(化学物质)将与三种可能的蛋白质靶点结合。

Description

Small molecule drugs are chemicals that interact with cellular protein machinery and affect the functions of this machinery in some way. Often, drugs are meant to inhibit the activity of single protein targets, and those targets are thought to be involved in a disease process. A classic approach to identify such candidate molecules is to physically make them, one by one, and then expose them to the protein target of interest and test if the two interact. This can be a fairly laborious and time-intensive process.

小分子药物是与细胞蛋白质机制相互作用并以某种方式影响该机制功能的化学物质。通常,药物旨在抑制单个蛋白质靶标的活性,而这些靶标被认为与疾病过程有关。识别这类候选分子的一种经典方法是一个接一个地进行物理制造,然后将其暴露于感兴趣的蛋白质靶点,并测试两者是否相互作用。这可能是一个相当费力和耗时的过程。

The US Food and Drug Administration (FDA) has approved roughly 2,000 novel molecular entities in its entire history. However, the number of chemicals in druglike space has been estimated to be 10^60, a space far too big to physically search. There are likely effective treatments for human ailments hiding in that chemical space, and better methods to find such treatments are desirable to us all.

美国食品药品监督管理局(FDA)已经批准了大约2000种新型分子实体在其整个历史. 然而,类药物领域的化学物质数量估计为$10^60$,这个空间太大了,无法进行物理搜索。在这个化学空间里,可能有有效的治疗人类疾病的方法,而找到更好的治疗方法对我们所有人来说都是可取的。

To evaluate potential search methods in small molecule chemistry, competition host Leash Biosciences physically tested some 133M small molecules for their ability to interact with one of three protein targets using DNA-encoded chemical library (DEL) technology. This dataset, the Big Encoded Library for Chemical Assessment (BELKA), provides an excellent opportunity to develop predictive models that may advance drug discovery.

为了评估小分子化学中潜在的搜索方法,比赛主办方Leash Biosciences使用DNA编码化学文库(DEL)技术对约133M个小分子进行了物理测试,以确定它们与三个蛋白质靶标之一相互作用的能力。该数据集,即化学评估大编码库(BELKA),为开发可能促进药物发现的预测模型提供了极好的机会。

Datasets of this size are rare and restricted to large pharmaceutical companies. The current best-curated public dataset of this kind is perhaps bindingdb, which, at 2.8M binding measurements, is much smaller than BELKA.

这种规模的数据集非常罕见,仅限于大型制药公司。目前这类最好的公共数据集可能是bindingdb,在2.8M的结合测量值下,比BELKA小得多。

This competition aims to revolutionize small molecule binding prediction by harnessing ML techniques. Recent advances in ML approaches suggest it might be possible to search chemical space by inference using well-trained computational models rather than running laboratory experiments. Similar progress in other fields suggest using ML to search across vast spaces could be a generalizable approach applicable to many domains. We hope that by providing BELKA we will democratize aspects of computational drug discovery and assist the community in finding new lifesaving medicines.

这项竞赛旨在通过利用ML技术彻底改变小分子结合预测。ML方法的最新进展表明,使用训练有素的计算模型而不是进行实验室 实验,通过推理搜索化学空间是可能的。其他 领域的类似进展表明,使用ML在广阔的空间中搜索可能是一种适用于许多领域的通用方法。我们希望通过提供BELKA,我们将使计算药物发现的各个方面民主化,并帮助社区寻找新的救命药物。

Here, you’ll build predictive models to estimate the binding affinity of unknown chemical compounds to specified protein targets. You may use the training data provided; alternatively, there are a number of methods to make small molecule binding predictions without relying on empirical binding data (e.g. DiffDock, and this contest was designed to allow for such submissions).

在这里,你将建立预测模型来估计未知化合物与特定蛋白质靶标的结合亲和力。您可以使用提供的培训数据;或者,有许多方法可以在不依赖经验结合数据的情况下进行小分子结合预测(例如DiffDock,而本次竞赛旨在允许此类提交)。

Your work will contribute to advances in small molecule chemistry used to accelerate drug discovery.

你的工作将有助于促进用于加速药物发现的小分子化学的进步。

Evaluation

This metric for this competition is the average precision calculated for each (protein, split group) and then averaged for the final score. Please see this forum post for important details.

这项比赛的指标是为每个(蛋白质、分组)计算的平均精度,然后为最终得分取平均值。请参阅此论坛帖子了解重要细节。

Here’s the code for the implementation.

这是代码以供实施。

Submission File

For each id in the test set, you must predict a probability for the binary target binds target. The file should contain a header and have the following format:

对于测试集中的每个id您必须预测二进制目标“绑定”目标的概率。该文件应包含一个标头,并具有以下格式:

1
2
3
4
5
id,binds
295246830,0.5
295246831,0.5
295246832,0.5
etc.

Timeline

  • April 4, 2024 - Start Date.
  • July 1, 2024 - Entry Deadline. You must accept the competition rules before this date in order to compete.
  • July 1, 2024 - Team Merger Deadline. This is the last day participants may join or merge teams.
  • July 8, 2024 - Final Submission Deadline.

All deadlines are at 11:59 PM UTC on the corresponding day unless otherwise noted. The competition organizers reserve the right to update the contest timeline if they deem it necessary.

Prizes

  • First Prize: $12,000
  • Second Prize: $10,000
  • Third Prize: $10,000
  • Fourth Prize: $8,000
  • Fifth Prize: $5,000
  • Top Student Group: $5,000 to the highest performing student team. A team would be considered a student team if majority members (e.g. at least 3 out of a 5 member team) are students enrolled in a high school or university degree. In the case of an even number of members, half of them must be students.

Competition Host

Leash Biosciences is a discovery-stage biotechnology company that seeks to improve medicinal chemistry with machine learning approaches and massive data collection. Leash is comprised of wet lab scientists and dry lab scientists in equal numbers, and is proudly headquartered in Salt Lake City, Utah, USA.

Additional Details

Chemical Representations

One of the goals of this competition is to explore and compare many different ways of representing molecules. Small molecules have been [represented](https://pubs.acs.org/doi/10.1021/acsinfocus.7e7006?ref=infocus%2FAI_& Machine Learning) with SMILES, graphs, 3D structures, and more, including more esoteric methods such as spherical convolutional neural nets. We encourage competitors to explore not only different methods of making predictions but also to try different ways of representing the molecules.

We provide the molecules in SMILES format.

这场比赛的目标之一是探索和比较许多不同的分子表现方式。小分子已经用SMILES、图形、3D结构等表示,包括更深奥的方法,如球形卷积神经网络。我们鼓励竞争对手不仅探索不同的预测方法,还尝试不同的分子表示方法。

我们提供SMILES格式的分子。

SMILES

SMILES is a concise string notation used to represent the structure of chemical molecules. It encodes the molecular graph, including atoms, bonds, connectivity, and stereochemistry as a linear sequence of characters, by traversing the molecule graph. SMILES is widely used in machine learning applications for chemistry, such as molecular property prediction, drug discovery, and materials design, as it provides a standardized and machine-readable format for representing and manipulating chemical structures.

The SMILES in this dataset should be sufficient to be translated into any other chemical representation format that you want to try. A simple way to perform some of these translations is with RDKit.

SMILES是一种简明的字符串表示法,用于表示化学分子的结构。它通过遍历分子图,将分子图(包括原子、键、连接性和立体化学)编码为线性字符序列。SMILES广泛用于化学的机器学习应用,如分子性质预测、药物发现和材料设计,因为它为表示和操纵化学结构提供了标准化和机器可读的格式。
该数据集中的SMILES应该足以转换为您想要尝试的任何其他化学表示格式。执行其中一些翻译的一种简单方法是使用RDKit.

Details about the experiments

DELs are libraries of small molecules with unique DNA barcodes covalently attached

Traditional high-throughput screening requires keeping individual small molecules in separate, identifiable tubes and demands a lot of liquid handling to test each one of those against the protein target of interest in a separate reaction. The logistical overhead of these efforts tends to restrict screening collections, called libraries, to 50K-5M small molecules. A scalable solution to this problem, DNA-encoded chemical libraries, was described in 2009. As DNA sequencing got cheaper and cheaper, it became clear that DNA itself could be used as a label to identify, and deconvolute, collections of molecules in a complex mixture. DELs leverage this DNA sequencing technology.

These barcoded small molecules are in a pool (many in a single tube, rather than one tube per small molecule) and are exposed to the protein target of interest in solution. The protein target of interest is then rinsed to remove small molecules in the DEL that don’t bind the target, and the remaining binders are collected and their DNA sequenced.

DEL是共价连接有独特DNA条形码的小分子库
传统高通量筛选需要将单个小分子保持在单独的、可识别的管中,并且需要大量的液体处理来在单独的反应中针对感兴趣的蛋白质靶标测试其中的每一个。这些工作的后勤开销往往将筛选收藏(称为文库)限制在5000万至500万个小分子以内。这个问题的一个可扩展的解决方案,DNA编码的化学文库,在2009年描述. 随着DNA测序变得越来越便宜,很明显,DNA本身可以用作标签来识别和消除复杂混合物中分子的聚集。DELs这种DNA测序技术。
这些条形码小分子在一个池中(许多在单管中,而不是每个小分子一管),并暴露于溶液中感兴趣的蛋白质靶标。然后冲洗感兴趣的蛋白质靶标,以去除DEL中不与靶标结合的小分子,收集剩余的结合物并对其DNA进行测序。

DELs are manufactured by combining different building blocks

An intuitive way to think about DELs is to imagine a Mickey Mouse head as an example of a small molecule in the DEL. We attach the DNA barcode to Mickey’s chin. Mickey’s left ear is connected by a zipper; Mickey’s right ear is connected by velcro. These attachment points of zippers and velcro are analogies to different chemical reactions one might use to construct the DEL.

We could purchase ten different Mickey Mouse faces, ten different zipper ears, and ten different velcro ears, and use them to construct our small molecule library. By creating every combination of these three, we’ll have 1,000 small molecules, but we only needed thirty building blocks (faces and ears) to make them. This combinatorial approach is what allows DELs to have so many members: the library in this competition is composed of 133M small molecules. The 133M small molecule library used here, AMA014, was provided by AlphaMa. It has a triazine core and superficially resembles the DELs described here.

DEL是通过组合不同的构建块来制造的
一个思考DEL的直观方法是想象一个米老鼠的头作为DEL中一个小分子的例子。我们把DNA条形码贴在米奇的下巴上。米奇的左耳由拉链连接;米奇的右耳是用尼龙搭扣连接的。拉链和尼龙搭扣的这些连接点类似于可能用于构建DEL的不同化学反应。
我们可以购买十个不同的米老鼠脸、十个不同拉链耳朵和十个不同尼龙搭扣耳朵,并用它们来构建我们的小分子库。通过创建这三者的每一个组合,我们将拥有1000个小分子,但我们只需要30个构建块(脸和耳朵)就可以制造它们。这种组合方法使DEL能够拥有如此多的成员:这场竞争中的文库由133M个小分子组成。这里使用的133M小分子文库AMA014由AlphaMa提供。它有一个三嗪核心,表面上类似于此处描述的DEL。

Acknowledgements

Leash Biosciences is grateful for the generous cosponsorship of Top Harvest Capital and AlphaMa.

Citation

Andrew Blevins, Ian K Quigley, Brayden J Halverson, Nate Wilkinson, Rebecca S Levin, Agastya Pulapaka, Walter Reade, Addison Howard. (2024). Leash Bio - Predict New Medicines with BELKA. Kaggle. https://kaggle.com/competitions/leash-BELKA


Dataset Description

Overview

The examples in the competition dataset are represented by a binary classification of whether a given small molecule is a binder or not to one of three protein targets. The data were collected using DNA-encoded chemical library (DEL) technology.

比赛数据集中的例子由给定小分子是否与三个蛋白质靶标之一结合的二元分类表示。使用DNA编码化学文库(DEL)技术收集数据。

We represent chemistry with SMILES (Simplified Molecular-Input Line-Entry System) and the labels as binary binding classifications, one per protein target of three targets.

我们用SMILES(简化分子输入 行输入系统)和二元绑定分类来表示化学,三个靶标中的每个蛋白质靶标都有一个。

Files

[train/test].[csv/parquet] - The train or test data, available in both the csv and parquet formats.

  • id - A unique example_id that we use to identify the molecule-binding target pair.
  • buildingblock1_smiles - The structure, in SMILES, of the first building block
  • buildingblock2_smiles - The structure, in SMILES, of the second building block
  • buildingblock3_smiles - The structure, in SMILES, of the third building block
  • molecule_smiles - The structure of the fully assembled molecule, in SMILES. This includes the three building blocks and the triazine core. Note we use a [Dy] as the stand-in for the DNA linker.
  • protein_name - The protein target name
  • binds - The target column. A binary class label of whether the molecule binds to the protein. Not available for the test set.

sample_submission.csv - A sample submission file in the correct format

[train/test].[csv/parquet] - 训练或测试数据,csv和parquet格式均可。

  • id - 我们用来识别分子结合靶标对的唯一示例_id。
  • buildingblock1_smiles - 第一个构建块的结构,以SMILES表示
  • buildingblock2_smiles - 第二个构建块的结构,以SMILES表示
  • buildingblock3_smiles - 第三个构建块的结构,以SMILES表示
  • molecule_smiles - 完全组装的分子的结构,以SMILES表示。这包括三个构建块和三嗪核心。请注意,我们使用[Dy]作为DNA连接子的替代。
  • protein_name - 蛋白质靶标名称
  • binds - 目标列。分子是否与蛋白质结合的二进制类标签。不适用于测试集。

Competition data

All data were generated in-house at Leash Biosciences. We are providing roughly 98M training examples per protein, 200K validation examples per protein, and 360K test molecules per protein. To test generalizability, the test set contains building blocks that are not in the training set. These datasets are very imbalanced: roughly 0.5% of examples are classified as binders; we used 3 rounds of selection in triplicate to identify binders experimentally. Following the competition, Leash will make all the data available for future use (3 targets × 3 rounds of selection × 3 replicates × 133M molecules, or 3.6B measurements).

所有数据均由Leash Biosciences公司内部生成。我们为每种蛋白质提供了大约 98M 个训练实例,为每种蛋白提供了 200K 个验证实例,为每个蛋白质提供了 360K 个测试分子。为了测试可推广性,测试集包含不在训练集中的构建块。这些数据集非常不平衡:大约0.5%的示例被归类为绑定;我们使用了三轮一式三份的选择来实验鉴定粘合剂。比赛结束后,Leash将提供所有数据供未来使用(3个靶标×3轮选择×3个重复×3.33M个分子,或3.6B测量值)。

Targets

Proteins are encoded in the genome, and names of the genes encoding those proteins are typically bestowed by their discoverers and regulated by the Hugo Gene Nomenclature Committee. The protein products of these genes can sometimes have different names, often due to the history of their discovery.

We screened three protein targets for this competition.

蛋白质在基因组中编码,编码这些蛋白质的基因的名称通常由其发现者命名,并由雨果基因命名委员会监管。这些基因的蛋白质产物有时可能有不同的名称,通常是由于它们的发现历史。
我们为这次比赛筛选了三个蛋白质靶点。

EPHX2 (sEH)

The first target, epoxide hydrolase 2, is encoded by the EPHX2 genetic locus, and its protein product is commonly named “soluble epoxide hydrolase”, or abbreviated to sEH. Hydrolases are enzymes that catalyze certain chemical reactions, and EPHX2/sEH also hydrolyzes certain phosphate groups. EPHX2/sEH is a potential drug target for high blood pressure and diabetes progression, and small molecules inhibiting EPHX2/sEH from earlier DEL efforts made it to clinical trials.

EPHX2/sEH was also screened with DELs, and hits predicted with ML approaches, in a recent study but the screening data were not published. We included EPHX2/sEH to allow contestants an external gut check for model performance by comparing to these previously-published results.

We screened EPHX2/sEH purchased from Cayman Chemical, a life sciences commercial vendor. For those contestants wishing to incorporate protein structural information in their submissions, the amino sequence is positions 2-555 from UniProt entry P34913, the crystal structure can be found in PDB entry 3i28, and predicted structure can be found in AlphaFold2 entry 34913. Additional EPHX2/sEH crystal structures with ligands bound can be found in PDB.

第一个靶标环氧化物水解酶2由EPHX2基因座编码,其蛋白产物通常被命名为“可溶性环氧化物水解酶”,或缩写为sEH。水解酶是催化某些化学反应的酶,EPHX2/sEH也水解某些磷酸基团。EPHX2/sEH是高血压和糖尿病进展的潜在药物靶点,早期DEL研究中抑制EPHX2/s EH的小分子已进入临床试验.
EPHX2/sEH也用DEL进行了筛选,并用ML方法预测了命中率(https://blog.research.google/2020/06/unlocking-chemome-with-dna-encoded.html学习https://pubs.acs.org/doi/10.1021/acs.jmedchem.0c00452)但筛选数据没有公布。我们纳入了EPHX2/sEH,通过与之前公布的结果进行比较,让参赛者能够对模型性能进行外部检查。
我们筛选了EPHX2/sEH购自开曼化学. 在PDB中可以发现具有结合配体的额外的EPHX2/sEH晶体结构。

BRD4

The second target, bromodomain 4, is encoded by the BRD4 locus and its protein product is also named BRD4. Bromodomains bind to protein spools in the nucleus that DNA wraps around (called histones) and affect the likelihood that the DNA nearby is going to be transcribed, producing new gene products. Bromodomains play roles in cancer progression and a number of drugs have been discovered to inhibit their activities.

BRD4 has been screened with DEL approaches previously but the screening data were not published. We included BRD4 to allow contestants to evaluate candidate molecules for oncology indications.

We screened BRD4 purchased from Active Motif, a life sciences commercial vendor. For those contestants wishing to incorporate protein structural information in their submissions, the amino acid sequence is positions 44-460 from UniProt entry O60885-1, the crystal structure (for a single domain) can be found in PDB entry 7USK and predicted structure can be found in AlphaFold2 entry O60885. Additional BRD4 crystal structures with ligands bound can be found in PDB.

ALB (HSA)

The third target, serum albumin, is encoded by the ALB locus and its protein product is also named ALB. The protein product is sometimes abbreviated as HSA, for “human serum albumin”. ALB, the most common protein in the blood, is used to drive osmotic pressure (to bring fluid back from tissues into blood vessels) and to transport many ligands, hormones, fatty acids, and more.

Albumin, being the most abundant protein in the blood, often plays a role in absorbing candidate drugs in the body and sequestering them from their target tissues. Adjusting candidate drugs to bind less to albumin and other blood proteins is a strategy to help these candidate drugs be more effective.

ALB has been screened with DEL approaches previously but the screening data were not published. We included ALB to allow contestants to build models that might have a larger impact on drug discovery across many disease types. The ability to predict ALB binding well would allow drug developers to improve their candidate small molecule therapies much more quickly than physically manufacturing many variants and testing them against ALB empirically in an iterative process.

We screened ALB purchased from Active Motif. For those contestants wishing to incorporate protein structural information in their submissions, the amino acid sequence is positions 25 to 609 from UniProt entry P02768, the crystal structure can be found in PDB entry 1AO6, and predicted structure can be found in AlphaFold2 entry P02768. Additional ALB crystal structures with ligands bound can be found in PDB.

Good luck!

PyTorch快速上手指南

PyTorch 深度学习框架快速上手指南

PyTorch 可以说是目前最常用的深度学习框架 , 常应用于搭建深度学习网络 , 完成一些深度学习任务 (CV、NLP领域)

要想快速上手 PyTorch , 你需要知道什么 :

  1. 一个项目的完整流程 , 即到什么点该干什么事
  2. 几个常用 (或者说必备的) 组件

剩下的时间你就需要了解 , 完成什么任务 , 需要什么网络 , 而且需要用大量的时间去做这件事情


$^{(e.g.)}$例如 : 你现在有一个图像分类任务 , 完成该任务需要什么网络, 你需要通过查找资料来了解需要查找什么网络。

需要注意的是 , 有一些常识性的问题你必须知道 , 例如: 图像层面无法或很难使用机器学习方法 , 卷积神经网络最多的是应用于图像领域等


下面我将通过一个具体的分类项目流程来讲述到什么点该干什么事

一个完整的 PyTorch 分类项目需要以下几个方面:

  1. 准备数据集
  2. 加载数据集
  3. 使用变换(Transforms模块)
  4. 构建模型
  5. 训练模型 + 验证模型
  6. 推理模型

  1. 准备数据集 一般来说 , 比赛会给出你数据集, 不同数据集的组织方式不同 , 我们要想办法把他构造成我们期待的样子
    • 分类数据集一般比较简单, 一般是将某个分类的文件全都放在一个文件夹中, 例如:
    • 二分类问题 : Fake(文件夹) / Real(文件夹)
    • 多分类问题 : 分类 1(文件夹) / 分类 2(文件夹) / … / 分类 N(文件夹)
    • 当然有些时候他们会给出其他方式 , 如 UBC-OCEAN , 他们将所有的图片放在一个文件夹中 , 并用 csv 文件存储这些文件的路径(或者是文件名) , 然后在 csv 文件中进行标注(如下):

    • 以后你可能还会遇到更复杂的目标检测的数据集, 这种数据集会有一些固定格式 , 如 VOC格式 , COCO格式
    • 在数据集方面 , 需要明确三个概念——训练集、验证集和测试集 , 请务必明确这三个概念 , 这是基本中的基本
      • 训练集(Train) : 字如其名 , 简单来说就是知道数据 , 也知道标签 的数据 , 我们用其进行训练
      • 验证集(Valid) : 验证集测试集 是非常容易混淆的概念 , 简单来说 , 验证集就是我们也知道数据和标签 , 但是我们的一般不将这些数据用于训练 , 而是将他们当作我们的测试集 , 即我们已经站在了出题人的角度 , 给出参赛者输入数据 , 而我们知道这个数据对应的输出 , 但是我们不让模型知道
      • 测试集(Test) : 测试集就是 , 我们不知道输入数据的输出标签 , 只有真正的出题人知道 , 一般来说 , 我们无法拿到测试集 , 测试集是由出题人掌控的
      • 需要注意的是 , 如果你通过某种途径知道了所有的测试集的标签时 , 不可使用测试集进行训练 , 这是非常严重的学术不端行为 , 会被学术界和工业界唾弃
1
2
3
4
5
6
7
8
9
10
11
# 现在我们已经有了一个数据集 , 我将以 FAKE_OR_REAL 数据集为例 , 展示我们数据集的结构
# D:\REAL_OR_FAKE\DATASET
# ├─test --------- 测试集路径, 这里可以放你自己的数据, 你甚至可以将他们分类, 但是请注意, 实际情况下你只能通过这种方式来“得到”测试集
# │ ├─fake ------ 你自己分的类, 开心就好
# │ └─real ------ 同上
# ├─train -------- 训练集路径, 这里面放的是题目给出的数据, 下面有 fake 和 real 两个文件夹, 这两个文件夹中就是两个类别, 我们要用这里面的图片进行分类
# │ ├─fake
# │ └─real
# └─valid -------- 验证集路径, 这里面放的是题目给出的数据, 下面有 fake 和 real 两个文件夹, 这两个文件夹中就是两个类别, 这里面的图片不需要进行训练
# ├─fake
# └─real
  1. 加载数据集
    • 请务必记住 , 不管是什么数据集 , 数据集是如何构成的 , 在使用 PyTorch 框架时 , 我们都要像尽办法将他们加载入 Dataset 类中
    • 简单来说 , Dataset 类就是描述了我们数据的组成的类
    • 需要注意 , PyTorch 实现了许多自己的 Dataset 类 , 这些类可以轻松的加载特定格式的数据集 , 但是我强烈建议所有的数据集都要自己继承Dataset类 , 自行加载 , 这样我们可以跟清晰的指导数据集的组成方式 , 也可以使得我们加载任意格式的数据集
    • 实现 DataSet 类需要我们先继承 Dataset 类 , 在继承 Dataset 类后, 我们只需要实现其中的__init____len____getitem__三个方法 , 即可完成对数据集的加载 , 这三个方法就和他的名字一样 :
      • __init__ 方法是构造函数 , 用于初始化
      • __len__ 方法用于获取数据集的大小
      • __getitem__ 方法用于获取数据集的元素 , 我将从下面的代码中进行更详细的解释
    • 有些数据集并不分别提供 Train训练集 和 Valid验证集, 我们可以使用 random_split() 方法对数据集进行划分
      • 需要注意的是, 每次重新划分数据集时, 必须重新训练模型, 因为 random_split() 方法随机性, 划分后的数据不可能和之前的数据完全重合, 因此会导致数据交叉的情况, 下面一段使用 random_split() 进行划分的 Python 代码示例 :

        1
        2
        3
        4
        5
        6
        7
        8
        # 下面演示使用 random_split 来划分数据集的操作
        # 我们假设已经定义了 CustomImageDataSet
        split_ratio = 0.8 # 表示划分比例为 8 : 2
        dataset = CustomImageDataSet(fake_dir, real_dir) # 定义 CustomImageDataSet 类, 假设此时没有划分训练集和验证集
        train_dataset_num = int(dataset.lens * split_ratio) # 定义训练集的大小
        valid_dataset_num = dataset.lens - train_dataset_num # 定义验证集的大小
        # random_split(dataset, [train_dataset_num, valid_dataset_num]) 表示将 dataset 按照 [train_dataset_num: valid_dataset_num] 的比例进行划分
        train_dataset, valid_dataset = random_split(dataset, [train_dataset_num, valid_dataset_num])
      • 当数据集不是很大的时, 推荐人为的将数据集进行划分, 可以写一个 Python 脚本(.py) 或者 批处理脚本(.bat) 来完成这个操作

完整的数据集加载代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torch.utils.data import Dataset
import os
from PIL import Image

# 这里我们定义了一个 CustomImageDataset(...) 类, 括号中的内容表示我们继承了 ... 类
# 因此我们这里 CustomImageDataset(Dataset): 表示我们定义了一个“自定义图片”类, 这个“自定义图片”类继承自 Dataset 类
class CustomImageDataset(Dataset):
# 这里我们实现 __init__ 方法, __init__ 方法其实就是一个类的构造函数, 他也分有参构造和无参构造, 只是在这里我们说无参构造基本没啥意义
# 因此我们常常实现这个类, 使得可以指定这个类的输入输出
# 比如下面我们写的 def __init__(self, fake_dir, real_dir, transform=None):
# self : 自己, 我一般直接理解为 this 指针, 如果有兴趣了解更深层的东西可以查阅一些资料, 这个是必填的
# fake_dir : 用于指定 fake 类型图片的位置的
# real_dir : 用于指定 real 类型图片的位置的
# transform : 用于指定变换, 简单来说就是对输入进行某些操作, 我会在下面的板块中进行详细叙述
def __init__(self, fake_dir, real_dir, transform=None):
self.fake_dir = fake_dir # 这里表示这个类内定义了一个 fake_dir, 其值为传入的 fake_dir
self.real_dir = real_dir # 这里表示这个类内定义了一个 real_dir, 其值为传入的 real_dir
self.transform = transform # 这里表示这个类内定义了一个 transform, 其值为传入的 transform, 当没有传入时, 这个变量为 None

self.fake_images = os.listdir(fake_dir) # 传入的 fake_dir 是一个路径, 我们使用 os.listdir(fake_dir) 可以加载 fake_dir 文件夹下的内容, 也就是所有 fake 图片
self.real_images = os.listdir(real_dir) # 传入的 real_dir 是一个路径, 我们使用 os.listdir(real_dir) 可以加载 real_dir 文件夹下的内容, 也就是所有 real 图片

self.total_images = self.fake_images + self.real_images # 总图片列表, 就是将 fake 图片列表和 real 图片列表进行组合
self.labels = [0]*len(self.fake_images) + [1]*len(self.real_images) # 对图片打标签, fake 为 0, real 为 1
# [0] * 10 得到的结果为 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# [1] * 10 得到的结果为 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

# 这里我们实现 __len__ 方法, 这个方法用于获取数据集的大小
def __len__(self):
return len(self.total_images) # 这里我们直接返回总图片列表的长度即可, 这里的实现方式不唯一, 只要能做到表示数据集大小即可

# 这里我们实现 __getitem__ 方法, 这个方法用于获取数据集中的某个元素
# 其中 idx 表示索引, 这个参数是必须的, 当然可以起其他名字, 不过最好还是使用 idx
# __getitem__(self, idx) 表示获取 idx 位置的元素
def __getitem__(self, idx):
# 这里表示获取一个元素的逻辑
# 当 idx 位置的标签为 0 时, 图片的路径为 fake_dir + self.total_images[idx], idx 即为图片的索引位置
# 当 idx 位置的标签为 1 时, 图片的路径为 real_dir + self.total_images[idx]
image_path = os.path.join(self.fake_dir if self.labels[idx] == 0 else self.real_dir, self.total_images[idx])

# 使用 PIL 库加载图片, 通过 image_path 打开图片, 并且将图片转化为 RGB 格式
image = Image.open(image_path).convert('RGB')

# 这里是 transform, 表示变换, 当其值为 None 时不进行操作, 当传入自己的 transform 时即为非空, 即对输入数据进行变换
if self.transform:
# 我们将变换后的图片直接保存在原位置
image = self.transform(image)

# 最后函数的返回值为 image 和 self.labels[idx], 即表示索引位置 idx 处的图片和标签
return image, self.labels[idx]
  1. 使用 Transforms
    • 不要简单的使用原始图片进行训练 , 当然如果一定要使用原始图片进行训练, 也可以使用 transforms 模块
    • 一般来说, 训练集和验证集的 transforms 是不同的, 因为我们希望验证集和测试集的图片贴合真实的情况
    • 下面的代码演示了如何定义 transforms
    • 在定义完 transforms 我们就可以完全定义我们的 DatasetDataloader
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torchvision import transforms

# 定义transform
# transforms.Compose(transforms) 实际上就是将多个 transform 方法变为逐步执行, 一般我们直接使用这种方式来对图片进行连续的变换
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomVerticalFlip(), # 随机垂直翻转
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), # 改变图像的属性, 将图像的brightness亮度/contrast对比度/saturation饱和度/hue色相 随机变化为原图亮度的 10%
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # 对图片先进行随机采集, 然后对裁剪得到的图像缩放为同一大小, 意义是即使只是该物体的一部分, 我们也认为这是该类物体
transforms.RandomRotation(40), # 在[-40, 40]范围内随机旋转
transforms.RandomAffine(degrees=0, shear=10, scale=(0.8,1.2)), # 随机仿射变换
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # 色彩抖动
transforms.ToTensor(), # [重点] 将图片转化为 Tensor 张量, 在 PyTorch 中, 一切的运算都基于张量, 请一定将你的输入数据转化为张量
# 请理解什么是张量 : 我们在线性代数中有向量的概念, 简单来说就是张量就是向量, 只不过张量往往具有更高的维度
# 而大家一般习惯将高于三维的向量称为张量, 某些人(比如我)也习惯所有的向量统称为张量
# 可以简单的将数组的维数来界定张量的维度
# 例如 [ ] 为一维张量, [[ ]] 为二维张量, [[[ ]]]为三维张量, [[[[ ]]]]为四维张量
# 对于图像来说, jpg 图像实际为三维矩阵, png 图像实际为四维矩阵, 这个维数是根据图像的通道数进行划分的
# 例如 jpg 有 R、G、B三个通道, png 具有
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 归一化, 可以对矩阵进行归一化
# 详细查看这个Blog : https://blog.csdn.net/qq_38765642/article/details/109779370
transforms.RandomErasing() # 随机擦除
])

valid_transform = transforms.Compose([
transforms.Resize((256, 256)), # Resize 操作, 将图片转换到指定的大小
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torch.utils.data import DataLoader

# 定义 Dataset 实例
train_dataset = CustomImageDataset(fake_dir="./dataset/train/fake", real_dir="./dataset/train/real", transform=train_transform)
valid_dataset = CustomImageDataset(fake_dir="./dataset/valid/fake", real_dir="./dataset/valid/real", transform=valid_transform)

# 创建 DataLoader 实例
# 这里将要涉及到超参数的概念, 什么是超参数: 简单来将, 超参数就是我们自己能指定的一些数据, 超参数的选择将很大程度上影响模型的性能
# 因此 深度学习领域的工程师 常称自己为 炼丹师、调参师等
batch_size = 32 # batch_size 就是一个超参数, batch 即为 “批次”, 表示一次使用 DataLoader 加载多少张图片进行运算
# 这个数值并不是越大越好, 也不是越小越好, 但是往往大一些比较好, 这个数字最大能选择多大和你的图片大小和显卡显存有很大的关系
# 当出现 [Out Of Memery] 错误时往往表明你选取了过大的 batch_size, 导致显卡出现了爆显存的问题
# batch_size : 每次训练时,模型所看到的数据数量。它是决定训练速度和内存使用的重要参数。
# shuffle : 是否在每个训练周期之前打乱数据集的顺序。这对于许多模型(如卷积神经网络)是很有帮助的,因为它可以帮助模型避免模式识别。
# sampler : 定义如何从数据集中抽样。默认情况下,它使用随机采样。但你可以使用其他更复杂的采样策略,如学习率调度采样。
# batch_sampler : 与sampler类似,但它在批处理级别上进行采样,而不是在整个数据集上。这对于内存使用效率更高的场景很有用。
# num_workers : 定义了多少个工作进程用于数据的加载。这可以加快数据加载的速度,但需要注意内存的使用情况。
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
1
2
3
4
5
6
7
8
9
10
11
# 查看Dataloader数据
# 为了了解Dataloader中的数据, 我们可以使用以下方法来查看:
# 使用 Python 的 len() 函数 : 我们可以直接通过 len() 函数获取 Dataloader 的长度, 即数据集中数据块的数量
# 使用 torch.utils.data.DataLoader.len() 方法 : 这个方法也会返回Dataloader的长度。
# 使用 iter() 函数:Dataloader是一个可迭代对象,我们可以直接通过iter()函数对其进行迭代,以获取每个批次的数据。
# 使用torchvision.utils.save_image()函数 : 如果我们正在处理的是图像数据集,那么可以使用这个函数来保存Dataloader中的图像数据。
len(train_loader) # 401
len(valid_loader) # 100
images, labels = next(iter(train_loader))
print(images)
print(labels)
  1. 构建模型
    • 构建模型是比较重要的一部分, 一般来说做好数据集之后, 最重要的事情就是修改模型, 通过训练结果改进模型, 判断自己的模型的正确性, 这里就是整个你要用到的神经网络的部分 , 需要注意的是 , 这里指定什么输入 , 推理的时候就要指定什么输入
    • 简单用几个符号说明一下就是: $^{Train} model (inputX, inputY, …)$ → $^{Valid} model (inputX, inputY, …)$
      • 如何确定输入是什么: 看 forward() 的输入是啥模型的输入就是啥
    • 我下面展现了我复现的 ResNet50 , 用这种方式可以顺便教你如何复现网络结构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch.nn as nn
from torch.nn import functional as F

# 这里是对 ResNet50 的实现, 请对照论文来进行对照阅读
# 定义 ResNet50Basic类, 这里并不是完整的模型, 而是模型的一个部分
class ResNet50BasicBlock(nn.Module):
def __init__(self, in_channel, outs, kernerl_size, stride, padding):
# super(ResNet50BasicBlock, self).__init__() 这里是干什么的?
# 1. 首先找到 ResNet50BasicBlock 的父类, 这里是 nn.Module
# 2. 把类 ResNet50BasicBlock 的对象self转换为 nn.Module 的对象
# 3. "被转换"的 nn.Module 对象调用自己的 init 函数
# 简单理解一下就是 : 子类把父类的 __init__ 放到自己的 __init__ 当中, 这样子类就有了父类的 __init__ 的那些东西
super(ResNet50BasicBlock, self).__init__()
# 这里只是定义部分, 在这里的定义并不一定会在推理过程中使用
self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernerl_size[0], stride=stride[0], padding=padding[0])
self.bn1 = nn.BatchNorm2d(outs[0])
self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernerl_size[1], stride=stride[0], padding=padding[1])
self.bn2 = nn.BatchNorm2d(outs[1])
self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernerl_size[2], stride=stride[0], padding=padding[2])
self.bn3 = nn.BatchNorm2d(outs[2])

# 输入是啥看 forward(), 例如这里是 forward(self, x), 则表示输入是 x, 也就是一个
def forward(self, x):
# nn.Conv2d 是卷积层, 请了解[1]什么是卷积层, 以及[2]卷积层是干啥用的, [3]卷积后会变成什么
# 卷积运算的目的是提取输入的不同特征, 第一层卷积层可能只能提取一些低级的特征如边缘、线条和角等层级, 更多层的网路能从低级特征中迭代提取更复杂的特征
out = self.conv1(x)
# [*] 什么是 ReLU, ReLU是激活函数, 请了解 [1]什么是激活函数, [2]为什么要使用激活函数
# [*] 什么是 Batch Normalization层, BN 层是批次归一化层
out = F.relu(self.bn1(out))
out = self.conv2(out)
out = F.relu(self.bn2(out))
out = self.conv3(out)
out = self.bn3(out)
return F.relu(out + x)


# 定义 ResNet50DownBlock类, 这里并不是完整的模型, 而是模型的一个部分
class ResNet50DownBlock(nn.Module):
def __init__(self, in_channel, outs, kernel_size, stride, padding):
super(ResNet50DownBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channel, outs[0], kernel_size=kernel_size[0], stride=stride[0], padding=padding[0])
self.bn1 = nn.BatchNorm2d(outs[0])
self.conv2 = nn.Conv2d(outs[0], outs[1], kernel_size=kernel_size[1], stride=stride[1], padding=padding[1])
self.bn2 = nn.BatchNorm2d(outs[1])
self.conv3 = nn.Conv2d(outs[1], outs[2], kernel_size=kernel_size[2], stride=stride[2], padding=padding[2])
self.bn3 = nn.BatchNorm2d(outs[2])

self.extra = nn.Sequential(
nn.Conv2d(in_channel, outs[2], kernel_size=1, stride=stride[3], padding=0),
nn.BatchNorm2d(outs[2])
)

def forward(self, x):
x_shortcut = self.extra(x)
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out)

out = self.conv3(out)
out = self.bn3(out)
return F.relu(x_shortcut + out)


class ResNet50(nn.Module):
def __init__(self):
super(ResNet50, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

# Sequential 类是 torch.nn 模块中的一个容器, 可以将多个层封装在一个对象中, 方便顺序连接
self.layer1 = nn.Sequential(
ResNet50DownBlock(64, outs=[64, 64, 256], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50BasicBlock(256, outs=[64, 64, 256], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
)

self.layer2 = nn.Sequential(
ResNet50DownBlock(256, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50BasicBlock(512, outs=[128, 128, 512], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50DownBlock(512, outs=[128, 128, 512], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0])
)

self.layer3 = nn.Sequential(
ResNet50DownBlock(512, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50BasicBlock(1024, outs=[256, 256, 1024], kernerl_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50DownBlock(1024, outs=[256, 256, 1024], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0])
)

self.layer4 = nn.Sequential(
ResNet50DownBlock(1024, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 2, 1, 2], padding=[0, 1, 0]),
ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0]),
ResNet50DownBlock(2048, outs=[512, 512, 2048], kernel_size=[1, 3, 1], stride=[1, 1, 1, 1], padding=[0, 1, 0])
)

self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1, ceil_mode=False)
# self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

self.fc = nn.Linear(2048, 10)
# 使用卷积代替全连接
self.conv11 = nn.Conv2d(2048, 10, kernel_size=1, stride=1, padding=0)

def forward(self, x):
out = self.conv1(x)
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
# avgpool 平均池化层, 了解什么是平均池化层
out = self.avgpool(out)
out = self.conv11(out)
out = out.reshape(x.shape[0], -1)
# out = self.fc(out)
return out

# 这里展现了对 ResNet 的一个具体的应用
# x = torch.randn(1, 3, 224, 224) # 这个是我们 ResNet50 期待的输入样子, 可以看到他是 [1] 个 [3] 通道, 宽度为[224], 高度为 [224]的张量
image_path = './dataset/test/fake/test_fake_1.png'
image = Image.open(image_path).convert('RGB') # 图片加载
transform = transforms.ToTensor() # 将图片转化为张量, 此时的 张量的形状为[3, 1024, 1024]
# 当输入数据的维度不足时, 我们可以通过 unsqueeze() 添加维度, 这个东西简单理解一下就是, 在某个维度外面加括号[], 即可拓展出更高的维度
img_tensor = transform(image).unsqueeze(dim=0)

# print(x.shape) 我们可以使用 shape 来查看一个张量的形状
# print(img_tensor.shape)

# 这里加载我们的网络架构
net = ResNet50()

# 这里进行输入, 输入 img_tensor, 进入 forword() 部分, 然后得到最终输出的结果
out = net(img_tensor)
print(out)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torchvision import models

# 这里为了方便, 我们直接加载 PyTorch 预训练好的 ResNet50 的模型
# PyTorch 已经为我们提供了不少已经预训练好的模型, 我们只需要加载他们与训练好的模型即可
# 但是我还是希望你可以掌握上面这种自定义模型的方法, 这样遇到 PyTorch 未提供的模型, 我们也可以尝试自己实现该模型
model = models.resnet50(pretrained=True)

# 冻结参数 : 即不更新模型的参数
# 可以看到下面的代码, 这里表示冻结了所有层
for param in model.parameters():
param.requires_grad = False

# 但是我们可以通过替换层来接触某些层的冻结
num_ftrs = model.fc.in_features # 这里是获取 ResNet50 的 fc 层的输入特征数
model.fc = torch.nn.Linear(num_ftrs, 2) # 这里是对 fc 层进行修改, Linear(input_feather_num, output_feather_num)
# 这里输入特征数是 num_ftrs, 输出特征数为 2

# 这一行很重要, 指定了模型的位置, cuda 可以理解为 GPU 设备, cuda: 0 表示使用编号为 0 的GPU进行训练
# 当有多块 GPU 时, 可以用其他的方式指定 GPU
# model = torch.nn.DataParallel(model, device_ids=[0, 1, 2]), 当然向我们这种小白(穷B), 当然还是单卡为主
# 为了避免出现多卡的情况, 我在下面放入两篇博客, 有兴趣可以参考这两篇文章进行多卡训练
# https://zhuanlan.zhihu.com/p/102697821
# https://blog.csdn.net/qq_34243930/article/details/106695877
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
  1. 训练模型 + 验证模型
    • 这里需要直接对模型进行训练 , 一般来说 , 在训练的过程中我们会加入 tqdm 库使得训练过程可视化 , 有时我们还会在训练过程中保存更好的训练结果 , 并且设置断点训练等操作 , 我只使用最简单的方式进行预测
    • train 部分的代码因人而异, 基本上每个人的写法都可能不同, 没有固定的写法
    • 对于训练完的模型我们需要对其进行评价, 一般来说, 训练和验证都是放在一起的, 不可分开的
    • 记得保存一下训练后的模型, 使用如下代码保存/加载整个模型
      1
      2
      3
      4
      5
      6
      # 保存模型
      model_path = "xxxx.pth" # xxxx 表示一个你喜欢的名字
      torch.save(model, model_path) # 使用 torch.save(model, model_path) 保存模型

      # 加载模型
      model = torch.load(model_path) # 使用 torch.load(model_path) 即可加载模型

完整的”训练模型 + 验证模型”代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score

# 定义损失函数和优化器
# 这里包含了 PyTorch 的 19 种损失函数 https://blog.csdn.net/qq_35988224/article/details/112911110
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

# 计算 F1 值和 准确率
def evaluate(loader, model):
preds = []
targets = []
loop = tqdm(loader, total=len(loader), leave=True)
for images, labels in loop:
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
outputs = model(images)
_, predicted = torch.max(outputs, 1)
preds.extend(predicted.cpu().numpy())
targets.extend(labels.cpu().numpy())

# Update the progress bar
loop.set_description("Evaluating")
return f1_score(targets, preds), accuracy_score(targets, preds)

# 训练循环
best_f1 = 0.0
loss_values = []
num_epochs = 10 # 定义训练的轮次
for epoch in range(num_epochs):
model.train() # 将模型设置为训练模式
loop = tqdm(train_loader, total=len(train_loader), leave=True)
print(loop)
for images, labels in loop:
images, labels = images.to(device), labels.to(device)

# 前向推理
outputs = model(images)
loss = criterion(outputs, labels)

# 反向传播及优化
# 在用 PyTorch训练模型时, 通常会在遍历 Epochs 的过程中依次用到
# optimizer.zero_grad() : 先将梯度归零
# loss.backward() : 反向传播计算得到每个参数的梯度值
# optimizer.step() : 通过梯度下降执行一步参数更新
# 对于这三个函数, 这篇博客写的很好 : https://blog.csdn.net/PanYHHH/article/details/107361827
# 可以简单阅读一遍
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 保存该批次的损失
loss_values.append(loss.item())

# 更新进度条
loop.set_description(f"Epoch [{epoch + 1}/{num_epochs}]")
loop.set_postfix(loss=loss.item())

# 在每轮之后验证模型
model.eval() # 将模型设置为推理模式, 此时模型中的参数不会进行更新, 即完全用于推理/验证
f1_value, accuracy = evaluate(valid_loader, model)
print(f'F1 score: {f1_value:.4f}, Accuracy: {accuracy:.4f}')

# 保存 F1 值最高的模型
if f1_value > best_f1:
best_f1 = f1_value
# 这里和上面 Markdown 的保存方式不同, model.state_dict(), 表示模型的参数, 简单来说呢我们仅仅保存了模型的参数, 但是我们并没有保存模型的结构
# 上面 Markdown 的保存方式是即保存了整个模型的结构, 也保存了模型的参数
torch.save(model.state_dict(), 'best_model.pth')
print('训练结束')

当然我们也可以使用绘图函数,来展示过程中的相关数据。

1
2
3
4
5
6
7
8
9
10
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 8))
plt.plot(loss_values, label='Train Loss')
plt.title('Loss values over epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
  1. 推理模型
    • 很高兴, 如果你到这一步, 你的水平肯定已经有了质的飞跃, 这里已经是最后一步了, 结束这个部分, 你就要开始自己的探索之路了
    • 推理模型很简单, 我在上面说过, 构造模型时指定什么输入 , 推理的时候就要指定什么输入, 这里就是对应的部分了
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torchvision.transforms import ToTensor, Resize, Normalize

# predict_by_file 表示推理一个文件, 我们需要传入文件路径以及模型
def predict_by_file(file_path, model):
#
image = Image.open(file_path).convert('RGB')

# 这里的 transform 有与没有都无所谓, 纯看心情
transform = transforms.Compose([
Resize((256, 256)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

image = transform(image)

# 这里和上面一样, 表示在最外面加一层括号, 使 [3, 256, 256] 变为 [1, 3, 256, 256]
image = image.unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
outputs = model(image) # 模型推理
# torch.max(...)
# input (Tensor) – 输入张量
# dim (int) – 指定的维度
_, predicted = torch.max(outputs, 1) # 返回指定维度的最大值, 其实这里只有一维
print(outputs) # tensor([[0.7360, 0.2668]], device='cuda:0')
print(outputs.shape) # torch.Size([1, 2])
return "Fake" if predicted.item() == 0 else "Real"

path = './dataset/test/real/test_real_7.jpg'
print(predict_by_file(path, model))