生成对抗网络——GAN(代码+理解)

生成对抗网络——GAN(代码+理解)一 GAN 模型介绍 GAN 全称 GenerativeAd 即生成对抗网络 是一种基于对抗学习的深度生成模型

大家好,欢迎来到IT知识分享网。

目录

一、GAN模型介绍

二、GAN模型的训练过程

1. 初始化网络:

2. 训练判别器:

3. 训练生成器:

4. 重复步骤 2和步骤 3:

三、GAN实现

1. 模型结构

(1)生成器(Generator)

(2)判别器(Discriminator)

2. 代码

3. 运行结果展示

四、GAN模型的应用领域

五、学习中产生的疑问,及文心一言回答

1. 生成器(Generator)模型理解

2. 为什么要使用 block 函数?

3. 函数解释

4. 为什么要将像素值从[0, 255]缩放到[0.0, 1.0] ?

5. 详细解释一下是怎样对Tensor进行标准化的,以及为什么要这么做?


一、GAN模型介绍

        GAN,全称 Generative Adversarial Network,即生成对抗网络,是一种基于 对抗学习的深度生成模型。该模型由Ian Goodfellow在 2014年 首次提出,并迅速成为 学术界研究的热点,推动了生成模型领域的发展。

        GAN模型主要由两部分组成:生成器(Generator)和判别器(Discriminator)

   1. 生成器:生成器模型可以是 任意结构的神经网络,其 输入是 随机噪声torch.randn,输出则是 生成的样本。生成器的 目标是使生成的样本尽可能接近真实样本的分布,以欺骗判别器

    2. 判别器:判别器模型同样可以是任意结构的神经网络,其 输入是真实样本或生成器生成的样本,输出是一个 概率值,表示 输入样本是真实样本的概率。判别器的 目标是尽可能准确地判断输入样本是真实样本还是生成样本

        这两个网络 在训练过程中 相互对抗、相互优化,形成了一种 零和博弈

生成对抗网络——GAN(代码+理解)

二、GAN模型的训练过程

1. 初始化网络

        生成器(Generator)和判别器(Discriminator)的参数需要随机初始化。这两个网络都是神经网络,通常使用随机权重开始训练。

# 实例化 generator = Generator() discriminator = Discriminator()

2. 训练判别器

1从真实数据集中 随机选择一批数据,将其输入到 判别器中进行训练。

for i, (img, _) in enumerate(dataloader): # 内层迭代次数为 10000 // 64 = 157次,每次 64个数据

2同时,从生成器的当前状态生成一批假数据(也称为生成数据),也将这些数据输入到判别器中进行训练。

# 假数据的生成 fake_img = torch.randn(size, 100)

3在训练判别器时,需要固定生成器的参数.detach函数),只更新判别器的参数。

output_fake = generator(fake_img) fake_socre = discriminator(output_fake.detach()) 

