import os import torch import torch.nn as nn from torchvision import transforms, utils from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from tqdm import tqdm import torch.multiprocessing # ---------------------------- # 1. 判别器 # ---------------------------- class Detector(nn.Module): def __init__(self): super(Detector, self).__init__() self.model = nn.Sequential( # 3 x 64 x 64 -> 64 x 32 x 32 nn.Conv2d(3, 64, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 64 x 32 x 32 -> 128 x 16 x 16 nn.Conv2d(64, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # 128 x 16 x 16 -> 256 x 8 x 8 nn.Conv2d(128, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # 256 x 8 x 8 -> 512 x 4 x 4 nn.Conv2d(256, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # 512 x 4 x 4 -> 1024 x 2 x 2 nn.Conv2d(512, 1024, 4, 2, 1, bias=False), nn.BatchNorm2d(1024), nn.LeakyReLU(0.2, inplace=True), # 1024 x 2 x 2 -> 1 x 1 x 1 nn.Conv2d(1024, 1, 2, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): return self.model(x) # ---------------------------- # 2. 生成器 # ---------------------------- class Generator(nn.Module): def __init__(self, z_dim=100): super(Generator, self).__init__() self.model = nn.Sequential( # z -> 1024 x 4 x 4 nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False), nn.BatchNorm2d(1024), nn.ReLU(True), # 1024 x 4 x 4 -> 512 x 8 x 8 nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), # 512 x 8 x 8 -> 256 x 16 x 16 nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), # 256 x 16 x 16 -> 128 x 32 x 32 nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), # 128 x 32 x 32 -> 3 x 64 x 64 nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, z): return self.model(z) # ---------------------------- # 3. 数据集加载 # ---------------------------- transform = transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) dataset = ImageFolder(root='./data', transform=transform) dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4) # ---------------------------- # 4. 参数与设备 # ---------------------------- z_dim = 100 num_epochs = 100 lr = 0.002 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ---------------------------- # 5. 模型与优化器 # ---------------------------- d = Detector().to(device) g = Generator(z_dim=z_dim).to(device) criterion = nn.BCELoss() optimizer_d = torch.optim.Adam(d.parameters(), lr=lr, betas=(0.5, 0.999)) optimizer_g = torch.optim.Adam(g.parameters(), lr=lr, betas=(0.5, 0.999)) # 固定噪声用于观察训练过程 z_fixed = torch.randn(64, z_dim, 1, 1, device=device) # 创建保存目录 os.makedirs("results", exist_ok=True) # ---------------------------- # 6. 训练循环 # ---------------------------- def train(): for epoch in range(1, num_epochs + 1): loss_d_total, loss_g_total = 0, 0 for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False): real_images = real_images.to(device) B = real_images.size(0) # 标签平滑(真实 = 0.9, 假 = 0.0) real_labels = torch.full((B, 1, 1, 1), 0.9, device=device) fake_labels = torch.zeros((B, 1, 1, 1), device=device) # 生成假图像 z = torch.randn(B, z_dim, 1, 1, device=device) fake_images = g(z) # 判别器训练 output_real = d(real_images) output_fake = d(fake_images.detach()) loss_real = criterion(output_real, real_labels) loss_fake = criterion(output_fake, fake_labels) loss_d = loss_real + loss_fake optimizer_d.zero_grad() loss_d.backward() optimizer_d.step() # 生成器训练 output = d(fake_images) loss_g = criterion(output, real_labels) # 欺骗D:希望D输出真 optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() loss_d_total += loss_d.item() loss_g_total += loss_g.item() # 平均损失 avg_loss_d = loss_d_total / len(dataloader) avg_loss_g = loss_g_total / len(dataloader) print(f"Epoch [{epoch}/{num_epochs}] Loss_D: {avg_loss_d:.4f} Loss_G: {avg_loss_g:.4f}") # 每5轮保存一次生成图像 if epoch % 5 == 0: with torch.no_grad(): fake = g(z_fixed).detach().cpu() utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8) if __name__ == "__main__": torch.multiprocessing.freeze_support() # ✅ Windows 兼容性 train() # 把训练过程封装成 main() 函数