Files
python/Pytorch/Project/DCGAN1/test.py
e2hang 7673cc9279 DCGAN
2025-10-28 22:29:54 +08:00

177 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
from torchvision import transforms, utils
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import torch.multiprocessing
# ----------------------------
# 1. 判别器
# ----------------------------
class Detector(nn.Module):
def __init__(self):
super(Detector, self).__init__()
self.model = nn.Sequential(
# 3 x 64 x 64 -> 64 x 32 x 32
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 64 x 32 x 32 -> 128 x 16 x 16
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128 x 16 x 16 -> 256 x 8 x 8
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 256 x 8 x 8 -> 512 x 4 x 4
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 512 x 4 x 4 -> 1024 x 2 x 2
nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
# 1024 x 2 x 2 -> 1 x 1 x 1
nn.Conv2d(1024, 1, 2, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# ----------------------------
# 2. 生成器
# ----------------------------
class Generator(nn.Module):
def __init__(self, z_dim=100):
super(Generator, self).__init__()
self.model = nn.Sequential(
# z -> 1024 x 4 x 4
nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(True),
# 1024 x 4 x 4 -> 512 x 8 x 8
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 512 x 8 x 8 -> 256 x 16 x 16
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 x 16 x 16 -> 128 x 32 x 32
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 x 32 x 32 -> 3 x 64 x 64
nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# ----------------------------
# 3. 数据集加载
# ----------------------------
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = ImageFolder(root='./data', transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
# ----------------------------
# 4. 参数与设备
# ----------------------------
z_dim = 100
num_epochs = 100
lr = 0.002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ----------------------------
# 5. 模型与优化器
# ----------------------------
d = Detector().to(device)
g = Generator(z_dim=z_dim).to(device)
criterion = nn.BCELoss()
optimizer_d = torch.optim.Adam(d.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_g = torch.optim.Adam(g.parameters(), lr=lr, betas=(0.5, 0.999))
# 固定噪声用于观察训练过程
z_fixed = torch.randn(64, z_dim, 1, 1, device=device)
# 创建保存目录
os.makedirs("results", exist_ok=True)
# ----------------------------
# 6. 训练循环
# ----------------------------
def train():
for epoch in range(1, num_epochs + 1):
loss_d_total, loss_g_total = 0, 0
for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False):
real_images = real_images.to(device)
B = real_images.size(0)
# 标签平滑(真实 = 0.9, 假 = 0.0
real_labels = torch.full((B, 1, 1, 1), 0.9, device=device)
fake_labels = torch.zeros((B, 1, 1, 1), device=device)
# 生成假图像
z = torch.randn(B, z_dim, 1, 1, device=device)
fake_images = g(z)
# 判别器训练
output_real = d(real_images)
output_fake = d(fake_images.detach())
loss_real = criterion(output_real, real_labels)
loss_fake = criterion(output_fake, fake_labels)
loss_d = loss_real + loss_fake
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
# 生成器训练
output = d(fake_images)
loss_g = criterion(output, real_labels) # 欺骗D希望D输出真
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
loss_d_total += loss_d.item()
loss_g_total += loss_g.item()
# 平均损失
avg_loss_d = loss_d_total / len(dataloader)
avg_loss_g = loss_g_total / len(dataloader)
print(f"Epoch [{epoch}/{num_epochs}] Loss_D: {avg_loss_d:.4f} Loss_G: {avg_loss_g:.4f}")
# 每5轮保存一次生成图像
if epoch % 5 == 0:
with torch.no_grad():
fake = g(z_fixed).detach().cpu()
utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8)
if __name__ == "__main__":
torch.multiprocessing.freeze_support() # ✅ Windows 兼容性
train() # 把训练过程封装成 main() 函数