(4判别器的目标是将真实数据和假数据区分开来,因此其损失函数通常定义为二元交叉熵损失函数.BCELoss函数

# 损失函数 criterion = torch.nn.BCELoss() # 对应 Sigmoid,计算二元交叉墒损失

(5使用反向传播算法更新判别器的参数,以最小化损失函数。

3. 训练生成器

1生成器的目标是生成与真实数据相似的假数据,使得判别器无法区分真实数据和假数据。

2生成器的 损失函数通常定义为 判别器对 假数据的 输出结果的 交叉熵损失函数的 相反数。换句话说,生成器希望判别器 对假数据的判断结果 尽可能接近真实数据

fake_G_socre = discriminator(output_fake) G_fake_loss = criterion(fake_G_socre, torch.ones_like(fake_G_socre))

(3同样使用反向传播算法更新生成器的参数,以最小化其损失函数。

4. 重复步骤 2和步骤 3

1在每一轮训练中,先 训练判别器,然后 训练生成器。这样可以确保两个网络都能得到足够的优化。

2重复这个过程,直到达到预设的迭代次数或满足某种收敛条件(如生成器生成的假数据与真实数据的差距达到一定程度本案例没有设计)。

三、GAN实现

1. 模型结构

(1)生成器(Generator)

生成对抗网络——GAN(代码+理解)

(2)判别器(Discriminator)

生成对抗网络——GAN(代码+理解)

2. 代码

import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import datasets from torch.utils.data import DataLoader import numpy as np import matplotlib.pyplot as plt mnist = datasets.MNIST( root='./others/', train=False, download=False, transform=transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) ) dataloader = DataLoader( dataset=mnist, batch_size=64, shuffle=True ) def gen_img_plot(model, epoch, text_input): prediction = np.squeeze(model(text_input).detach().cpu().numpy()[:16]) plt.figure(figsize=(4, 4)) for i in range(16): plt.subplot(4, 4, i + 1) plt.imshow((prediction[i] + 1) / 2) plt.axis('off') plt.show() # 生成器定义 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2)) return layers self.mean = nn.Sequential( *block(100, 256, normalize=False), *block(256, 512), *block(512, 1024), nn.Linear(1024, 28 * 28), nn.Tanh() ) def forward(self, x): imgs = self.mean(x) imgs = imgs.view(-1, 1, 28, 28) return imgs class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.mean = nn.Sequential( nn.Linear(28 * 28, 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): x = x.view(-1, 28 * 28) img = self.mean(x) # 对 64条数据的每一条都进行模型运算 return img # 实例化 generator = Generator() discriminator = Discriminator() # 定义优化器 G_Apim = torch.optim.Adam(generator.parameters(), lr=0.0001) D_Apim = torch.optim.Adam(discriminator.parameters(), lr=0.0002) # 损失函数 criterion = torch.nn.BCELoss() # 对应 Sigmoid,计算二元交叉墒损失 epoch_num = 100 G_loss_save = [] D_loss_save = [] for epoch in range(epoch_num): # 将 10000 条数据迭代了两遍 G_epoch_loss = 0 D_epoch_loss = 0 count = len(dataloader) for i, (img, _) in enumerate(dataloader): # 内层迭代次数为 10000 // 64 = 157次,每次 64个数据 # 训练 Discriminator # 判断出假的 size = img.size(0) # 0 维有多少个数据 fake_img = torch.randn(size, 100) output_fake = generator(fake_img) fake_socre = discriminator(output_fake.detach()) # .detach() 返回一个关闭梯度的 output_fake,这样前向传播不会修改 generater 的 grad D_fake_loss = criterion(fake_socre, torch.zeros_like(fake_socre)) # 判断出真的 real_socre = discriminator(img) D_real_loss = criterion(real_socre, torch.ones_like(real_socre)) D_loss = D_fake_loss + D_real_loss D_Apim.zero_grad() D_loss.backward() D_Apim.step() # 训练 Generater # G_fake_img = torch.randn(size, 100) # G_output_fake = generator(G_fake_img) # fake_G_socre = discriminator(G_output_fake) fake_G_socre = discriminator(output_fake) G_fake_loss = criterion(fake_G_socre, torch.ones_like(fake_G_socre)) G_Apim.zero_grad() G_fake_loss.backward() G_Apim.step() with torch.no_grad(): # 其中所有的 requires_grad 都被默认设置为 False G_epoch_loss += G_fake_loss D_epoch_loss += D_loss with torch.no_grad(): G_epoch_loss /= count D_epoch_loss /= count G_loss_save.append(G_epoch_loss.item()) D_loss_save.append(D_epoch_loss.item()) print('Epoch: [%d/%d] | G_loss: %.3f | D_loss: %.3f' % (epoch, epoch_num, G_epoch_loss, D_epoch_loss)) text_input = torch.randn(64, 100) gen_img_plot(generator, epoch, text_input) x = [epoch + 1 for epoch in range(epoch_num)] plt.figure() plt.plot(x, G_loss_save, 'r') plt.plot(x, D_loss_save, 'b') plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['G_loss','D_loss']) plt.show()

3. 运行结果展示

生成对抗网络——GAN(代码+理解)

生成对抗网络——GAN(代码+理解)

四、GAN模型的应用领域

        GAN模型 在图像生成、视频生成、文本生成等领域都有 广泛的应用。例如,在图像生成领域,GAN模型可以生成高质量的图像、进行图像修复 和 超分辨率重建 等任务;在视频生成领域,GAN模型可以生成连贯的视频序列;在文本生成领域,GAN模型可以生成逼真的文本内容等。此外,GAN模型还在 AI绘画领域 发挥着 重要作用,成为AI绘画工作流中的 关键辅助模型。

五、学习中产生的疑问,及文心一言回答

1. 生成器(Generator)模型理解

生成对抗网络——GAN(代码+理解)

2. 为什么要使用 block 函数?

生成对抗网络——GAN(代码+理解)

3. 函数解释

生成对抗网络——GAN(代码+理解)

4. 为什么要将像素值从[0, 255]缩放到[0.0, 1.0] ?

生成对抗网络——GAN(代码+理解)

5. 详细解释一下是怎样对Tensor进行标准化的,以及为什么要这么做?

生成对抗网络——GAN(代码+理解)


                                                后续更新GAN的其他模型结构。

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/129095.html

(0)
上一篇 2025-08-26 19:20
下一篇 2025-08-26 19:26

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信