-
生成对抗网络(generative adversarial nets, GAN)[1]是受零和博弈的思想启发而提出的一种新颖的生成模型框。它由一个生成网络与一个判别网络组成,通过让两个神经网络相互博弈、相互对抗的方式进行学习,最终达到纳什均衡。GAN通常用于生成以假乱真的图片[2]、影片、音频[3]、3D模型、文本[4]等等,在诸多领域都取得了显著的成效。
尽管如此,GAN还有一些突出的问题有待研究,比如模式崩溃问题。普遍的看法是因为数据的支持和生成的分布是不相交的或位于低维流形中[5]。根据Monge-Ampere方程的正则性理论,如果目标度量的支持是断开的或只是非凸的,则最优转换映射是不连续的。而通用DNN只能近似连续映射,这种内在冲突导致了模式崩溃[6]。对于这个问题,前期也有很多研究,如通过优化使得网络具有更加优异的学习能力,从而能够学习具有一般性的特征,而不是集中于某种特异性特征;或是通过控制损失函数,逼迫模型学习更多类型的特征。本文从多生成器博弈的角度出发,通过对现有的多生成器模型的一系列改进来促使不同的生成器生成不同的模式数据,达到有效解决模式崩溃问题的目的。现有的多生成器模型的基本思路都是使用多个生成器的联合分布去模拟样本的真实分布,多个生成器网络参数共享或者不进行共享,通过引入分类器来最大化各个生成器生成数据的JS散度,强制不同生成器去捕获不同的模式,取得了较好的效果。Multi-generator GAN (MGAN)是其中效果较好的网络,但也存在一些问题。因为MGAN的损失函数是在原GAN的损失函数基础上添加一个最大化生成器样本差异的正则项,该正则项主要是对多个生成器系统整体进行约束,而从单个生成器的角度出发,模式崩溃问题还是存在的。并且,由于GAN的损失函数的缺陷[7]也会造成生成样本的质量在达到一定水平后,继续训练生成质量反而会下降的不稳定现象。
针对MGAN上述的问题,本文主要优化思路如下:
1) 使用Wasserstein距离作为多个生成器与判别器间的博弈损失函数,改善训练过程中的梯度消失、训练不稳定等问题。
2)引入一个正则惩罚项使得损失函数可以更好地满足Lipschitz连续,从而使得梯度可以向着更快和更好的角度前进,同时也在一定程度上避免了梯度消失和过拟合带来的影响。
3)引入一个超参数来平衡多角度损失函数带来的差异性,避免过度偏向其中某一种梯度方向。
4)提出了一种多生成器参数共享策略,减少训练代价的情况下同时提高了网络的性能,方便各个生成器独立处理图像的高维特征。
-
本文提出了一种采用多生成器架构的生成对抗网络模型 ( improved-MGAN, IMGAN),网络结构如图1所示。图中z表示随机噪声,Gk为k个生成器。
-
本文的参数共享策略是在保持前置卷积神经网络参数共享的基础上,对网络的最后一层卷积和全连接层进行了独立训练,即除了网络的全连接层和最后一层卷积输出参数外,网络的其他层参数都共享,在减少训练代价的情况下同时提高了网络的性能。
-
在这种网络架构下,多个生成器将输入的随机噪声转化为图片;判别器对接收到的图片进行区分,判断是生成器生成的图片还是样本集中的图片;分类器对多个生成器生成的样本进行区分,判断是由哪个生成器生成,评估不同生成器生成样本的相似性。经过多个生成器、判别器、分类器之间的多方博弈,最终达到纳什均衡。
-
为了优化典型的多生成器网络如MGAN易出现的梯度消失、训练不易收敛等问题[15],引入WGAN-GP的Wasserstein距离的损失函数作为IMGAN模型中多个生成器与判别器间的博弈的损失函数。
引入WGAN-GP的损失函数后,此时判别器的输出结果是样本图片分布与生成的图片分布间的Wasserstein距离的近似,较之原模型的判别器的输出结果在度量上发生了变化,因此引入一个参数项
${\lambda _{{C}}}$ 来平衡判别器与分类器对网络公共部分的影响:$${{{L}}_{{\rm{total}}}} = {L_D} + {\lambda _C}{L_C}$$ 式中,
${L_D}$ 为判别器的损失;${L_C}$ 为分类器的损失;${L_{{\rm{total}}}}$ 为判别/分类网络的损失函数。由于本文是对每个样本独立地施加梯度惩罚,为防止引入同一批次样本间的相互依赖关系,本文也按照WGAN-GP的思路对判别网络的结构进行了调整,去掉了判别器/分类器网络的批量归一化以及判别器最后一层的激活函数Sigmoid。
-
为了避免模式崩溃,希望不同的生成器生成的样本之间有明显的差异,所以分类器的损失函数需要引导不同的生成器生成差异较大的样本,采用交叉熵来衡量不同生成器生成样本的差异,分类器损失函数为:
$${\rm{Los}}{{\rm{s}}_C} = \sum\limits_{k = 1}^K {{\pi _k}{E_{x\sim {P_{{G_k}}}}}[\log {C_k}(x)]} $$ (1) 式中,
${\pi _k}$ 为第k个生成器生成的分布在多个生成器形成的联合分布中的权重;${{P}_{{{G}_k}}}$ 为第k个生成器生成的分布;${{C}_k}(x)$ 为样本来自第k个生成器的概率。由式(1)可知,当各生成器生成的样本差异较大,分类器易于区分时,损失较小;当各生成器生成的样本较为接近,分类器难以区分时,损失较大,由此可以促使不同的生成器生成不同的样本。 -
为了使模型训练过程更加稳定,生成器的损失函数采用WGAN生成器的损失函数:
$${\rm{Los}}{{\rm{s}}_G} = - {E_{x\sim {P_{\operatorname{model} }}}}D(x) + \beta {\operatorname{Loss} _c}$$ 式中,
${{\rm{P}}_{{\rm{model}}}}$ 代表多个生成器生成的联合分布;D(x)为判别器的判别结果,在WGAN的损失函数中不需要取对数。生成器损失函数由两部分组成,前一项为GAN的经典损失,用于促使生成器生成的图片与真实样本更接近,后一项是前面提到的分类器损失函数,用于使生成器生成尽可能差异化的样本,两部分通过参数$\beta $ 进行调节,通过该损失函数来提升生成器生成结果的质量和多样性。 -
为了应对训练过程中出现的梯度消失的问题,本文模型的判别网络部分的损失函数采用了WGAN-GP中的判别器损失函数:
$$\begin{split}&{\rm{Los}}{{\rm{s}}_D} = {{E}_{{{x\sim }}{{\rm{P}}_{{\rm{data}}}}}}D(x) - {{\rm E}_{x\sim {P_{{\rm{model}}}}}}D(x) + \\ &\quad\quad\quad{\lambda _{{\rm{gp}}}}{{ E}_{x\sim {P_{{\rm{gp}}}}}}{(\left\| {{\nabla _x}D(x)} \right\| - 1)^2}\end{split}$$ 式中,
${P_{{\rm{model}}}}$ 是多个生成器生成的联合分布;${\lambda _{{\rm{gp}}}}$ 是梯度惩罚项的参数。最后一项梯度惩罚项使判别器满足Lipschitz约束,能够平滑判别器的参数,有效缓解WGAN收敛困难的问题。 -
本文通过引入残差块[16]将构成MGAN网络的基本单元进行替换,解决原网络中存在的随着训练轮数的增加活性神经元的比例会逐渐下降的问题。在同等网络深度下,残差网络不仅具有更小的参数量,还能够进一步提高模型生成图像的质量。
-
生成网络的结构包含输入、反卷积、激励、输出几层。网络结构中都采用了批量归一化操作来代替池化层以避免一些有用的特征丢失和整体与部分关联关系被忽略的问题。卷积神经网络中的反卷积操作由残差块通过上采样完成。
多个生成器采用了参数共享机制,输入层到第一层全连接层以及最后一层反卷积层参数不共享,其余层参数都共享。各个生成器的数据批量归一化分开进行[17],网络结构如图2所示。
-
判别卷积神经网络/分类卷积神经网络同样采用参数共享,最后一层参数不共享,网络的其余层参数都进行共享。两个网络由卷积、池化、激励和输出几层构成。由于采用了WGAN-GP的损失函数,所以不需要对判别器的数据进行批量归一化,去掉了判别器的最后一层Sigmoid激活函数。判别/分类网络中的卷积操作通过下采样残差块完成,其结构如图3所示。
-
本文实验选取了Cifar10和CelebA两个数据集对本文的模型进行验证。Cifar10数据集提供了60000张大小为32*32像素的彩色图片,分为10类,每类包含6000张图片,是开放的物体识别数据集。CelebA包含了10177个名人的共202599张做了特征标记和属性标记的人脸图片。
-
实验需要对本文提出的IMGAN与典型的多生成器模型MGAN进行对比,首先要排除超参数对实验的影响,受限于实验条件,未寻找模型在某一数据集上的最优值,而是采用了相关文献给出的较优值。相关参数的取值如表1所示。
表 1 实验参数设置
环境项 参数 生成器数量 10 ${\lambda _{gp}}$ 10 $({\pi _1},{\pi _2},\cdots,{\pi _K})$ ${1}/{K}$ 学习率 0.02 由于在生成对抗网络中,损失函数输出的损失值并不能直接代表生成图片的质量,即使通过训练,损失值已经很小了,但实际生成的图片仍然和真实图片相去甚远,所以本文引入了GAN生成质量的常用评价标准(Fréchet inception distance, FID)来对IMGAN模型的生成效果进行评价。FID使用均值和协方差矩阵来计算两个分布之间的距离:
$${\rm{FID(}}x{\rm{,}}g{\rm{) = }} {\mu _x} - {\mu _g} _{\rm{2}}^{\rm{2}} + {\rm{Tr}}\left({{\boldsymbol{\varSigma}} _x} + {{\boldsymbol{\varSigma}} _g} - 2{\left({{\boldsymbol{\varSigma}} _x}{{\boldsymbol{\varSigma}} _g}\right)^{1/2}}\right)$$ 式中,x为真实图片分布;g为生成图片分布;
$\mu $ 为均值;Σ为协方差;Tr为矩阵的迹,即矩阵对角线上元素的总和。FID值越低,两个分布越接近,说明生成图片的质量较高、多样性较好。 -
在Cifar10数据集上,分别测试了:1)单独更改参数共享方案,解绑模型的最后一层参数;2)单独更改损失函数,使用Wasserstein距离;3)在引入1)、2)优化的基础上再更改网络结构,引入残差块这3种场景来验证本文优化方法的效果,计算这3种场景的FID值来进行评估。采用Adam优化器,设置初始学习率为0.02,随训练轮数的增加递减,设置Adam优化器的衰减参数
$\,{\beta _1}{\rm{ = }}0.5$ ,$\,{\beta _2}{\rm{ = }}0.90$ ,设置多样性调节参数$\,\beta =0.05$ 。引入残差块后网络中多个生成器的结构如表2所示,判别/分类网络结构如表3所示。
表 2 Cifar10上多个生成器网络配置
模块 卷积核 采样 激活函数 参数共享 Noise − − − − Linear − − ReLU n Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Conv2D 3×3 − tanh n 实验结果FID值如表4所示。从实验结果可以看出,本文策略确实能够有效降低FID值,性能较MGAN有了明显的提升。
表 3 Cifar10上判别/分类器网络配置
模块 卷积核 采样 激活函数 参数共享 Image − − − − Residual Block [3×3]×2 DownSample ReLU y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 None − y Residual Block [3×3]×2 None − y Average Pooling − − ReLU y Linear − − Softmax/− n 表 4 Cifar10上IMGAN不同优化策略效果
模型 FID MGAN 30.7 更改参数共享 28.57 更改损失函数 26.35 更改参数共享、损失函数、网络结构 22.6 从图4的两种模型生成的图片对比来看,IMGAN较之MGAN生成的图片直观上体现了较大的差异性,也没有出现单个生成器的模式崩溃问题,体现出了更好的生成效果。
-
在CelebA数据集上的实验中,采用FID值来对模型的表现进行评价。网络优化同样采用Adam优化器;设置初始学习率为0.02,随训练轮数增加递减,设置Adam优化器的衰减参数
$\,{\beta _1}{\rm{ = }}0.00$ ,$\,{\beta _2}{\rm{ = }} $ $ 0.90$ ,设置超参数${\lambda _{{C}}} = 0.90$ 。由于CelebA的属性标记比Cifar10更加复杂,将调节模型生成样本多样性的超参数进一步增大,设置$\,\beta {\rm{ = }}0.10$ 。网络中多个生成器的结构如表5所示,判别/分类网络结构如表6所示。
表 5 CelebA上多个生成器网络配置
模块 卷积核 采样 激活函数 参数共享 Noise − − − − Linear − − ReLU n Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Conv2D 3×3 − tanh n 同在Cifar10数据集上的实验一样,分别测试:1)单独更改参数共享方案,解绑模型的最后一层参数;2) 单独更改损失函数,使用Wasserstein距离;3)在引入1)、2)优化的基础上再更改网络结构,引入残差块这3种场景来验证本文优化方法的效果。模型收敛时,FID指标对比如表7所示。
表 6 CelebA上判别/分类器网络配置
模块 卷积核 采样 激活函数 参数共享 Image − − − − Conv2D 3x3 − ReLU y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 DownSample − y Linear − − Softmax/− n 表 7 两模型在CelebA上的对比实验评估指标
模型 FID MGAN 9.263 更改参数共享 9.127 更改损失函数 8.994 更改参数共享、损失函数、网络结构 8.584 模型迭代100000轮后,得到的生成样本对比如图5所示。直观上来观察模型的生成效果,IMGAN生成的人脸更加清晰和真实,FID值也比原模型下降了0.679,这说明本文的模型在CelebA数据集上能够进一步提高生成图片的质量。
A Generation Adversarial Network Based on Multi-Condition Confrontation and Gradient Optimization
-
摘要: 该文针对模式崩溃的问题,从多生成器博弈强迫每个生成器生成不同模式数据的思路出发,提出了一种基于多生成器的生成对抗网络(IMGAN)。IMGAN在多个生成器之间采用参数共享的方式来加速训练,同时采用最后一层独立训练的方式来弱化参数同一性所带来的影响;引入一个正则惩罚项使得损失函数可以更好地满足Lipschitz连续,一定程度上避免了梯度消失带来的影响;引入一个超参数来解决多重损失函数带来的差异性问题,避免过度偏向其中某一种梯度方向。最后,通过在多个数据集上的对比实验验证了该文模型的表现和性能。Abstract: Aiming at the problem of pattern collapse, this paper starts from the idea of forcing each generator to generate different pattern data in a multi-generator game, and proposes a multi-generator-based generation confrontation network, named improved multi-generator generative adversarial nets (IMGAN). IMGAN uses parameter sharing between multiple generators to speed up training, and at the same time uses the last layer of independent training to weaken the impact of parameter identity; introduces a regular penalty term to make the loss function better satisfy Lipschitz continuousness, which avoids the effect of gradient disappearance to a certain extent; and introduces a hyperparameter to solve the disparity problem caused by multiple loss functions and avoid excessive bias toward one of the gradient directions. At last, we verify the performance of our model through comparative experiments on multiple data sets.
-
表 1 实验参数设置
环境项 参数 生成器数量 10 ${\lambda _{gp}}$ 10 $({\pi _1},{\pi _2},\cdots,{\pi _K})$ ${1}/{K}$ 学习率 0.02 表 2 Cifar10上多个生成器网络配置
模块 卷积核 采样 激活函数 参数共享 Noise − − − − Linear − − ReLU n Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Conv2D 3×3 − tanh n 表 3 Cifar10上判别/分类器网络配置
模块 卷积核 采样 激活函数 参数共享 Image − − − − Residual Block [3×3]×2 DownSample ReLU y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 None − y Residual Block [3×3]×2 None − y Average Pooling − − ReLU y Linear − − Softmax/− n 表 4 Cifar10上IMGAN不同优化策略效果
模型 FID MGAN 30.7 更改参数共享 28.57 更改损失函数 26.35 更改参数共享、损失函数、网络结构 22.6 表 5 CelebA上多个生成器网络配置
模块 卷积核 采样 激活函数 参数共享 Noise − − − − Linear − − ReLU n Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Residual Block [3×3]×2 UpSample − y Conv2D 3×3 − tanh n 表 6 CelebA上判别/分类器网络配置
模块 卷积核 采样 激活函数 参数共享 Image − − − − Conv2D 3x3 − ReLU y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 DownSample − y Residual Block [3×3]×2 DownSample − y Linear − − Softmax/− n 表 7 两模型在CelebA上的对比实验评估指标
模型 FID MGAN 9.263 更改参数共享 9.127 更改损失函数 8.994 更改参数共享、损失函数、网络结构 8.584 -
[1] GOODFELLOW I, POUGET-ABADIE J, MIRZA M, et al. Generative adversarial nets[J]. Advances in Neural Information Processing Systems, 2014, 27: 2672-2680. [2] LIU G, REDA F A, SHIH K J, et al. Image inpainting for irregular holes using partial convolutions[EB/OL]. [2020-10-11]. https://arxiv.org/pdf/1804.07723.pdf. [3] DONG H W, HSIAO W Y, YANG L C, et al. Musegan: Multi-track sequential generative adversarial networks for symbolic music generation and accompaniment[EB/OL]. [2020-10-15]. https://arxiv.org/pdf/1709.06298v2.pdf. [4] RAJESWAR S, SUBRAMANIAN S, DUTIL F, et al. Adversarial generation of natural language[EB/OL]. [2020-10-20]. https://arxiv.org/pdf/1705.10929.pdf. [5] NARAYANAN H, MITTER S. Sample complexity of testing the manifold hypothesis[C]//Proceedings of the 23rd International Conference on Neural Information Processing Systems. Vancouver, Canada: [s.n.], 2010, 2: 1786-1794. [6] LEI N, GUO Y, AN D, et al. Mode collapse and regularity of optimal transportation maps[EB/OL]. [2020-10-15]. https://arxiv.org/pdf/1902.02934v1.pdf. [7] ARJOVSKY M, BOTTOU L. Towards principled methods for training generative adversarial networks[EB/OL]. [2020-10-15]. https://arxiv.org/pdf/1701.04862.pdf. [8] TOLSTIKHIN I, GELLY S, BOUSQUET O, et al. Adagan: Boosting generative models[EB/OL]. [2020-10-18]. https://arxiv.org/pdf/1701.02386.pdf. [9] IM D J, KIM C D, JIANG H, et al. Generating images with recurrent adversarial networks[EB/OL]. [2020-10-15]. https://arxiv.org/pdf/1602.05110.pdf. [10] IM D J, MA H, KIM C D, et al. Generative adversarial parallelization[EB/OL]. [2020-10-15]. https://arxiv.org/pdf/1612.04021.pdf. [11] LI D, CHEN D, JIN B, et al. Madgan: Multivariate anomaly detection for time series data with generative adversarial networks[EB/OL]. [2020-10-12]. https://arxiv.org/pdf/1901.04997.pdf. [12] ARJOVSKY M, CHINTALA S, BOTTOU L. Wasserstein generative adversarial networks[C]//International Conference on Machine Learning. [S.l.]: PMLR, 2017: 214-223. [13] GULRAJANI I, AHMED F, ARJOVSKY M, et al. Improved training of wasserstein gans[EB/OL]. [2020-10-15]. arXiv preprint arXiv: 1704.00028, 2017. [14] HOANG Q, NGUYEN T D, LE T, et al. Multi-generator generative adversarial nets[EB/OL]. [2020-10-12]. https://arxiv.org/pdf/1708.02556.pdf. [15] ARJOVSKY M, BOTTOU L. Towards principled methods for training generative adversarial networks[EB/OL]. [2020-10-15]. https://arxiv.org/pdf/1701.04862.pdf. [16] HE K, ZHANG X, REN S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. [S.l.]: IEEE, 2016: 770-778. [17] ZHANG H, GOODFELLOW I, METAXAS D, et al. Self-attention generative adversarial networks[C]//International Conference on Machine Learning. [S.l.]: PMLR, 2019: 7354-7363.