Introduction
A diffusion mannequin usually phrases is a kind of generative deep studying mannequin that creates information from a discovered denoising course of. There are a lot of variations of diffusion fashions with the most well-liked ones normally being textual content conditional fashions that may generate a sure picture primarily based on a immediate. Some diffusion fashions (Management-Internet) may even mix pictures with sure inventive kinds. Right here is an instance beneath right here:
In case you don’t know what’s so particular concerning the picture, attempt shifting farther away from the display screen or squinting your eyes to see the key hidden within the picture.
There are a lot of completely different functions and varieties of diffusion fashions, however on this tutorial we’re going to construct the foundational unconditional diffusion mannequin, DDPM (Denoising Diffusion Probabilistic Fashions) [1]. We are going to begin by trying into how the algorithm works intuitively beneath the hood, after which we are going to construct it from scratch in PyTorch. Additionally, this tutorial will focus totally on the intuitive concept behind the algorithm and the precise implementation particulars. For the mathematical derivations and background, this e-book [2] is a good reference.
Final Notes: This implementation was constructed for workflows that comprise a single GPU with CUDA compatibility. As well as, the entire code repository could be discovered right here https://github.com/nickd16/Diffusion-Fashions-from-Scratch
The way it Works -> The Ahead and Reverse Course of
The diffusion course of features a ahead and a reverse course of. The ahead course of is a predetermined Markov chain primarily based on a noise schedule. The noise schedule is a set of variances B1, B2, … BT that govern the conditional regular distributions that make up the Markov chain.
This method is the mathematical illustration of the ahead course of, however intuitively we are able to perceive it as a sequence the place we regularly map our information examples X to pure noise. Our first time period within the ahead course of is simply our preliminary information instance. At an intermediate time step t, now we have a noised model of X, and at our remaining time step T, we arrive at pure noise that’s roughly ruled by a typical regular distribution. Once we construct a diffusion mannequin, we select our noise schedule. In DDPM for instance, our noise schedule options 1000 time steps of linearly rising variances beginning at 1e-4 to 0.02. It’s also vital to notice that our ahead course of is static, which means we select our noise schedule as a hyperparameter to our diffusion mannequin and we don’t prepare the ahead course of as it’s already outlined explicitly.
The ultimate key element now we have to know concerning the ahead course of is that as a result of the distributions are regular, we are able to mathematically derive a distribution often called the “Diffusion Kernel” which is the distribution of any intermediate worth in our ahead course of given our preliminary information level. This enables us to bypass all the intermediate steps of iteratively including t-1 ranges of noise within the ahead course of to get a picture with t noise which can turn out to be useful later after we prepare our mannequin. That is mathematically represented as:
the place alpha at time t is outlined because the cumulative product (1-B) from our preliminary time step to our present time step.
The reverse course of is the important thing to a diffusion mannequin. The reverse course of is actually the undoing of the ahead course of by regularly eradicating quantities of noise from a pure noisy picture to generate new pictures. We do that by beginning at purely noised information, and for every time step t we subtract the quantity of noise that may have theoretically been added by the ahead course of for that point step. We preserve eradicating noise till ultimately now we have one thing that resembles our authentic information distribution. The majority of our work is coaching a mannequin to rigorously approximate the ahead course of to be able to estimate a reverse course of that may generate new samples.
The Algorithm and Coaching Goal
To coach such a mannequin to estimate the reverse diffusion course of, we are able to comply with the algorithm within the picture outlined beneath:
- Take a randomly sampled information level from our coaching dataset
- Choose a random timestep on our noise (variance) schedule
- Add the noise from that point step to our information, simulating the ahead diffusion course of by the “diffusion kernel”
- Move our defused picture into our mannequin to foretell the noise we added
- Compute the imply squared error between the expected noise and the precise noise and optimize our mannequin’s parameters by that goal operate
- And repeat!
Mathematically, the precise method within the algorithm may look a bit of unusual at first with out seeing the total derivation, however intuitively its a reparameterization of the diffusion kernel primarily based on the alpha values of our noise schedule and its merely the squared distinction of predicted noise and the precise noise we added to a picture.
If our mannequin can efficiently predict the quantity of noise primarily based on a particular time step of our ahead course of, we are able to iteratively begin from noise at time step T and regularly take away noise primarily based on every time step till we get better information that resembles a generated pattern from our authentic information distribution.
The sampling algorithm is summarized within the following:
- Generate random noise from a typical regular distribution
For every timestep ranging from our final timestep and shifting backwards:
2. Replace Z by estimating the reverse course of distribution with imply parameterized by Z from the earlier step and variance parameterized by the noise our mannequin estimates at that timestep
3. Add a small quantity of the noise again for stability (clarification beneath)
4. And repeat till we arrive at time step 0, our recovered picture!
The algorithm to then pattern and generate pictures may look mathematically difficult however it intuitively boils right down to an iterative course of the place we begin with pure noise, estimate the noise that theoretically was added at time step t, and subtract it. We do that till we arrive at our generated pattern. The one small element we must be conscious of is after we subtract the estimated noise, we add again a small quantity of it to maintain the method secure. For instance, estimating and subtracting the full quantity of noise at first of the iterative course of suddenly results in very incoherent samples, so in observe including a little bit of the noise again and iterating by each time step has empirically been proven to generate higher samples.
The UNET
The authors of the DDPM paper used the UNET structure initially designed for medical picture segmentation to construct a mannequin to foretell the noise for the diffusion reverse course of. The mannequin we’re going to use on this tutorial is supposed for 32×32 pictures good for datasets akin to MNIST, however the mannequin could be scaled to additionally deal with information of a lot increased resolutions. There are a lot of variations of the UNET, however the overview of the mannequin structure we are going to construct is within the picture beneath.
The UNET for DDPM is just like the traditional UNET as a result of it comprises each a down sampling stream and an up sampling stream that lightens the computational burden of the community, whereas additionally having skip connections between the 2 streams to merge the knowledge from each the shallow and deep options of the mannequin.
The principle variations between the DDPM UNET and the traditional UNET is that the DDPM UNET options consideration within the 16×16 dimensional layers and sinusoidal transformer embeddings in each residual block. The which means behind the sinusoidal embeddings is to inform the mannequin which era step we are attempting to foretell the noise. This helps the mannequin predict the noise at every time step by injecting positional info on the place the mannequin is on our noise schedule. For instance, if we had a schedule of noise that had quite a lot of noise in sure time steps, the mannequin understanding what time step it has to foretell might help the mannequin’s prediction on that noise for the corresponding time step. Extra basic info on consideration and embeddings could be discovered right here [3] for these not already aware of them from the transformer structure.
In our implementation of the mannequin, we are going to begin by defining our imports (attainable pip set up instructions commented for reference) and coding our sinusoidal time step embeddings. Intuitively, the sinusoidal embeddings are completely different sin and cos frequencies that may be added on to our inputs to offer the mannequin extra positional/sequential understanding. As you’ll be able to see from the picture beneath, every sinusoidal wave is exclusive which can give the mannequin consciousness on its location in our noise schedule.
# Imports
import torch
import torch.nn as nn
import torch.nn.useful as F
from einops import rearrange #pip set up einops
from typing import Record
import random
import math
from torchvision import datasets, transforms
from torch.utils.information import DataLoader
from timm.utils import ModelEmaV3 #pip set up timm
from tqdm import tqdm #pip set up tqdm
import matplotlib.pyplot as plt #pip set up matplotlib
import torch.optim as optim
import numpy as npclass SinusoidalEmbeddings(nn.Module):
def __init__(self, time_steps:int, embed_dim: int):
tremendous().__init__()
place = torch.arange(time_steps).unsqueeze(1).float()
div = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)
embeddings[:, 0::2] = torch.sin(place * div)
embeddings[:, 1::2] = torch.cos(place * div)
self.embeddings = embeddings
def ahead(self, x, t):
embeds = self.embeddings[t].to(x.system)
return embeds[:, :, None, None]
The residual blocks in every layer of the UNET will probably be equal to those used within the authentic DDPM paper. Every residual block can have a sequence of group-norm, the ReLU activation, a 3×3 “identical” convolution, dropout, and a skip-connection.
# Residual Blocks
class ResBlock(nn.Module):
def __init__(self, C: int, num_groups: int, dropout_prob: float):
tremendous().__init__()
self.relu = nn.ReLU(inplace=True)
self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
self.dropout = nn.Dropout(p=dropout_prob, inplace=True)def ahead(self, x, embeddings):
x = x + embeddings[:, :x.shape[1], :, :]
r = self.conv1(self.relu(self.gnorm1(x)))
r = self.dropout(r)
r = self.conv2(self.relu(self.gnorm2(r)))
return r + x
In DDPM, the authors used 2 residual blocks per layer (decision scale) of the UNET and for the 16×16 dimension layers, we embody the traditional transformer consideration mechanism between the 2 residual blocks. We are going to now implement the eye mechanism for the UNET:
class Consideration(nn.Module):
def __init__(self, C: int, num_heads:int , dropout_prob: float):
tremendous().__init__()
self.proj1 = nn.Linear(C, C*3)
self.proj2 = nn.Linear(C, C)
self.num_heads = num_heads
self.dropout_prob = dropout_probdef ahead(self, x):
h, w = x.form[2:]
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.proj1(x)
x = rearrange(x, 'b L (C H Okay) -> Okay b H L C', Okay=3, H=self.num_heads)
q,okay,v = x[0], x[1], x[2]
x = F.scaled_dot_product_attention(q,okay,v, is_causal=False, dropout_p=self.dropout_prob)
x = rearrange(x, 'b H (h w) C -> b h w (C H)', h=h, w=w)
x = self.proj2(x)
return rearrange(x, 'b h w C -> b C h w')
The eye implementation is straight ahead. We reshape our information such that the h*w dimensions are mixed right into a “sequence” dimension just like the traditional enter for a transformer mannequin and the channel dimension turns into the embedding characteristic dimension. On this implementation we make the most of torch.nn.useful.scaled_dot_product_attention as a result of this implementation comprises flash consideration, which is an optimized model of consideration which continues to be mathematically equal to traditional transformer consideration. For extra info on flash consideration you’ll be able to refer to those papers: [4], [5].
Lastly at this level, we are able to outline an entire layer of the UNET:
class UnetLayer(nn.Module):
def __init__(self,
upscale: bool,
consideration: bool,
num_groups: int,
dropout_prob: float,
num_heads: int,
C: int):
tremendous().__init__()
self.ResBlock1 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
self.ResBlock2 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
if upscale:
self.conv = nn.ConvTranspose2d(C, C//2, kernel_size=4, stride=2, padding=1)
else:
self.conv = nn.Conv2d(C, C*2, kernel_size=3, stride=2, padding=1)
if consideration:
self.attention_layer = Consideration(C, num_heads=num_heads, dropout_prob=dropout_prob)def ahead(self, x, embeddings):
x = self.ResBlock1(x, embeddings)
if hasattr(self, 'attention_layer'):
x = self.attention_layer(x)
x = self.ResBlock2(x, embeddings)
return self.conv(x), x
Every layer in DDPM as beforehand mentioned has 2 residual blocks and will comprise an consideration mechanism, and we moreover move our embeddings into every residual block. Additionally, we return each the downsampled or upsampled worth in addition to the worth prior which we are going to retailer and use for our residual concatenated skip connections.
Lastly, we are able to end the UNET Class:
class UNET(nn.Module):
def __init__(self,
Channels: Record = [64, 128, 256, 512, 512, 384],
Attentions: Record = [False, True, False, False, False, True],
Upscales: Record = [False, False, False, True, True, True],
num_groups: int = 32,
dropout_prob: float = 0.1,
num_heads: int = 8,
input_channels: int = 1,
output_channels: int = 1,
time_steps: int = 1000):
tremendous().__init__()
self.num_layers = len(Channels)
self.shallow_conv = nn.Conv2d(input_channels, Channels[0], kernel_size=3, padding=1)
out_channels = (Channels[-1]//2)+Channels[0]
self.late_conv = nn.Conv2d(out_channels, out_channels//2, kernel_size=3, padding=1)
self.output_conv = nn.Conv2d(out_channels//2, output_channels, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.embeddings = SinusoidalEmbeddings(time_steps=time_steps, embed_dim=max(Channels))
for i in vary(self.num_layers):
layer = UnetLayer(
upscale=Upscales[i],
consideration=Attentions[i],
num_groups=num_groups,
dropout_prob=dropout_prob,
C=Channels[i],
num_heads=num_heads
)
setattr(self, f'Layer{i+1}', layer)def ahead(self, x, t):
x = self.shallow_conv(x)
residuals = []
for i in vary(self.num_layers//2):
layer = getattr(self, f'Layer{i+1}')
embeddings = self.embeddings(x, t)
x, r = layer(x, embeddings)
residuals.append(r)
for i in vary(self.num_layers//2, self.num_layers):
layer = getattr(self, f'Layer{i+1}')
x = torch.concat((layer(x, embeddings)[0], residuals[self.num_layers-i-1]), dim=1)
return self.output_conv(self.relu(self.late_conv(x)))
The implementation is straight ahead primarily based on the courses now we have already created. The one distinction on this implementation is that our channels for the up-stream are barely bigger than the standard channels of the UNET. I discovered that this structure skilled extra effectively on a single GPU with 16GB of VRAM.
The Scheduler
Coding the noise/variance scheduler for DDPM can also be very simple. In DDPM, our schedule will begin, as beforehand talked about, at 1e-4 and finish at 0.02 and enhance linearly.
class DDPM_Scheduler(nn.Module):
def __init__(self, num_time_steps: int=1000):
tremendous().__init__()
self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)
alpha = 1 - self.beta
self.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)def ahead(self, t):
return self.beta[t], self.alpha[t]
We return each the beta (variance) values and the alpha values since we the formulation for coaching and sampling use each primarily based on their mathematical derivations.
def set_seed(seed: int = 42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
Moreover (not required) this operate defines a coaching seed. Because of this if you wish to reproduce a particular coaching occasion you need to use a set seed such that the random weight and optimizer initializations are the identical every time you utilize the identical seed.
Coaching
For our implementation, we are going to create a mannequin to generate MNIST information (hand written digits). Since these pictures are 28×28 by default in pytorch, we pad the pictures to 32×32 to comply with the unique paper skilled on 32×32 pictures.
For optimization, we use Adam with preliminary studying price of 2e-5. We additionally use EMA (Exponential Shifting Common) to help in era high quality. EMA is a weighted common of the mannequin’s parameters that in inference time can create smoother, much less noisy samples. For this implementation I exploit the library timm’s EMAV3 out of the field implementation with weight 0.9999 as used within the DDPM paper.
To summarize our coaching, we merely comply with the psuedo-code above. We decide random time steps for our batch, noise our information within the batch primarily based on our schedule at these time steps, and we enter that batch of noised pictures into the UNET together with the time steps themselves to information the sinusoidal embeddings. We use the formulation within the pseudo-code primarily based on the “diffusion kernel” to noise the pictures. We then take our mannequin’s prediction of how a lot noise we added and examine to the precise noise we added and optimize the imply squared error of the noise. We additionally carried out primary checkpointing to pause and resume coaching on completely different epochs.
def prepare(batch_size: int=64,
num_time_steps: int=1000,
num_epochs: int=15,
seed: int=-1,
ema_decay: float=0.9999,
lr=2e-5,
checkpoint_path: str=None):
set_seed(random.randint(0, 2**32-1)) if seed == -1 else set_seed(seed)train_dataset = datasets.MNIST(root='./information', prepare=True, obtain=False,rework=transforms.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)
mannequin = UNET().cuda()
optimizer = optim.Adam(mannequin.parameters(), lr=lr)
ema = ModelEmaV3(mannequin, decay=ema_decay)
if checkpoint_path will not be None:
checkpoint = torch.load(checkpoint_path)
mannequin.load_state_dict(checkpoint['weights'])
ema.load_state_dict(checkpoint['ema'])
optimizer.load_state_dict(checkpoint['optimizer'])
criterion = nn.MSELoss(discount='imply')
for i in vary(num_epochs):
total_loss = 0
for bidx, (x,_) in enumerate(tqdm(train_loader, desc=f"Epoch {i+1}/{num_epochs}")):
x = x.cuda()
x = F.pad(x, (2,2,2,2))
t = torch.randint(0,num_time_steps,(batch_size,))
e = torch.randn_like(x, requires_grad=False)
a = scheduler.alpha[t].view(batch_size,1,1,1).cuda()
x = (torch.sqrt(a)*x) + (torch.sqrt(1-a)*e)
output = mannequin(x, t)
optimizer.zero_grad()
loss = criterion(output, e)
total_loss += loss.merchandise()
loss.backward()
optimizer.step()
ema.replace(mannequin)
print(f'Epoch {i+1} | Loss {total_loss / (60000/batch_size):.5f}')
checkpoint = {
'weights': mannequin.state_dict(),
'optimizer': optimizer.state_dict(),
'ema': ema.state_dict()
}
torch.save(checkpoint, 'checkpoints/ddpm_checkpoint')
For inference, we precisely comply with once more the opposite a part of the pseudo code. Intuitively, we’re simply reversing the ahead course of. We’re ranging from pure noise, and our now skilled mannequin can predict the estimated noise at every time step and may then generate model new samples iteratively. Every completely different place to begin for the noise, we are able to generate a distinct distinctive pattern that’s just like our authentic information distribution however distinctive. The formulation for inference weren’t derived on this article however the reference linked at first might help information readers who desire a deeper understanding.
Additionally be aware, I included a helper operate to view the subtle pictures so you’ll be able to visualize how nicely the mannequin discovered the reverse course of.
def display_reverse(pictures: Record):
fig, axes = plt.subplots(1, 10, figsize=(10,1))
for i, ax in enumerate(axes.flat):
x = pictures[i].squeeze(0)
x = rearrange(x, 'c h w -> h w c')
x = x.numpy()
ax.imshow(x)
ax.axis('off')
plt.present()def inference(checkpoint_path: str=None,
num_time_steps: int=1000,
ema_decay: float=0.9999, ):
checkpoint = torch.load(checkpoint_path)
mannequin = UNET().cuda()
mannequin.load_state_dict(checkpoint['weights'])
ema = ModelEmaV3(mannequin, decay=ema_decay)
ema.load_state_dict(checkpoint['ema'])
scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)
occasions = [0,15,50,100,200,300,400,550,700,999]
pictures = []
with torch.no_grad():
mannequin = ema.module.eval()
for i in vary(10):
z = torch.randn(1, 1, 32, 32)
for t in reversed(vary(1, num_time_steps)):
t = [t]
temp = (scheduler.beta[t]/( (torch.sqrt(1-scheduler.alpha[t]))*(torch.sqrt(1-scheduler.beta[t])) ))
z = (1/(torch.sqrt(1-scheduler.beta[t])))*z - (temp*mannequin(z.cuda(),t).cpu())
if t[0] in occasions:
pictures.append(z)
e = torch.randn(1, 1, 32, 32)
z = z + (e*torch.sqrt(scheduler.beta[t]))
temp = scheduler.beta[0]/( (torch.sqrt(1-scheduler.alpha[0]))*(torch.sqrt(1-scheduler.beta[0])) )
x = (1/(torch.sqrt(1-scheduler.beta[0])))*z - (temp*mannequin(z.cuda(),[0]).cpu())
pictures.append(x)
x = rearrange(x.squeeze(0), 'c h w -> h w c').detach()
x = x.numpy()
plt.imshow(x)
plt.present()
display_reverse(pictures)
pictures = []
def predominant():
prepare(checkpoint_path='checkpoints/ddpm_checkpoint', lr=2e-5, num_epochs=75)
inference('checkpoints/ddpm_checkpoint')if __name__ == '__main__':
predominant()
After coaching for 75 epochs with the experimental particulars listed above, we acquire these outcomes:
At this level now we have simply coded DDPM from scratch in PyTorch!
Thanks for studying!
References
[1] DDPM https://arxiv.org/abs/2006.11239
[2] Understanding Deep Studying https://udlbook.github.io/udlbook/
[3] Consideration is All You Want https://arxiv.org/abs/1706.03762