This notebook is an introduction to diffusion models that are currently state-of-the-art in computer vision and in image generation. I build from scratch a diffusion model able to generate realistic rotation curves starting from purely random noise.
neural_network
diffusion
jupyter
Author
Lorenzo Posti
Published
January 13, 2023
Introduction to diffusion models
During the end-of-the-year break I decided to redo the awesome fast.ai course just to refresh some ideas and to keep up to date. In particular, I knew that this year they were redoing the part 2 of the course, i.e. the deep learning foundations part, which is a fantastic resource to get a profound understanding of deep learning methods.
I was very pleseantly surprised to see this year they were going from the deep learning foundations to stable diffusion. This got me excited since I felt like this was a great opportunity to learn about the popular diffusion models, which are state-of-the-art in computer vision, to reproduce them and master them myself.
So, in this notebook I cover the first steps to set up a simple diffusion model from scratch! I am going to: - generate some data (galaxy rotation curves), - add some random amount of noise to the data, - train an autoencoder on the noisy dataset, this is effectively a denoising network, - sampling some realistic rotation curves gradually denoising pure noise.
First of all, let’s generate some data that we can later use for training. I’m going to use the same framework that I’ve been using for the blog posts on autoencoders and VAEs, so it’s entirely skippable if you’ve seen these previous posts.
Basically I’m going to generate a dataset of rotation curves of galaxies of different stellar disc masses, dark matter halo masses, disc scale-lengths and halo concentrations. These 4 parameters are not taken at random, but they are constrained by observed galaxy scaling relations. In this way, only the stellar mass is distributed uniformly, while the other parameters are well constrained.
We now add noise to our rotation curves, since we want the neural network to learn to de-noise a noisy curve. To do this, we construct a simple noise scheduler which is a function that adds noise to our input dataset corresponding to an amount a\(\in [0,1]\), where if a=0 the curve is noise-free and if a=1 the data is pure Gaussian noise.
The term scheduler refers to the arbitrary interpolating function that we choose between the two regimes a=0 and a=1. For the sake of simplicity, here I chose a linear scheduler.
Let’s now set up the neural network that will learn to de-noise our rotation curves. This is done with a slightly modified autoencoder model, where I’ve added skip connections, just to be closer to the U-Net framework which is popular for diffusion models.
class AE_net(nn.Module):def__init__(self, ninp, **kwargs):super().__init__()self.encoder_layers = nn.ModuleList([ nn.Linear(in_features=ninp, out_features=32), nn.Linear(in_features=32, out_features=16), nn.Linear(in_features=16, out_features=4) ])self.decoder_layers = nn.ModuleList([ nn.Linear(in_features=4, out_features=16), nn.Linear(in_features=16, out_features=32), nn.Linear(in_features=32, out_features=ninp) ])self.act = nn.SiLU()def forward(self, x): h = [] # skip connectionsfor i, l inenumerate(self.encoder_layers): x =self.act(l(x))if i <2: h.append(x) # store for skip connection, for all but final layerfor i, l inenumerate(self.decoder_layers):if i >0: x += h.pop() # get stored output for skip connection, for all but first layer x =self.act(l(x)) if i<2else l(x) # final layer without activationreturn x
model = AE_net(len(cm.rad))
And now the training phase. Notice that we first add a random amount of noise to the rotation curve dataset and we pass these noisy curves to the autoencoder. This way, the model will learn to recognize the curves even when noise is added to them.
# Adam and MSE Lossloss_func = nn.MSELoss(reduction='mean')optimizer = torch.optim.Adam(model.parameters(), lr=0.01)for epoch inrange(2001):# generate noise with random amount noise = torch.rand(xtrain.shape[0])# add noise to data x_noisy = add_noise(xtrain, noise)# prediction ymod = model.forward(x_noisy)# loss loss = loss_func(xtrain, ymod) optimizer.zero_grad() loss.backward() optimizer.step()if epoch%100==0: print(epoch,"train L:%1.2e"% loss, " valid L:%1.2e"% loss_func(xvalid, model.forward(xvalid)))
How well does the model predict the underlying rotation curves from their noisy version? It depends a lot from the noise amount. This is rather intuitive, since for low amount of noise we do expect the model to provide almost perfect predictions, whereas when the data is mostly comprised of noise than signal it is not surprising to see the model failing.
These results tell us that the model has successfully learned how to denoise high signal-to-noise rotation curves, i.e. data whose noise amount is fairly low, and that it struggels to capture the details of the curve at low signal-to-noise.
Of course, the autoencoder model that I used is quite simple and it could be further improved by making the model more sophisticated and the noise scheduler more efficient.
Sampling
Finally, let’s have a look at how we can use the model that I’ve just trained to generate new rotation curve data starting from random noise. We could, in principle, simply just feed some random noise into the model since it will give us a rotation curve as output. However, we saw before that the model becomes quite unreliable when the noise is dominant over the signal. So, how do we deal with this?
A smart, but simple, solution is to borrow from the way a diffusion differential equation is usually solved, that is applying the denoising iteratively in many steps, instead of all in one step. The idea is to start from random, apply just a small step of the denoising network, then take this output and apply another small denoising step to it until we have fully denoised the dataset. In this way, the denoising is done gradually, which allows for complex features to arise in the rotation curve, instead of always just predicting an average curve.
Let’s apply 30 steps of denoising to 5 initial completely random datasets.
nsteps =31x_ini = torch.randn_like(x)inputs = []outputs = []for i inrange(nsteps):with torch.no_grad(): ymod = model(x_ini)if i%10==0: inputs.append(x_ini) outputs.append(ymod) mix_factor =1/(nsteps - i) # how much denoising we apply x_ini = x_ini*(1-mix_factor) + ymod*mix_factor # mix noisy input with model predictioninputs = torch.stack(inputs)outputs = torch.stack(outputs)
Let’s now plot a few of these timesteps for the 5 noise inputs:
fig, ax = plt.subplots(figsize=(16,16), nrows=inputs.shape[1], ncols=inputs.shape[0], gridspec_kw={'hspace':0.4})for i inrange(inputs.shape[1]):for j inrange(inputs.shape[0]): ax[i,j].plot(cm.rad, datascale(inputs[j,i], xmean, xstd), '.') ax[i,j].plot(cm.rad, datascale(outputs[j,i], xmean, xstd), '-', lw=2)if i ==0: ax[i,j].set_ylim(-10, 350)if i ==1: ax[i,j].set_ylim(0, 250)if i ==2: ax[i,j].set_ylim(-100, 300)if i ==3: ax[i,j].set_ylim(-50, 300)if i ==4: ax[i,j].set_ylim(-100, 600)if j ==0: ax[i,j].set_ylabel('velocity / km/s')if i == inputs.shape[1]-1: ax[i,j].set_xlabel('radius / kpc')if i ==0: ax[i,j].set_title('step=%d'% (j*10), fontsize=14)
As we can see, with this procedure the model is able to generate some new rotation curves that have significantly different shapes starting from full random noise! This happens since after the autoencoder has been trained to denoise a real rotation curve database, the slow denoising pipeline based on small timesteps that we created allows the model to enhance some peculiar – and non average – feature of the curve that it has inferred by chance from random noise. We observe the model to slowly convince itself that some accumulation of points that happens by chance actually hides some signal, which is cleaned and enhanced at each timestep.
Next we will see how to guide the model to see some particular feature in the random noise that we want in our rotation curve output.