【DCGAN】失败的对抗生成网络训练(上)

By e2hang at 4 天前 • 0人收藏 • 44人看过

注:本文章的图片可能令人感到不适,请谨慎观看

一、对抗生成网络(Generative Adversarial Network)

    整个网络类似一个制作假币集团和警察的对抗,假币集团希望警察不能发现他们的假币,警察希望正确地甄别真币与假币。这里需要引入一些博弈论的知识:Minimax-极大极小博弈理论。

·极大极小平衡点(Minimax Equilibrium)

    G(Generator,生成器)输入的是一个随机噪声,输出的一个伪造样本;

    D(Detector,检测器)输入的是一个样本,输出的是"D认为这个样本是真的的概率"


    对于检测器D,我们希望它对于正确的样本输出1,对于不正确的输出0,那么就应该以这个为目标输出设计梯度下降函数;

    LD=− Ex∼pdata[logD(x)] − Ez∼pz[log(1−D(G(z)))]

    同理对于生成器G,我们希望它能够欺骗检测器D,那么还是以检测器D对于本样本的输出为基础,我们希望它输出1,那么我们就以1为目标输出设计梯度下降函数

    LG=− Ez∼pz[logD(G(z))]


理论存在,实践开始,下面上代码;请注意关注代码中的参数,这对训练来讲至关重要

重要参数:lr_d, lr_g, n_g_step, real_labels, fake_labels


二、对抗生成网络代码

import os
import torch
import torch.nn as nn
from torchvision import transforms, utils
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch.multiprocessing
from PIL import Image  # 自定义数据集需要 import PIL
import matplotlib.pyplot as plt


# 判别器
class Detector(nn.Module):
    def __init__(self):
        super(Detector, self).__init__()
        self.model = nn.Sequential(
            # 3 x 64 x 64 -> 64 x 32 x 32
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 64 x 32 x 32 -> 128 x 16 x 16
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # 128 x 16 x 16 -> 256 x 8 x 8
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # 256 x 8 x 8 -> 512 x 4 x 4
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # 512 x 4 x 4 -> 1024 x 2 x 2
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),

            # 1024 x 2 x 2 -> 1 x 1 x 1
            nn.Conv2d(1024, 1, 2, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


# 生成器
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            # z -> 1024 x 4 x 4
            nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),

            # 1024 x 4 x 4 -> 512 x 8 x 8
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 512 x 8 x 8 -> 256 x 16 x 16
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 256 x 16 x 16 -> 128 x 32 x 32
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 128 x 32 x 32 -> 3 x 64 x 64
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)


# 数据集加载(使用自定义路径 ./data/images)
class FlatImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        if not os.path.exists(root_dir):
            raise FileNotFoundError(f"Dataset directory '{root_dir}' not found. Please create it and add images.")
        self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
        if len(self.image_files) == 0:
            raise ValueError(f"No valid image files found in '{root_dir}'. Supported: .png, .jpg, .jpeg, .bmp")
        self.image_paths = [os.path.join(root_dir, f) for f in self.image_files]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # 确保转为 RGB(3 通道)
        if self.transform:
            image = self.transform(image)
        return image, 0  # 返回图像和虚拟标签(GAN 不使用)


# 使用自定义数据集
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载自定义数据集
dataset = FlatImageDataset(root_dir='./data/images', transform=transform)
print(f"Loaded {len(dataset)} images from ./data/images")  # 调试:打印数据集大小
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2)  # 减小 num_workers 避免 Windows 问题


# 参数
z_dim = 100
num_epochs = 100  # 6.2w
lr_d = 0.002
lr_g = 0.004
n_g_steps = 2  # 标准 DCGAN是1步G
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 模型与优化器
d = Detector().to(device)
g = Generator(z_dim=z_dim).to(device)
criterion = nn.BCELoss()
optimizer_d = torch.optim.Adam(d.parameters(), lr=lr_d, betas=(0.5, 0.999))
optimizer_g = torch.optim.Adam(g.parameters(), lr=lr_g, betas=(0.5, 0.999))

