AI算法Python实现:BGAN(Boundary Seeking GAN)

AI算法Python实现:BGAN(Boundary Seeking GAN)文章介绍了生成对抗网络 GAN 的基本原理 包括生成器和判别器的相互作用

大家好,欢迎来到IT知识分享网。

目录

一、原理

1.1 GAN简单介绍

1.2 Boundary Seeking原理

1.2 BGAN原理

二、算法实现

三、参考资料


一、原理

1.1 GAN简单介绍

GAN(Generative Adversarial Network)主要是作为一种生成模型被广泛使用,它其实包含了两个模型,一个是生成模型(Generative Model),一个是判别模型(Discriminative Model),即生成器和判别器。GAN利用两者相互竞争来学习目标(数据)的分布,生成器会尝试欺骗判别器,让它认为生成的样本是真实的;判别器会尝试区分真实的样本和生成的样本。具体流程如下所示:

AI算法Python实现:BGAN(Boundary Seeking GAN)

GAN的目标函数如下:

 \underset{G}{min}\underset{D}{max}V\left ( D, G \right )=E_{x\sim p_{data}\left ( x \right )}\left [ log D\left ( x \right ) \right ]+E_{z\sim p_{z}\left ( z \right )}\left [ log\left ( 1-D\left ( G\left ( z \right ) \right ) \right ) \right ]

训练过程中固定一方,更新另一个网络的参数,交替迭代。但是原始的GAN训练有着两个显著的缺陷:难以训练离散数据以及训练困难。而BGAN可以很好的解决这两点。

1.2 Boundary Seeking原理

Boundary Seeking是一种训练GAN的方式,它让生成器不直接依赖于判别器的输出,而是去寻找一个目标分布的边界,这个目标分布在理想情况下会和数据分布一致。这样做有两个好处:一是可以处理离散数据,比如文本或图像;二是可以避免GAN训练过程中出现的不稳定性或模式崩溃。

我们可以把目标分布的边界想象成一个圆形的围栏,里面有很多真实数据,比如二进制序列。生成器要尽量产生一些靠近围栏的样本,也就是说和真实数据很相似的样本。这样判别器就很难发现生成器产生的样本和真实数据之间的区别。如果生成器产生一些远离围栏的样本,比如非二进制序列,那么判别器就很容易识别出来,并给出一个很低的得分。这个得分就是生成器要优化的目标函数。

生成的数据如果在围栏的中心,也是和真实的数据很相似,但是这样的话,生成器就没有办法探索更多的可能性。因为在围栏的中心,生成器产生的样本和真实数据之间的距离都很小,判别器给出的得分都很高,生成器就没有梯度来更新参数。而如果生成器产生一些靠近边界的样本,那么判别器给出的得分就会有一定的变化,生成器就可以根据这个变化来调整参数。这样生成器就可以学习到更多的数据特征,并且避免了模式崩溃(mode collapse)。

1.2 BGAN原理

BGAN采用Boundary Seeking的方法对GAN进行训练,引入策略梯度(Policy Gradient)来解决离散值导致价值函数不是处处可微的问题。引入策略梯度后GAN不再直接根据是否骗过判别网络调整生成网络,而是间接基于判别网络的评价计算目标,可以提高训练的稳定度。

原始GAN论文中表示,最优的判别器为:

D_{G}^{*}(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)}

因此,如果我们知道每个生成器对应的最优判别器就可以重新整理上面的方程,最终变成下面这样:

p_{data}(x)=p_{g}(x)\frac{D_{G}^{*}(x)}{1-D_{G}^{*}(x)}

从这个方程我们可以看出,即使我们没有得到最优的生成器G,仍然可以通过调整p_{g}(x)、生成器的分布、生成器与判别器的比例,得到真实数据的分布。虽然我们很难得到最优的判别器,但是,我们可以通过不断地训练D(x)来迫近它,我们的训练效果也将越来越好。

如果我们训练出来的生成器足够完美,那么p_{g}(x)将无限接近于p_{data}(x),判别器将无法判断生成样本和真实样本之间的区别,即D(x)=0.5。因此最优的生成器就是能使判别器处处都为0.5的那个。这个D(x)=0.5便是我们要找的决策边界,也就是上面提到的基于判别网络的评价计算目标。这样的话,我们可以调整生成器的目标函数,使得判别器的输出都为0.5。新的生成器目标函数如下:

