From 33b72751ab11c6794bfbe204fdb56ca47013f48c Mon Sep 17 00:00:00 2001 From: e2hang <2099307493@qq.com> Date: Tue, 28 Oct 2025 22:27:20 +0800 Subject: [PATCH] DCGAN --- Pytorch/Project/DCGAN/.idea/DCGAN.iml | 8 + .../inspectionProfiles/Project_Default.xml | 266 ++++++++++++++++++ .../inspectionProfiles/profiles_settings.xml | 6 + Pytorch/Project/DCGAN/.idea/misc.xml | 7 + Pytorch/Project/DCGAN/.idea/modules.xml | 8 + Pytorch/Project/DCGAN/.idea/workspace.xml | 53 ++++ Pytorch/Project/DCGAN/main.py | 245 ++++++++++++++++ Pytorch/Project/DCGAN/test.py | 176 ++++++++++++ 8 files changed, 769 insertions(+) create mode 100644 Pytorch/Project/DCGAN/.idea/DCGAN.iml create mode 100644 Pytorch/Project/DCGAN/.idea/inspectionProfiles/Project_Default.xml create mode 100644 Pytorch/Project/DCGAN/.idea/inspectionProfiles/profiles_settings.xml create mode 100644 Pytorch/Project/DCGAN/.idea/misc.xml create mode 100644 Pytorch/Project/DCGAN/.idea/modules.xml create mode 100644 Pytorch/Project/DCGAN/.idea/workspace.xml create mode 100644 Pytorch/Project/DCGAN/main.py create mode 100644 Pytorch/Project/DCGAN/test.py diff --git a/Pytorch/Project/DCGAN/.idea/DCGAN.iml b/Pytorch/Project/DCGAN/.idea/DCGAN.iml new file mode 100644 index 0000000..b910500 --- /dev/null +++ b/Pytorch/Project/DCGAN/.idea/DCGAN.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/.idea/inspectionProfiles/Project_Default.xml b/Pytorch/Project/DCGAN/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..92645b2 --- /dev/null +++ b/Pytorch/Project/DCGAN/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,266 @@ + + + + \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/.idea/inspectionProfiles/profiles_settings.xml b/Pytorch/Project/DCGAN/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/Pytorch/Project/DCGAN/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/.idea/misc.xml b/Pytorch/Project/DCGAN/.idea/misc.xml new file mode 100644 index 0000000..15d27bc --- /dev/null +++ b/Pytorch/Project/DCGAN/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/.idea/modules.xml b/Pytorch/Project/DCGAN/.idea/modules.xml new file mode 100644 index 0000000..10aba47 --- /dev/null +++ b/Pytorch/Project/DCGAN/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/.idea/workspace.xml b/Pytorch/Project/DCGAN/.idea/workspace.xml new file mode 100644 index 0000000..0e3fefa --- /dev/null +++ b/Pytorch/Project/DCGAN/.idea/workspace.xml @@ -0,0 +1,53 @@ + + + + + + + + + + + { + "associatedIndex": 3 +} + + + + + + + + + + + + + + 1761570889588 + + + + \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/main.py b/Pytorch/Project/DCGAN/main.py new file mode 100644 index 0000000..412145b --- /dev/null +++ b/Pytorch/Project/DCGAN/main.py @@ -0,0 +1,245 @@ +import os +import torch +import torch.nn as nn +from torchvision import transforms, utils +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm +import torch.multiprocessing +from PIL import Image # 自定义数据集需要 import PIL +import matplotlib.pyplot as plt + + +# 判别器 +class Detector(nn.Module): + def __init__(self): + super(Detector, self).__init__() + self.model = nn.Sequential( + # 3 x 64 x 64 -> 64 x 32 x 32 + nn.Conv2d(3, 64, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + + # 64 x 32 x 32 -> 128 x 16 x 16 + nn.Conv2d(64, 128, 4, 2, 1, bias=False), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + + # 128 x 16 x 16 -> 256 x 8 x 8 + nn.Conv2d(128, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + + # 256 x 8 x 8 -> 512 x 4 x 4 + nn.Conv2d(256, 512, 4, 2, 1, bias=False), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + + # 512 x 4 x 4 -> 1024 x 2 x 2 + nn.Conv2d(512, 1024, 4, 2, 1, bias=False), + nn.BatchNorm2d(1024), + nn.LeakyReLU(0.2, inplace=True), + + # 1024 x 2 x 2 -> 1 x 1 x 1 + nn.Conv2d(1024, 1, 2, 1, 0, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + return self.model(x) + + +# 生成器 +class Generator(nn.Module): + def __init__(self, z_dim=100): + super(Generator, self).__init__() + self.model = nn.Sequential( + # z -> 1024 x 4 x 4 + nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(True), + + # 1024 x 4 x 4 -> 512 x 8 x 8 + nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(True), + + # 512 x 8 x 8 -> 256 x 16 x 16 + nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True), + + # 256 x 16 x 16 -> 128 x 32 x 32 + nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(True), + + # 128 x 32 x 32 -> 3 x 64 x 64 + nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False), + nn.Tanh() + ) + + def forward(self, z): + return self.model(z) + + +# 数据集加载(使用自定义路径 ./data/images) +class FlatImageDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.transform = transform + if not os.path.exists(root_dir): + raise FileNotFoundError(f"Dataset directory '{root_dir}' not found. Please create it and add images.") + self.image_files = [f for f in os.listdir(root_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))] + if len(self.image_files) == 0: + raise ValueError(f"No valid image files found in '{root_dir}'. Supported: .png, .jpg, .jpeg, .bmp") + self.image_paths = [os.path.join(root_dir, f) for f in self.image_files] + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + img_path = self.image_paths[idx] + image = Image.open(img_path).convert('RGB') # 确保转为 RGB(3 通道) + if self.transform: + image = self.transform(image) + return image, 0 # 返回图像和虚拟标签(GAN 不使用) + + +# 使用自定义数据集 +transform = transforms.Compose([ + transforms.Resize(64), + transforms.CenterCrop(64), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) +]) + +# 加载自定义数据集 +dataset = FlatImageDataset(root_dir='./data/images', transform=transform) +print(f"Loaded {len(dataset)} images from ./data/images") # 调试:打印数据集大小 +dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=2) # 减小 num_workers 避免 Windows 问题 + + +# 参数 +z_dim = 100 +num_epochs = 100 # 6.2w +lr_d = 0.001 +lr_g = 0.002 +n_g_steps = 2 # 标准 DCGAN是1步G +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + +# 模型与优化器 +d = Detector().to(device) +g = Generator(z_dim=z_dim).to(device) +criterion = nn.BCELoss() +optimizer_d = torch.optim.Adam(d.parameters(), lr=lr_d, betas=(0.5, 0.999)) +optimizer_g = torch.optim.Adam(g.parameters(), lr=lr_g, betas=(0.5, 0.999)) + +# 固定噪声用于观察训练过程 +z_fixed = torch.randn(64, z_dim, 1, 1, device=device) + +# 创建保存目录 +os.makedirs("results", exist_ok=True) + + +# 根据保存的 G dict 生成图片的函数 +def generate_from_g_dict(model_path, z_dim=100, num_images=64, output_path='generated.png'): + """ + 从保存的生成器 state_dict 文件加载模型,并生成图片保存。 + """ + if not os.path.exists(model_path): + print(f"Model path '{model_path}' not found. Skipping generation.") + return + # 加载生成器并恢复权重 + g_loaded = Generator(z_dim=z_dim).to(device) + g_loaded.load_state_dict(torch.load(model_path, map_location=device)) + g_loaded.eval() + + # 生成假图像 + with torch.no_grad(): + z = torch.randn(num_images, z_dim, 1, 1, device=device) + fake_images = g_loaded(z).detach().cpu() + + # 保存图片 + utils.save_image(fake_images, output_path, normalize=True, nrow=8) + print(f"Generated images saved to {output_path}") + + +# 训练循环 +def train(): + for epoch in range(1, num_epochs + 1): + loss_d_total, loss_g_total = 0, 0 + for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False): + real_images = real_images.to(device) + B = real_images.size(0) + + # 标签平滑(真实 = 0.9, 假 = 0.0) + real_labels = torch.full((B, 1, 1, 1), 0.9, device=device) + fake_labels = torch.full((B, 1, 1, 1), 0.0, device=device) + + # 生成假图像 + z = torch.randn(B, z_dim, 1, 1, device=device) + fake_images = g(z) + + # 判别器训练 + output_real = d(real_images) + output_fake = d(fake_images.detach()) + + loss_real = criterion(output_real, real_labels) + loss_fake = criterion(output_fake, fake_labels) + loss_d = loss_real + loss_fake + + optimizer_d.zero_grad() + loss_d.backward() + optimizer_d.step() + + # 生成器训练(标准 BCE) + for _ in range(n_g_steps): + #每次生成器更新前重新生成假图像 + z = torch.randn(B, z_dim, 1, 1, device=device) + fake_images = g(z) + + output = d(fake_images) + loss_g = criterion(output, real_labels) # 欺骗 D:希望 D 输出真 + + optimizer_g.zero_grad() + loss_g.backward() + optimizer_g.step() + loss_g_total += loss_g.item() + + loss_d_total += loss_d.item() + + # 平均损失(G 损失已累加 n_g_steps 次) + avg_loss_d = loss_d_total / len(dataloader) + avg_loss_g = loss_g_total / (len(dataloader) * n_g_steps) # 修复:除以总 G 步数 + print(f"Epoch [{epoch}/{num_epochs}] Loss_D: {avg_loss_d:.4f} Loss_G: {avg_loss_g:.4f}") + + loss_history = {"D": [], "G": []} + + # 每轮结束时: + loss_history["D"].append(avg_loss_d) + loss_history["G"].append(avg_loss_g) + + # 最后画图: + plt.plot(loss_history["D"], label="Loss_D") + plt.plot(loss_history["G"], label="Loss_G") + plt.legend() + plt.savefig("results/loss_curve.png") + + # 每2轮保存一次生成图像和 G 的 state_dict + if epoch % 2 == 0: + with torch.no_grad(): + fake = g(z_fixed).detach().cpu() + utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8) + + # 保存 G 的 state_dict(dict 形式) + g_state_dict_path = f"results/g_epoch_{epoch}.pth" + torch.save(g.state_dict(), g_state_dict_path) + print(f"Generator state_dict saved to {g_state_dict_path}") + + +if __name__ == "__main__": + torch.multiprocessing.freeze_support() + #train() + for i in range(100): + text = 'results/generated_after_train' + str(i) + '.png' + generate_from_g_dict('results/g_epoch_85.pth', output_path=text) \ No newline at end of file diff --git a/Pytorch/Project/DCGAN/test.py b/Pytorch/Project/DCGAN/test.py new file mode 100644 index 0000000..deeff7f --- /dev/null +++ b/Pytorch/Project/DCGAN/test.py @@ -0,0 +1,176 @@ +import os +import torch +import torch.nn as nn +from torchvision import transforms, utils +from torch.utils.data import DataLoader +from torchvision.datasets import ImageFolder +from tqdm import tqdm +import torch.multiprocessing + + +# ---------------------------- +# 1. 判别器 +# ---------------------------- +class Detector(nn.Module): + def __init__(self): + super(Detector, self).__init__() + self.model = nn.Sequential( + # 3 x 64 x 64 -> 64 x 32 x 32 + nn.Conv2d(3, 64, 4, 2, 1, bias=False), + nn.LeakyReLU(0.2, inplace=True), + + # 64 x 32 x 32 -> 128 x 16 x 16 + nn.Conv2d(64, 128, 4, 2, 1, bias=False), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2, inplace=True), + + # 128 x 16 x 16 -> 256 x 8 x 8 + nn.Conv2d(128, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.LeakyReLU(0.2, inplace=True), + + # 256 x 8 x 8 -> 512 x 4 x 4 + nn.Conv2d(256, 512, 4, 2, 1, bias=False), + nn.BatchNorm2d(512), + nn.LeakyReLU(0.2, inplace=True), + + # 512 x 4 x 4 -> 1024 x 2 x 2 + nn.Conv2d(512, 1024, 4, 2, 1, bias=False), + nn.BatchNorm2d(1024), + nn.LeakyReLU(0.2, inplace=True), + + # 1024 x 2 x 2 -> 1 x 1 x 1 + nn.Conv2d(1024, 1, 2, 1, 0, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + return self.model(x) + + +# ---------------------------- +# 2. 生成器 +# ---------------------------- +class Generator(nn.Module): + def __init__(self, z_dim=100): + super(Generator, self).__init__() + self.model = nn.Sequential( + # z -> 1024 x 4 x 4 + nn.ConvTranspose2d(z_dim, 1024, 4, 1, 0, bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(True), + + # 1024 x 4 x 4 -> 512 x 8 x 8 + nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(True), + + # 512 x 8 x 8 -> 256 x 16 x 16 + nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True), + + # 256 x 16 x 16 -> 128 x 32 x 32 + nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(True), + + # 128 x 32 x 32 -> 3 x 64 x 64 + nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False), + nn.Tanh() + ) + + def forward(self, z): + return self.model(z) + + +# ---------------------------- +# 3. 数据集加载 +# ---------------------------- +transform = transforms.Compose([ + transforms.Resize(64), + transforms.CenterCrop(64), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) +]) + +dataset = ImageFolder(root='./data', transform=transform) +dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4) + +# ---------------------------- +# 4. 参数与设备 +# ---------------------------- +z_dim = 100 +num_epochs = 100 +lr = 0.002 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# ---------------------------- +# 5. 模型与优化器 +# ---------------------------- +d = Detector().to(device) +g = Generator(z_dim=z_dim).to(device) +criterion = nn.BCELoss() +optimizer_d = torch.optim.Adam(d.parameters(), lr=lr, betas=(0.5, 0.999)) +optimizer_g = torch.optim.Adam(g.parameters(), lr=lr, betas=(0.5, 0.999)) + +# 固定噪声用于观察训练过程 +z_fixed = torch.randn(64, z_dim, 1, 1, device=device) + +# 创建保存目录 +os.makedirs("results", exist_ok=True) + +# ---------------------------- +# 6. 训练循环 +# ---------------------------- +def train(): + for epoch in range(1, num_epochs + 1): + loss_d_total, loss_g_total = 0, 0 + for real_images, _ in tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", leave=False): + real_images = real_images.to(device) + B = real_images.size(0) + + # 标签平滑(真实 = 0.9, 假 = 0.0) + real_labels = torch.full((B, 1, 1, 1), 0.9, device=device) + fake_labels = torch.zeros((B, 1, 1, 1), device=device) + + # 生成假图像 + z = torch.randn(B, z_dim, 1, 1, device=device) + fake_images = g(z) + + # 判别器训练 + output_real = d(real_images) + output_fake = d(fake_images.detach()) + + loss_real = criterion(output_real, real_labels) + loss_fake = criterion(output_fake, fake_labels) + loss_d = loss_real + loss_fake + + optimizer_d.zero_grad() + loss_d.backward() + optimizer_d.step() + + # 生成器训练 + output = d(fake_images) + loss_g = criterion(output, real_labels) # 欺骗D:希望D输出真 + optimizer_g.zero_grad() + loss_g.backward() + optimizer_g.step() + + loss_d_total += loss_d.item() + loss_g_total += loss_g.item() + + # 平均损失 + avg_loss_d = loss_d_total / len(dataloader) + avg_loss_g = loss_g_total / len(dataloader) + print(f"Epoch [{epoch}/{num_epochs}] Loss_D: {avg_loss_d:.4f} Loss_G: {avg_loss_g:.4f}") + + # 每5轮保存一次生成图像 + if epoch % 5 == 0: + with torch.no_grad(): + fake = g(z_fixed).detach().cpu() + utils.save_image(fake, f"results/epoch_{epoch}.png", normalize=True, nrow=8) + +if __name__ == "__main__": + torch.multiprocessing.freeze_support() # ✅ Windows 兼容性 + train() # 把训练过程封装成 main() 函数