# 固定噪声用于观察训练过程
z_fixed = torch.randn(64, z_dim, 1, 1, device=device)

# 创建保存目录
os.makedirs("results", exist_ok=True)


# 根据保存的 G dict 生成图片的函数
def generate_from_g_dict(model_path, z_dim=100, num_images=64, output_path='generated.png'):
    """
    从保存的生成器 state_dict 文件加载模型,并生成图片保存。
    """
    if not os.path.exists(model_path):
        print(f"Model path '{model_path}' not found. Skipping generation.")
        return
    # 加载生成器并恢复权重
    g_loaded = Generator(z_dim=z_dim).to(device)
    g_loaded.load_state_dict(torch.load(model_path, map_location=device))
    g_loaded.eval()

    # 生成假图像
    with torch.no_grad():
        z = torch.randn(num_images, z_dim, 1, 1, device=device)
        fake_images = g_loaded(z).detach().cpu()

    # 保存图片
    utils.save_image(fake_images, output_path, normalize=True, nrow=8)
    print(f"Generated images saved to {output_path}")


# 训练循环
def train():
    for epoch in range(1, num_epochs + 1):
        loss_d_total, loss_g_total = 0, 0
        for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False):
            real_images = real_images.to(device)
            B = real_images.size(0)

            # 标签平滑(真实 = 0.9, 假 = 0.0)
            real_labels = torch.full((B, 1, 1, 1), 0.9, device=device)
            fake_labels = torch.full((B, 1, 1, 1), 0.1, device=device)

            # 生成假图像
            z = torch.randn(B, z_dim, 1, 1, device=device)
            fake_images = g(z)

            # 判别器训练
            output_real = d(real_images)
            output_fake = d(fake_images.detach())

            loss_real = criterion(output_real, real_labels)
            loss_fake = criterion(output_fake, fake_labels)
            loss_d = loss_real + loss_fake

            optimizer_d.zero_grad()
            loss_d.backward()
            optimizer_d.step()

            # 生成器训练(标准 BCE)
            for _ in range(n_g_steps):
                #每次生成器更新前重新生成假图像
                z = torch.randn(B, z_dim, 1, 1, device=device)
                fake_images = g(z)

                output = d(fake_images)
                loss_g = criterion(output, real_labels)  # 欺骗 D:希望 D 输出真

                optimizer_g.zero_grad()
                loss_g.backward()
                optimizer_g.step()
                loss_g_total += loss_g.item()

            loss_d_total += loss_d.item()

        # 平均损失(G 损失已累加 n_g_steps 次)
        avg_loss_d = loss_d_total / len(dataloader)
        avg_loss_g = loss_g_total / (len(dataloader) * n_g_steps)  # 修复:除以总 G 步数
        print(f"Epoch [{epoch}/{num_epochs}]  Loss_D: {avg_loss_d:.4f}  Loss_G: {avg_loss_g:.4f}")

        loss_history = {"D": [], "G": []}

        # 每轮结束时:
        loss_history["D"].append(avg_loss_d)
        loss_history["G"].append(avg_loss_g)

        # 最后画图:
        plt.plot(loss_history["D"], label="Loss_D")
        plt.plot(loss_history["G"], label="Loss_G")
        plt.legend()
        plt.savefig("results/loss_curve.png")

        # 每2轮保存一次生成图像和 G 的 state_dict
        if epoch % 2 == 0:
            with torch.no_grad():
                fake = g(z_fixed).detach().cpu()
                utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8)

            # 保存 G 的 state_dict(dict 形式)
            g_state_dict_path = f"results/g_epoch_{epoch}.pth"
            torch.save(g.state_dict(), g_state_dict_path)
            print(f"Generator state_dict saved to {g_state_dict_path}")


if __name__ == "__main__":
    torch.multiprocessing.freeze_support()
    train()
    #for i in range(100):
        #text = 'results/generated_after_train' + str(i) + '.png'
        #generate_from_g_dict('results/g_epoch_85.pth', output_path=text)

