DCGAN
This commit is contained in:
176
Pytorch/Project/DCGAN1/test.py
Normal file
176
Pytorch/Project/DCGAN1/test.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms, utils
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
from tqdm import tqdm
|
||||
import torch.multiprocessing
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 1. 判别器
|
||||
# ----------------------------
|
||||
class Detector(nn.Module):
|
||||
def __init__(self):
|
||||
super(Detector, self).__init__()
|
||||
self.model = nn.Sequential(
|
||||
# 3 x 64 x 64 -> 64 x 32 x 32
|
||||
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# 64 x 32 x 32 -> 128 x 16 x 16
|
||||
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# 128 x 16 x 16 -> 256 x 8 x 8
|
||||
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# 256 x 8 x 8 -> 512 x 4 x 4
|
||||
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# 512 x 4 x 4 -> 1024 x 2 x 2
|
||||
nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# 1024 x 2 x 2 -> 1 x 1 x 1
|
||||
nn.Conv2d(1024, 1, 2, 1, 0, bias=False),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 2. 生成器
|
||||
# ----------------------------
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, z_dim=100):
|
||||
super(Generator, self).__init__()
|
||||
self.model = nn.Sequential(
|
||||
# z -> 1024 x 4 x 4
|
||||
nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(1024),
|
||||
nn.ReLU(True),
|
||||
|
||||
# 1024 x 4 x 4 -> 512 x 8 x 8
|
||||
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.ReLU(True),
|
||||
|
||||
# 512 x 8 x 8 -> 256 x 16 x 16
|
||||
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.ReLU(True),
|
||||
|
||||
# 256 x 16 x 16 -> 128 x 32 x 32
|
||||
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(True),
|
||||
|
||||
# 128 x 32 x 32 -> 3 x 64 x 64
|
||||
nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, z):
|
||||
return self.model(z)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# 3. 数据集加载
|
||||
# ----------------------------
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(64),
|
||||
transforms.CenterCrop(64),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
dataset = ImageFolder(root='./data', transform=transform)
|
||||
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
|
||||
|
||||
# ----------------------------
|
||||
# 4. 参数与设备
|
||||
# ----------------------------
|
||||
z_dim = 100
|
||||
num_epochs = 100
|
||||
lr = 0.002
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# ----------------------------
|
||||
# 5. 模型与优化器
|
||||
# ----------------------------
|
||||
d = Detector().to(device)
|
||||
g = Generator(z_dim=z_dim).to(device)
|
||||
criterion = nn.BCELoss()
|
||||
optimizer_d = torch.optim.Adam(d.parameters(), lr=lr, betas=(0.5, 0.999))
|
||||
optimizer_g = torch.optim.Adam(g.parameters(), lr=lr, betas=(0.5, 0.999))
|
||||
|
||||
# 固定噪声用于观察训练过程
|
||||
z_fixed = torch.randn(64, z_dim, 1, 1, device=device)
|
||||
|
||||
# 创建保存目录
|
||||
os.makedirs("results", exist_ok=True)
|
||||
|
||||
# ----------------------------
|
||||
# 6. 训练循环
|
||||
# ----------------------------
|
||||
def train():
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
loss_d_total, loss_g_total = 0, 0
|
||||
for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False):
|
||||
real_images = real_images.to(device)
|
||||
B = real_images.size(0)
|
||||
|
||||
# 标签平滑(真实 = 0.9, 假 = 0.0)
|
||||
real_labels = torch.full((B, 1, 1, 1), 0.9, device=device)
|
||||
fake_labels = torch.zeros((B, 1, 1, 1), device=device)
|
||||
|
||||
# 生成假图像
|
||||
z = torch.randn(B, z_dim, 1, 1, device=device)
|
||||
fake_images = g(z)
|
||||
|
||||
# 判别器训练
|
||||
output_real = d(real_images)
|
||||
output_fake = d(fake_images.detach())
|
||||
|
||||
loss_real = criterion(output_real, real_labels)
|
||||
loss_fake = criterion(output_fake, fake_labels)
|
||||
loss_d = loss_real + loss_fake
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
loss_d.backward()
|
||||
optimizer_d.step()
|
||||
|
||||
# 生成器训练
|
||||
output = d(fake_images)
|
||||
loss_g = criterion(output, real_labels) # 欺骗D:希望D输出真
|
||||
optimizer_g.zero_grad()
|
||||
loss_g.backward()
|
||||
optimizer_g.step()
|
||||
|
||||
loss_d_total += loss_d.item()
|
||||
loss_g_total += loss_g.item()
|
||||
|
||||
# 平均损失
|
||||
avg_loss_d = loss_d_total / len(dataloader)
|
||||
avg_loss_g = loss_g_total / len(dataloader)
|
||||
print(f"Epoch [{epoch}/{num_epochs}] Loss_D: {avg_loss_d:.4f} Loss_G: {avg_loss_g:.4f}")
|
||||
|
||||
# 每5轮保存一次生成图像
|
||||
if epoch % 5 == 0:
|
||||
with torch.no_grad():
|
||||
fake = g(z_fixed).detach().cpu()
|
||||
utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.multiprocessing.freeze_support() # ✅ Windows 兼容性
|
||||
train() # 把训练过程封装成 main() 函数
|
||||
Reference in New Issue
Block a user