YouTip LogoYouTip

Pytorch Gan

Generative Adversarial Network (GAN) is one of the most creative model architectures in deep learning. By making two neural networks compete against each other and learn from each other, it can ultimately generate very realistic data. GANs are widely used in scenarios such as image generation, style transfer, and data augmentation.\\n\\n* * *\\n\\n## 1. Core Principles of GAN\\n\\nThe core idea of GAN comes from the "zero-sum game" in game theory. It contains two competing networks:\\n\\n* Generator: Learns to generate fake data, with the goal of making the discriminator unable to distinguish between generated data and real data\\n* Discriminator: Learns to distinguish between real data and generated data, with the goal of being as accurate as possible\\n\\nThe two compete against each other during the training process, continuously improving, and ultimately reaching a Nash equilibrium state.\\n\\n### 1.1 GAN Objective Function\\n\\nThe training objective of GAN can be expressed as the following minimax game:\\n\\n$$\\n\\\\underset{G}{min ⁑} \\\\underset{D}{max ⁑} \\\\mathbb{E}_{x sim p_{d a t a} \\\\left(\\\\right. x \\\\left.\\\\right)} \\\\left[\\\\right. log ⁑ D \\\\left(\\\\right. x \\\\left.\\\\right) \\\\left]\\\\right. + \\\\mathbb{E}_{z sim p_{z} \\\\left(\\\\right. z \\\\left.\\\\right)} \\\\left[\\\\right. log ⁑ \\\\left(\\\\right. 1 - D \\\\left(\\\\right. G \\\\left(\\\\right. z \\\\left.\\\\right) \\\\left.\\\\right) \\\\left.\\\\right) \\\\left]\\\\right.\\n$$\\n\\nWhere:\\n\\n* \\\\(G\\\\) represents the generator network\\n* \\\\(D\\\\) represents the discriminator network\\n* \\\\(x\\\\) represents real data\\n* \\\\(z\\\\) represents a random noise vector (usually following a standard normal distribution)\\n* \\\\(G(z)\\\\) represents the fake data generated by the generator based on noise\\n\\n### 1.2 Training Process Interpretation\\n\\nGAN training is divided into two stages:\\n\\nFirst Stage: Train the Discriminator\\n\\nFix the generator, improve the discriminator's ability to distinguish:\\n\\n$$\\n\\\\underset{D}{max ⁑} \\\\mathbb{E}_{x sim p_{d a t a}} \\\\left[\\\\right. log ⁑ D \\\\left(\\\\right. x \\\\left.\\\\right) \\\\left]\\\\right. + \\\\mathbb{E}_{z sim p_{z}} \\\\left[\\\\right. log ⁑ \\\\left(\\\\right. 1 - D \\\\left(\\\\right. G \\\\left(\\\\right. z \\\\left.\\\\right) \\\\left.\\\\right) \\\\left.\\\\right) \\\\left]\\\\right.\\n$$\\n\\nSecond Stage: Train the Generator\\n\\nFix the discriminator, improve the generator's ability to deceive:\\n\\n$$\\n\\\\underset{G}{min ⁑} \\\\mathbb{E}_{z sim p_{z}} \\\\left[\\\\right. log ⁑ \\\\left(\\\\right. 1 - D \\\\left(\\\\right. G \\\\left(\\\\right. z \\\\left.\\\\right) \\\\left.\\\\right) \\\\left.\\\\right) \\\\left]\\\\right.\\n$$\\n\\n> In actual training, the discriminator is usually trained for k steps first, and then the generator is trained for 1 step to maintain balance.\\n\\n* * *\\n\\n## 2. Basic GAN Implementation\\n\\nBelow we implement one of the simplest GANsβ€”used to generate 2D data points.\\n\\n### 2.1 Define Generator and Discriminator\\n\\n## Instance\\n\\nimport torch\\n\\nimport torch.nn as nn\\n\\nimport torch.optim as optim\\n\\nimport matplotlib.pyplot as plt\\n\\n# Set Random Seed\\n\\n torch.manual_seed(42)\\n\\n# ── Generator Network ──────────────────────────────────────\\n\\nclass Generator(nn.Module):\\n\\n"""\\n\\n Generator: Generate Data from Random Noise\\n\\n Input: Noise Vector (batch_size, noise_dim)\\n\\n Output: Generated Data (batch_size, data_dim)\\n\\n """\\n\\ndef __init__ (self, noise_dim, data_dim, hidden_dim=64):\\n\\nsuper(). __init__ ()\\n\\nself.net= nn.Sequential(\\n\\n nn.Linear(noise_dim, hidden_dim),\\n\\n nn.ReLU(),\\n\\n nn.Linear(hidden_dim, hidden_dim),\\n\\n nn.ReLU(),\\n\\n nn.Linear(hidden_dim, data_dim),\\n\\n# Output Without Activation, GAN Will Learn Appropriate Distribution\\n\\n)\\n\\ndef forward(self, x):\\n\\nreturn self.net(x)\\n\\n# ── Discriminator Network ──────────────────────────────────────\\n\\nclass Discriminator(nn.Module):\\n\\n"""\\n\\n Discriminator: Distinguish Between Real and Generated Data\\n\\n Input: Data Points (batch_size, data_dim)\\n\\n Output: Probability of Real Data (batch_size, 1)\\n\\n """\\n\\ndef __init__ (self, data_dim, hidden_dim=64):\\n\\nsuper(). __init__ ()\\n\\nself.net= nn.Sequential(\\n\\n nn.Linear(data_dim, hidden_dim),\\n\\n nn.LeakyReLU(0.2),# LeakyReLU Prevent Gradient Vanishing\\n\\n nn.Linear(hidden_dim, hidden_dim),\\n\\n nn.LeakyReLU(0.2),\\n\\n nn.Linear(hidden_dim,1),\\n\\n nn.Sigmoid()# Output Probability\\n\\n)\\n\\ndef forward(self, x):\\n\\nreturn self.net(x)\\n\\n# Hyperparameters\\n\\n NOISE_DIM =16\\n\\n DATA_DIM =2\\n\\n HIDDEN_DIM =64\\n\\n BATCH_SIZE =128\\n\\n# Create Network\\n\\n generator = Generator(NOISE_DIM, DATA_DIM, HIDDEN_DIM)\\n\\n discriminator = Discriminator(DATA_DIM, HIDDEN_DIM)\\n\\nprint(f"Generator Parameter Count: {sum(p.numel() for p in generator.parameters()):,}")\\n\\nprint(f"Discriminator Parameter Count: {sum(p.numel() for p in discriminator.parameters()):,}")\\n\\n### 2.2 Training Loop\\n\\n## Instance\\n\\n# ── Optimizer ──────────────────────────────────────\\n\\n lr =0.001\\n\\n g_optimizer = optim.Adam(generator.parameters(), lr=lr)\\n\\n d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)\\n\\n# Loss Function: Binary Cross Entropy\\n\\n criterion = nn.BCELoss()\\n\\n# ── Training Data: Ring Distribution ──────────────────────────\\n\\ndef generate_real_data(batch_size):\\n\\n"""Generate Real Data with Ring Distribution"""\\n\\n angles = torch.rand(batch_size) * 2 * torch.pi\\n\\n radius =1.0 + torch.randn(batch_size) * 0.1# Radius Approximately 1\\n\\n x = radius * torch.cos(angles)\\n\\n y = radius * torch.sin(angles)\\n\\nreturn torch.stack([x, y], dim=1)\\n\\n# ── Training Loop ──────────────────────────────────────\\n\\n NUM_EPOCHS =1000\\n\\n d_losses =[]\\n\\n g_losses =[]\\n\\nfor epoch in range(NUM_EPOCHS):\\n\\n# 1. Train Discriminator\\n\\n# Generate Fake Data\\n\\n noise = torch.randn(BATCH_SIZE, NOISE_DIM)\\n\\n fake_data = generator(noise).detach()# detach Avoid Calculating Generator Gradients\\n\\n# Generate Real Data\\n\\n real_data = generate_real_data(BATCH_SIZE)\\n\\n# Discriminator Loss\\n\\n real_pred = discriminator(real_data)\\n\\n fake_pred = discriminator(fake_data)\\n\\n d_loss = criterion(real_pred, torch.ones_like(real_pred)) + \\\\\\n\\ncriterion(fake_pred, torch.zeros_like(fake_pred))\\n\\n# Update Discriminator\\n\\n d_optimizer.zero_grad()\\n\\n d_loss.backward()\\n\\n d_optimizer.step()\\n\\n# 2. Train Generator\\n\\n# Generate New Batch of Fake Data\\n\\n noise = torch.randn(BATCH_SIZE, NOISE_DIM)\\n\\n fake_data = generator(noise)\\n\\n# Generator Loss: Make Discriminator Believe Generated Data is Real\\n\\n fake_pred = discriminator(fake_data)\\n\\n g_loss = criterion(fake_pred, torch.ones_like(fake_pred))\\n\\n# Update Generator\\n\\n g_optimizer.zero_grad()\\n\\n g_loss.backward()\\n\\n g_optimizer.step()\\n\\n# Record Loss\\n\\n d_losses.append(d_loss.item())\\n\\n g_losses.append(g_loss.item())\\n\\nif(epoch + 1) % 100==0:\\n\\nprint(f"Epoch {epoch+1:4d} | D_loss: {d_loss:.4f} | G_loss: {g_loss:.4f}")\\n\\nprint("Training Complete!")\\n\\n### 2.3 Visualize Generation Results\\n\\n## Instance\\n\\n# Generate Data and Visualize\\n\\ndef visualize_results(generator, num_samples=1000):\\n\\n noise = torch.randn(num_samples, NOISE_DIM)\\n\\n generated_data = generator(noise).detach().numpy()\\n\\nplt.figure(figsize=(6,6))\\n\\n plt.scatter(generated_data[:,0], generated_data[:,1],\\n\\n alpha=0.5, s=10, c='blue', label='Generated')\\n\\n plt.xlim(-2,2)\\n\\n plt.ylim(-2,2)\\n\\n plt.xlabel('x')\\n\\n plt.ylabel('y')\\n\\n plt.title('GAN Generated Data')\\n\\n plt.legend()\\n\\n plt.grid(True, alpha=0.3)\\n\\n plt.show()\\n\\n# Check Generation Results\\n\\n visualize_results(generator)\\n\\n* * *\\n\\n## 3. DCGAN - Deep Convolutional GAN\\n\\nDCGAN is a classic architecture that introduces Convolutional Neural Networks into GANs, significantly improving the quality of image generation.\\n\\n### 3.1 DCGAN Architecture Key Points\\n\\n* Uses Transposed Convolution for upsampling to generate images\\n* Uses strided convolution for downsampling to discriminate images\\n* Uses BatchNorm in both generator and discriminator (except for the output layer and input layer)\\n* Generator uses ReLU, discriminator uses LeakyReLU\\n\\n### 3.2 DCGAN Implementation\\n\\n## Instance\\n\\nimport torch\\n\\nimport torch.nn as nn\\n\\n# ── DCGAN Generator ─────────────────────────────────\\n\\nclass DCGenerator(nn.Module):\\n\\n"""\\n\\n DCGAN Generator:Upsampling with Transposed Convolution\\n\\n """\\n\\ndef __init__ (self, noise_dim=100, channels=3,
← Pytorch Evaluation DebuggingPytorch Lstm Gru β†’