你觉得这篇文章怎么样? 帮助我们为您提供更好的内容。
Thank you! Your feedback has been received.
There was a problem submitting your feedback, please try again later.
你觉得这篇文章怎么样?
作者 |
嘉钧 |
难度 |
理论困难,实作普通 |
材料表 |
|
遇到生成的问题一定是要找现在最流行的GAN,那今天我们除了要让大家了解GAN是什么之外,我们也要来挑战Jetson Nano的极限,上次已经用Nano跑猫狗分类发现对于它来说已经相当的艰辛了,这次要给它更艰难的任务,我们要让计算机学习如何生成手写数字的图片!
GAN 基本观念
GAN 是生成对抗网络 (generative adversarial network, GAN) 是一个相当有名的神经网络模型也是一个人见人怕的模型,原因是因为它非常的难训练、耗费的时间也很久,它常常被用于深度学习中的生成任务,不管是图像还是声音都可以。
Deep Fake |
Image Source : https://aiacademy.tw/what-is-deepfake/ |
CycleGAN |
Image Source : https://github.com/junyanz/CycleGAN |
核心观念可以这样想象,GAN就像「收藏家」与「画假画的人」,G是画假画的,D是收藏家;一开始画假画的人技术还不成熟,所以一眼就被收藏家发现问题,所以G就回去苦练画功慢慢开始能骗过D,这时候D被告知买的都是假画所以他也开始进步越来越能发现假画的瑕疵,就这样反复的交手成长,等到最后互相僵持不下的时候就达到我们的目的了 - 获得一个很会画假画的G,或者称为很会生成图片的G。
在GAN当中,我们可以输入一组 噪声 (Noise) 或称 潜在空间 (Latent Space),然后生成器 (Generator) 会将那组噪声转换成图片,再经由鉴别器 (Discriminator) 分辨是否是真实的图片,鉴别器将会输出0或1,其中 0 是 fake,1 是 real 。
DCGAN 架构介绍
一般的GAN都是 FC ( linear ),需要将图片reshape成一维做训练,以 mnist 来看的话就是 1*28*28 变成 1*784,而今天我们要介绍的是DCGAN,是利用卷积的方式来建构生成器与鉴别器,它跟一般的CNN有些许不同,作者有列出几个要改变的点:
- 取消所有pooling层。G使用转置卷积(transposed convolutional layer)进行上取样而D用加入stride的卷积代替pooling。
- 在 D 和 G 中均使用 batch normalization。
- 去掉 FC 层,使网络变为 全卷积网络 (Conv2D)。
- G使用 ReLU 作为激励函数,最后一层须使用 tanh。
- D使用 LeakyReLU 作为激励函数。
等等实作程序的时候可以再回来观察是不是都符合需求了,那基础观念的部分先带到这我们直接开始写code吧!
准备数据集
今天我们要尝试生成手写数字,最常使用的就是MNIST,由于它太常用了可以直接在torch中就下载到,在torchvision.datasets里面,由于它已经是数据集了所以可以先宣告transform套入,接着就可以打包进DataLoader里面进行批次训练
可以注意到图片已经是灰阶图并且大小为 ( 1, 28, 28)
建构鉴别器
鉴别器的架构就是CNN只差要取消全连阶层的部分,所以要想办法将28*28的图片卷积成 1*1输出
最后因为输出的label是介于 [ 0 , 1 ] 所以最后要透过Sigmoid来收敛。这边直接使用官方提供的架构可以发现跟以往不同的地方是利用nn.Sequential将所有层在initial的时候就连接起来,你会发现更以前相比简洁很多;此外我还使用了torchsummary的函式库可视化网络架构:
使用torchsummary之前需要先安装套件
!pip3 install torchsummary
然后记得要先导入
from torchsummary import summary
建构生成器
生成器一样要全用卷积的方式,但是正如前面所说我们要将一组噪声转换成一张图片,用一般的卷积方式只会越来越小,所以我们必须使用反卷积 ( ConvTranspose2d ),可以想象就是把卷积反过来操作就好,大小一样要自己算记得最后输出应该跟输入图片一样大,公式如下:
建置生成器的架构与CNN颠倒,Kernel的数量要从大到小,然后我们不使用MaxPooling改用ConvTranspose,每层都要有BatchNorm并且除了最后一层的激励函数是Tanh之外其他都是ReLU。
我们一样使用 torchsummary 来可视化网络架构:
开始训练GAN
其实训练GAN的方法不难,只要想象成是训练两个神经网络即可。GAN的训练方式较常见的会先训练鉴别器,让鉴别器具有一定的判断能力后再开始训练生成器;首先一样先建立基本的参数。
由于GAN的训练较复杂,需要的迭代次数也越高,这次我们直接训练个100次看成效如何,此外这次的损失函数都用BCELoss 因为主要是二元分类问题,对于想要了解更深入的可以去看台大李弘毅教授GAN的教学影片。
训练Discriminator
训练鉴别器的目的是让自己能分辨真实图片跟造假图片,首先,先将优化器的梯度清除避免重复计算,接者将真实图片丢进GPU并给予标签 ( 1 ),将其丢进D去预测结果并计算Loss,最后丢进倒传递中;假的图片也是一样的部分,差别只在于假图片需要由生成器生成出来,所以要先定义一组噪声并丢入生成器产生图片,因为要让鉴别器知道这是假的所以要给予标签 ( 0 ),接着一样将假图片丢进鉴别器并计算Loss。经过反复的训练鉴别器就越来越能判断真假照片了,但是同时生成器也会训练,所以当鉴别器越强的时候,生成器产生的图片也会越好!
这个部分要注意的地方是经过鉴别器训练出来的答案维度是 [ batch ,1, 1, 1],所以需要过一个view(-1)来将维度便形成 [batch, ] 一维大小。
训练生成器
如果已经看懂鉴别器的训练方式了,那生成器对你来说就不是个问题!我们一样要先将梯度归零,随机数生成噪声并且透过生成器产生图片,这边要注意的是因为生成器的目标是要骗过鉴别器,所以我们要给予真实的标签 ( 1 ),这样做的目的是,生成器一开始产生的图片很差所以鉴别器得出的结果都接近于0,这时候如果标签是1的话,计算Loss数值将会很大,神经网络就会知道这样不是我们想要的,它会再想办法生成更好的图片让Loss越来越小。
所以完整的训练如下,先训练D再训练G,有的人会让D先训练个几次再训练G,听说效率比较高;而我为了测试jetson Nano的速度所以每一个epoch都有纪录时间,就连我自己的GPU( 1080 ) 都会爆显存更不用说用Jetson Nano来跑了!如果遇到显存爆炸的问题可以尝试先将batch size调小。
成果
你可以注意到它慢慢能转换成数字了,转换的速度其实很快但要更细节的纹路就需要更多时间来训练。
训练时间比较
使用GPU 1080 训练,每一个epoch约耗时310秒左右;而JetsonNano开启cuda来跑大概每一个epoch约耗时1030秒左右。所以其实要在Nano上面运行GAN也是可行的,速度还算可以。
桌面计算机 |
Jetson Nano |
结语
这篇教大家如何建构DGAN,接下来我们将会在Jetson Nano上尝试更多GAN相关的训练,下一篇将让计算机玩填色游戏~