Files
python/Pytorch/Project/DCGAN1/main.py
e2hang 7673cc9279 DCGAN
2025-10-28 22:29:54 +08:00

245 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import torch.nn as nn
from torchvision import transforms, utils
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch.multiprocessing
from PIL import Image # 自定义数据集需要 import PIL
import matplotlib.pyplot as plt
# 判别器
class Detector(nn.Module):
def __init__(self):
super(Detector, self).__init__()
self.model = nn.Sequential(
# 3 x 64 x 64 -> 64 x 32 x 32
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 64 x 32 x 32 -> 128 x 16 x 16
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 128 x 16 x 16 -> 256 x 8 x 8
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 256 x 8 x 8 -> 512 x 4 x 4
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 512 x 4 x 4 -> 1024 x 2 x 2
nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
# 1024 x 2 x 2 -> 1 x 1 x 1
nn.Conv2d(1024, 1, 2, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 生成器
class Generator(nn.Module):
def __init__(self, z_dim=100):
super(Generator, self).__init__()
self.model = nn.Sequential(
# z -> 1024 x 4 x 4
nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(True),
# 1024 x 4 x 4 -> 512 x 8 x 8
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 512 x 8 x 8 -> 256 x 16 x 16
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 256 x 16 x 16 -> 128 x 32 x 32
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 128 x 32 x 32 -> 3 x 64 x 64
nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# 数据集加载(使用自定义路径 ./data/images
class FlatImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
if not os.path.exists(root_dir):
raise FileNotFoundError(f"Dataset directory '{root_dir}' not found. Please create it and add images.")
self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))]
if len(self.image_files) == 0:
raise ValueError(f"No valid image files found in '{root_dir}'. Supported: .png, .jpg, .jpeg, .bmp")
self.image_paths = [os.path.join(root_dir, f) for f in self.image_files]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB') # 确保转为 RGB3 通道)
if self.transform:
image = self.transform(image)
return image, 0 # 返回图像和虚拟标签GAN 不使用)
# 使用自定义数据集
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载自定义数据集
dataset = FlatImageDataset(root_dir='./data/images', transform=transform)
print(f"Loaded {len(dataset)} images from ./data/images") # 调试:打印数据集大小
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2) # 减小 num_workers 避免 Windows 问题
# 参数
z_dim = 100
num_epochs = 100 # 6.2w
lr_d = 0.001
lr_g = 0.002
n_g_steps = 2 # 标准 DCGAN是1步G
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 模型与优化器
d = Detector().to(device)
g = Generator(z_dim=z_dim).to(device)
criterion = nn.BCELoss()
optimizer_d = torch.optim.Adam(d.parameters(), lr=lr_d, betas=(0.5, 0.999))
optimizer_g = torch.optim.Adam(g.parameters(), lr=lr_g, betas=(0.5, 0.999))
# 固定噪声用于观察训练过程
z_fixed = torch.randn(64, z_dim, 1, 1, device=device)
# 创建保存目录
os.makedirs("results", exist_ok=True)
# 根据保存的 G dict 生成图片的函数
def generate_from_g_dict(model_path, z_dim=100, num_images=64, output_path='generated.png'):
"""
从保存的生成器 state_dict 文件加载模型,并生成图片保存。
"""
if not os.path.exists(model_path):
print(f"Model path '{model_path}' not found. Skipping generation.")
return
# 加载生成器并恢复权重
g_loaded = Generator(z_dim=z_dim).to(device)
g_loaded.load_state_dict(torch.load(model_path, map_location=device))
g_loaded.eval()
# 生成假图像
with torch.no_grad():
z = torch.randn(num_images, z_dim, 1, 1, device=device)
fake_images = g_loaded(z).detach().cpu()
# 保存图片
utils.save_image(fake_images, output_path, normalize=True, nrow=8)
print(f"Generated images saved to {output_path}")
# 训练循环
def train():
for epoch in range(1, num_epochs + 1):
loss_d_total, loss_g_total = 0, 0
for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False):
real_images = real_images.to(device)
B = real_images.size(0)
# 标签平滑(真实 = 0.9, 假 = 0.0
real_labels = torch.full((B, 1, 1, 1), 0.9, device=device)
fake_labels = torch.full((B, 1, 1, 1), 0.0, device=device)
# 生成假图像
z = torch.randn(B, z_dim, 1, 1, device=device)
fake_images = g(z)
# 判别器训练
output_real = d(real_images)
output_fake = d(fake_images.detach())
loss_real = criterion(output_real, real_labels)
loss_fake = criterion(output_fake, fake_labels)
loss_d = loss_real + loss_fake
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()
# 生成器训练(标准 BCE
for _ in range(n_g_steps):
#每次生成器更新前重新生成假图像
z = torch.randn(B, z_dim, 1, 1, device=device)
fake_images = g(z)
output = d(fake_images)
loss_g = criterion(output, real_labels) # 欺骗 D希望 D 输出真
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
loss_g_total += loss_g.item()
loss_d_total += loss_d.item()
# 平均损失G 损失已累加 n_g_steps 次)
avg_loss_d = loss_d_total / len(dataloader)
avg_loss_g = loss_g_total / (len(dataloader) * n_g_steps) # 修复:除以总 G 步数
print(f"Epoch [{epoch}/{num_epochs}] Loss_D: {avg_loss_d:.4f} Loss_G: {avg_loss_g:.4f}")
loss_history = {"D": [], "G": []}
# 每轮结束时:
loss_history["D"].append(avg_loss_d)
loss_history["G"].append(avg_loss_g)
# 最后画图:
plt.plot(loss_history["D"], label="Loss_D")
plt.plot(loss_history["G"], label="Loss_G")
plt.legend()
plt.savefig("results/loss_curve.png")
# 每2轮保存一次生成图像和 G 的 state_dict
if epoch % 2 == 0:
with torch.no_grad():
fake = g(z_fixed).detach().cpu()
utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8)
# 保存 G 的 state_dictdict 形式)
g_state_dict_path = f"results/g_epoch_{epoch}.pth"
torch.save(g.state_dict(), g_state_dict_path)
print(f"Generator state_dict saved to {g_state_dict_path}")
if __name__ == "__main__":
torch.multiprocessing.freeze_support()
#train()
for i in range(100):
text = 'results/generated_after_train' + str(i) + '.png'
generate_from_g_dict('results/g_epoch_85.pth', output_path=text)