展示一部分训练集(64x64像素),一共6.3w条训练图像

00d130c5b926170312c4711d09b08c97.png

好的,那么代码写出来了,运行起来会是什么样子的呢?


三、训练期间的产出图片(图片可能令人不适)

第5次循环/共100次循环

epoch_5.png

第10次循环/共100次循环

epoch_10.png

第15次循环/共100次循环

epoch_15.png

第20次循环/共100次循环(初见端倪)

epoch_20.png

第25次循环/共100次循环

epoch_25.png

第30次循环/共100次循环

epoch_30.png

第35次循环/共100次循环(???)

epoch_35.png

第40次循环/共100次循环

epoch_40.png

第45次循环/共100次循环(何意味)

epoch_45.png

第50次循环/共100次循环

epoch_50.png

第55次循环/共100次循环(哎呦我)

epoch_55.png

第60次循环/共100次循环(?????)

epoch_60.png

第65次循环/共100次循环

epoch_65.png

第70次循环/共100次循环(此处更改了参数)

lr_d = 0.001, lr_g = 0.002, fake_labels的fill变成了0.0

epoch_70.png

第75次循环/共100次循环

epoch_75.png

第80次循环/共100次循环

epoch_80.png

第85次循环/共100次循环(最正常的一集)

epoch_85.png

第90次循环/共100次循环

epoch_90.png

第95次循环/共100次循环

epoch_95.png

第100次循环/共100次循环

epoch_100.png


后面我用第85次的保存过的模型输出了100张图,请仔细观察下面这张很有代表性的图

generated_after_train23.png

这个图很有代表性,左下角的6个图,只有略微的变化,大的形态上没有本质区别。其他生成的角色呢?也一样,就是这么几个人来回变。

好的,相信我们聪明的读者一定很容易发现问题吧!


四、原因分析

这就是 典型的 GAN 模式崩溃(Mode Collapse)。简单来说, Generator 学到了只生成几个样本的特定模式,哪怕加上一点随机噪声,输出还是那几个脸,只是微调了位置、颜色、亮度等。


1️⃣ 原因

  1. 模式崩溃(Mode Collapse)

  • GAN 的训练是对抗的:G 想骗 D,D 想识别真伪。

  • 如果 D 太强,G 找到“骗过 D 的最快方式” → 就学会重复生成几个样本。

  • 结果就是你看到的:总是几个固定人脸,变化很小。

  • 生成器容量限制 / 初始化问题

    • 如果 G 层数、通道数不够,或者初始化不好,容易陷入局部最优。

  • 训练超参数不合理

    • 学习率太高、n_g_steps>1 或标签过于极端 → 梯度不稳定,G 过快收敛到少数模式。

  • 数据多样性不足

    • 如果训练的数据集本身种类不够丰富,G 更容易崩溃到几个样本。


    2️⃣ 解决方案

  1. 降低学习率 / 调整更新比

  • DCGAN 官方推荐:lr=0.0002,D:G = 1:1。

  • 之前 n_g_steps=2 → 改回 1 会更稳定。

  • 标签平滑

    real_labels = 0.9
    fake_labels = 0.0
    • 避免 D 太快收敛。

  • 非饱和损失

    loss_g = -torch.mean(torch.log(d(fake_images) + 1e-8))
    • 可以缓解梯度消失,让 G 更新更稳定。

  • 增加数据集多样性

    • 数据集越多样,G 学到的模式越丰富。

  • 正则 / 初始化

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1 or classname.find('ConvTranspose') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
    g.apply(weights_init)
    d.apply(weights_init)
    • 避免竖条和特定模式的偏置。

  • 尝试不同 GAN 结构

    • WGAN-GPLSGANStyleGAN 更适合稳定生成,减少模式崩溃。



    1 个回复 | 最后更新于 4 天前
    4 天前   #1

    炼丹名不虚传

    登录后方可回帖

    登 录
    信息栏
    欢迎来到滑稽社论坛!注册会员即可发帖!

    你好啊

    Loading...