\underset{G}{min}E_{x\sim p_{G}(x)}[0.5(log(D(x)-log(1-D(x))))^{2}]

 其目标函数的目的是减少D(x)1-D(x)之间的距离,即使D(x)=0.5

二、算法实现

  • models
    • BGAN.py
    • __init__.py
  • data
    • mnist
  • train.py

BGAN.py

import numpy as np import torch.nn as nn class Generator(nn.Module): def __init__(self, latent_dim, image): super(Generator, self).__init__() self.latent_dim = latent_dim self.image = image def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(self.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(self.image))), nn.Tanh() ) def forward(self, z): img = self.model(z) img = img.view(img.shape[0], self.image[0], self.image[1], self.image[2]) return img class Discriminator(nn.Module): def __init__(self, image): super(Discriminator, self).__init__() self.image = image self.model = nn.Sequential( nn.Linear(int(np.prod(self.image)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, img): img_flat = img.view(img.shape[0], -1) validity = self.model(img_flat) return validity 

train.py

import os import argparse import torch import numpy as np import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable from models.BGAN import Generator, Discriminator os.makedirs("images", exist_ok=True) def parser_args(): parser = argparse.ArgumentParser() parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training") parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension") parser.add_argument("--channels", type=int, default=1, help="number of image channels") parser.add_argument("--sample_interval", type=int, default=400, help="interval between image samples") args = parser.parse_args() return args def boundary_seeking_loss(y_pred): """ Boundary seeking loss. """ return 0.5 * torch.mean((torch.log(y_pred) - torch.log(1 - y_pred)) 2) def train(gen, disc, disc_loss, device, dataloader, optim_G, optim_D, n_epochs, latent_dim, sample_interval): gen.to(device) disc.to(device) disc_loss.to(device) tensor = torch.cuda.FloatTensor for epoch in range(n_epochs): for i, (img, _) in enumerate(dataloader): # Adversarial ground truths valid = Variable(tensor(img.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(tensor(img.shape[0], 1).fill_(0.0), requires_grad=False) # Configure input real_img = Variable(img.type(tensor)) # ----------------- # Train Generator # ----------------- optim_G.zero_grad() # Sample noise as generator input z = Variable(tensor(np.random.normal(0, 1, (img.shape[0], latent_dim)))) # Generate a batch of images gen_img = gen(z) # Loss measures generator's ability to fool the discriminator g_loss = boundary_seeking_loss(disc(gen_img)) g_loss.backward() optim_G.step() # --------------------- # Train Discriminator # --------------------- optim_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = disc_loss(disc(real_img), valid) fake_loss = disc_loss(disc(gen_img.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optim_D.step() if i % 100 == 0: print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()) ) batches_done = epoch * len(dataloader) + i if batches_done % sample_interval == 0: save_image(gen_img.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True) def main(): args = parser_args() img_shape = (args.channels, args.img_size, args.img_size) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # Initialize generator and discriminator gen = Generator(args.latent_dim, img_shape) disc = Discriminator(img_shape) disc_loss = torch.nn.BCELoss() # Configure data loader os.makedirs("./data/mnist", exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "./data/mnist", train=True, download=True, transform=transforms.Compose( [transforms.Resize(args.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])] ), ), batch_size=args.batch_size, shuffle=True, ) # Optimizers optim_g = torch.optim.Adam(gen.parameters(), lr=args.lr, betas=(args.b1, args.b2)) optim_d = torch.optim.Adam(disc.parameters(), lr=args.lr, betas=(args.b1, args.b2)) train(gen, disc, disc_loss, device, dataloader, optim_g, optim_d, args.n_epochs, args.latent_dim, args.sample_interval) return if __name__ == '__main__': main()

训练结果:

AI算法Python实现:BGAN(Boundary Seeking GAN)

三、参考资料

1. 原论文地址

2. BGAN

3. 生成对抗网络(GAN)

4. PyTorch Implementation of Boundary Seeking GAN

5. BGAN:支持离散值、提升训练稳定性的新GAN训练方法

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/132019.html

(0)
上一篇 2025-08-03 18:20
下一篇 2025-08-03 18:26

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信