Commit 92edbf29 authored by Saswat's avatar Saswat

Add conditional ddpm; Add comment.

parent 8d0daf49
import numpy as np
'''
sin_train = np.load('./data/3d_sin_5_5_train.npy')
sin_test = np.load('./data/3d_sin_5_5_test.npy')
helix_train = np.load('./data/helix_3D_train.npy')
helix_test = np.load('./data/helix_3D_test.npy')
sin_bounds = np.load('./data/3d_sin_5_5_bounds.npy')
helix_bounds = np.load('./data/helix_3D_bounds.npy')
merged_train = np.concatenate((sin_train, helix_train), axis=0)
merged_test = np.concatenate((sin_test, helix_test), axis=0)
merged_bounds = np.concatenate((sin_bounds, helix_bounds), axis=0)
np.save('./data/3d_sin_helix_train.npy', merged_train)
np.save('./data/3d_sin_helix_test.npy', merged_test)
np.save('./data/3d_sin_helix_bounds.npy', merged_bounds)
'''
from generate import *
train_data = get_cylinder_helix('train')
print(train_data.shape)
test_data = get_cylinder_helix('test')
print(test_data.shape)
np.save('./data/3d_cylinder_helix_train.npy',train_data)
np.save('./data/3d_cylinder_helix_test.npy',test_data)
helix_3d_train = np.load('./data/helix_3D_train.npy')
helix_3d_test = np.load('./data/helix_3D_test.npy')
helix_3d_train[:,2] += 2
helix_3d_test[:,2] += 2
np.save('./data/shifted_helix_3D_train.npy',helix_3d_train)
np.save('./data/shifted_helix_3D_test.npy',helix_3d_test)
\ No newline at end of file
......@@ -3,6 +3,9 @@ import torch
import pytorch_lightning as pl
from finetune_model import FineTuneLitDiffusionModel
from dataset import ThreeDSinDataset
"""
Python file to perform the finetuning on the diffusion model.
"""
parser = argparse.ArgumentParser()
......
......@@ -18,6 +18,9 @@ def expand_alphas(batch, alpha, t):
class FineTuneLitDiffusionModel(pl.LightningModule):
"""
Class for finuetuning the diffusion model.
"""
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear",model_chkpt=None,model_hparams=None):
super().__init__()
"""
......@@ -162,59 +165,11 @@ class FineTuneLitDiffusionModel(pl.LightningModule):
frozen_unconditioning_score = self.frozen_model(x,t)
frozen_nc_score = self.frozen_model(nc,t)
#guidance = frozen_unconditioning_score - eta*(frozen_conditioning_score-frozen_unconditioning_score)
#a = frozen_conditioning_score
#b = frozen_unconditioning_score
#guidance =
#guidance = frozen_unconditioning_score + frozen_inv_conditioning_score
#temp = frozen_unconditioning_score - frozen_conditioning_score
#guidance = frozen_unconditioning_score + frozen_conditioning_score
#temp = ((eta)*(frozen_empty_score - frozen_unconditioning_score) - (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score))/(eta-1.0)
# temp = frozen_unconditioning_score - temp
#guidance = frozen_unconditioning_score - frozen_empty_score + frozen_unconditioning_score + (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score)
#guidance = (eta-1.0)*frozen_conditioning_score - (3.0-eta)*(frozen_unconditioning_score) - frozen_empty_score
#guidance = frozen_unconditioning_score - (eta)*(frozen_conditioning_score-frozen_unconditioning_score)
#temp = (eta/2.0)*((frozen_empty_score - frozen_unconditioning_score) - (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score))
#guidance = frozen_unconditioning_score - temp
#a = frozen_conditioning_score
#b = frozen_unconditioning_score
temp = frozen_unconditioning_score - frozen_nc_score
guidance = (frozen_unconditioning_score + (eta)*(temp))/(eta)
a = output
b = guidance - output
# a = frozen_empty_score - frozen_unconditioning_score
# b = frozen_empty_score - frozen_conditioning_score
# temp = a-b
#guidance = frozen_unconditioning_score - temp
#guidance = frozen_unconditioning_score - b
# fig = plt.figure()
# fig.add_subplot(141,projection='3d')
# ax = fig.add_subplot(141, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), a[:,2].detach().numpy(), c=batch_x[:,2], marker='o')
# ax2 = fig.add_subplot(142, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), b[:,2].detach().numpy(), c=batch_x[:,2], marker='o')
# ax3 = fig.add_subplot(143, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), temp[:,2].cpu().numpy(), c=batch_x[:,2], marker='o')
# ax4 = fig.add_subplot(144, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2], marker='o')
# plt.show()
# fig = plt.figure()
# fig.add_subplot(121,projection='3d')
# ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), output[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
# ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2].cpu().numpy(), marker='o')
# plt.show()
# plt.clf()
# guidance = frozen_empty_score + (eta-1.0)*(frozen_conditioning_score - frozen_unconditioning_score)
# fig = plt.figure()
# fig.add_subplot(121,projection='3d')
# ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), output[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
# ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2].cpu().numpy(), marker='o')
# plt.show()
# plt.clf()
# fig = plt.figure()
# fig.add_subplot(121,projection='3d')
# ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), x[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
# ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), c[:,2].cpu().numpy(), c=batch_conditioning[:,2].cpu().numpy(), marker='o')
# plt.show()
# plt.clf()
criteria = torch.nn.MSELoss()
return criteria(output, guidance)
......
......@@ -2,7 +2,9 @@ import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
"""
Contains various methods on generating distributions.
"""
def gaussian(x, mu, sig):
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
......
import numpy as np
from torch.utils.data import Dataset
class DistDataset(Dataset):
def __init__(self, npy_path, mean=None, std=None):
super().__init__()
if type(npy_path) == str:
self.data = np.load(npy_path)
else:
self.data1 = np.load(npy_path[0])
self.data2 = np.load(npy_path[1])
self.data = np.concatenate((self.data1, self.data2), axis=1)
print('Data shape',self.data.shape)
if mean is None:
mean = np.mean(self.data, axis=0)
if std is None:
std = np.std(self.data, axis=0)
#self.data = (self.data - mean) / std
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return self.data[index][0:4]
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np
def expand_alphas(batch, alpha, t):
"""
If t is not a tensor object than expand alpha[t] to shape of batch
else get alpha[t] in the shape of x
"""
if isinstance(t, int):
return alpha[t].expand(batch.size(0),1)
else:
return torch.zeros((batch.size(0),1)) + alpha[t][:,None]
class DDPM(pl.LightningModule):
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear"):
super().__init__()
"""
If you include more hyperparams (e.g. `n_layers`), be sure to add that to `argparse` from `train.py`.
Also, manually make sure that this new hyperparameter is being saved in `hparams.yaml`.
"""
self.save_hyperparameters()
"""
Your model implementation starts here. We have separate learnable modules for `time_embed` and `model`.
You may choose a different architecture altogether. Feel free to explore what works best for you.
If your architecture is just a sequence of `torch.nn.XXX` layers, using `torch.nn.Sequential` will be easier.
`time_embed` can be learned or a fixed function based on the insights you get from visualizing the data.
If your `model` is different for different datasets, you can use a hyperparameter to switch between them.
Make sure that your hyperparameter behaves as expecte and is being saved correctly in `hparams.yaml`.
"""
embedding_dim = 8
num_classes = 2
self.guidance_scale = 1
# nn.Embedding is used to embed time in a latend dimention that can be passed to the model
self.time_embed = nn.Embedding(n_steps,embedding_dim)
self.label_embed = nn.Embedding(num_classes, embedding_dim)
# The model is a sequential model consisting of linear layers with ReLU as activation function
self.model = nn.Sequential(
nn.Linear(embedding_dim + 3, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,3)
)
"""
Be sure to save at least these 2 parameters in the model instance.
"""
self.n_steps = n_steps
self.n_dim = n_dim
# New hyperparameter
# Scheduler type (liner/sigmoid/cosine) is stored
self.scheduler_type=scheduler_type
"""
Sets up variables for noise schedule
"""
self.init_alpha_beta_schedule(lbeta, ubeta)
def forward(self, x, t, label=None):
"""
Similar to `forward` function in `nn.Module`.
Notice here that `x` and `t` are passed separately. If you are using an architecture that combines
`x` and `t` in a different way, modify this function appropriately.
"""
if not isinstance(t, torch.Tensor):
t = torch.LongTensor([t]).expand(x.size(0))
t_embed = self.time_embed(t)
noise_prediction_uncond = self.model(torch.cat((x, t_embed), dim=1).float())
if label != None:
label = label.to(torch.int)
noise_prediction_text = self.model(torch.cat((x, t_embed+self.label_embed(label)), dim=1).float())
else:
noise_prediction_text = self.model(torch.cat((x, t_embed), dim=1).float())
noise_prediction = noise_prediction_uncond + self.guidance_scale * (noise_prediction_text - noise_prediction_uncond)
return noise_prediction
def init_alpha_beta_schedule(self, lbeta, ubeta):
"""
Set up your noise schedule. You can perhaps have an additional hyperparameter that allows you to
switch between various schedules for answering q4 in depth. Make sure that this hyperparameter
is included correctly while saving and loading your checkpoints.
"""
if self.scheduler_type=="linear":
self.beta = torch.linspace(lbeta,ubeta,self.n_steps)
elif self.scheduler_type == "sigmoid":
self.beta = torch.linspace(-6, 6, self.n_steps)
self.beta = torch.sigmoid(self.beta) * (ubeta - lbeta) + lbeta
elif self.scheduler_type == "cosine":
self.beta = torch.linspace(0, 1, self.n_steps)
self.beta = (torch.cos((self.beta)*np.pi/2)+0.002)/1.002 * (ubeta - lbeta) + lbeta
#Store different type of alpha and beta for speedup calculation
self.alpha = 1-self.beta
self.alpha_cum = torch.sqrt(torch.cumprod(1 - self.beta, 0))
self.alpha_cum_sqrt = torch.sqrt(self.alpha_cum)
self.one_min_alphas_sum_sqrt = torch.sqrt(1-self.alpha_cum)
def q_sample(self, x, t):
"""
Sample from q given x_t.
"""
alpha = expand_alphas(x,self.alpha_cum_sqrt, t)
one_minus_alpha = expand_alphas(x, 1-self.alpha_cum, t)
_q_sample = (alpha * x) + one_minus_alpha * torch.randn_like(x)
return _q_sample
def p_sample(self, x, t, label):
"""
Sample from p given x_t.
"""
epsilon_factor = (expand_alphas(x, self.beta, t) / expand_alphas(x, self.one_min_alphas_sum_sqrt, t))
epsilon_theta = self.forward(x, t, label)
mean = (1 / expand_alphas(x, torch.sqrt(self.alpha), t)) * (x - (epsilon_factor * epsilon_theta))
sigma = expand_alphas(x, torch.sqrt(self.beta), t)
sample = mean + sigma * torch.randn_like(x)
return sample
def training_step(self, batch, batch_idx):
"""
Implements one training step.
Given a batch of samples (n_samples, n_dim) from the distribution you must calculate the loss
for this batch. Simply return this loss from this function so that PyTorch Lightning will
automatically do the backprop for you.
Refer to the DDPM paper [1] for more details about equations that you need to implement for
calculating loss. Make sure that all the operations preserve gradients for proper backprop.
Refer to PyTorch Lightning documentation [2,3] for more details about how the automatic backprop
will update the parameters based on the loss you return from this function.
References:
[1]: https://arxiv.org/abs/2006.11239
[2]: https://pytorch-lightning.readthedocs.io/en/stable/
[3]: https://www.pytorchlightning.ai/tutorials
"""
label = batch[:, 3]
batch = batch[:, :3]
t = torch.randint(0, self.n_steps, size=(batch.shape[0],))
#print(t.shape)
alpha_sqrt = expand_alphas(batch, self.alpha_cum_sqrt, t)
#print(alpha_sqrt.shape)
one_min_alpha_sqrt = expand_alphas(batch, self.one_min_alphas_sum_sqrt, t)
#print(one_min_alpha_sqrt.shape)
noise = torch.randn_like(batch)
#print(noise.shape)
x = batch * alpha_sqrt + noise * one_min_alpha_sqrt
output = self.forward(x, t, label=label)
return (noise - output).square().mean()
def sample(self, n_samples, label, return_intermediate=False):
"""
Implements inference step for the DDPM.
`progress` is an optional flag to implement -- it should just show the current step in diffusion
reverse process.
If `return_intermediate` is `False`,
the function returns a `n_samples` sampled from the learned DDPM
i.e. a Tensor of size (n_samples, n_dim).
Return: (n_samples, n_dim)(final result from diffusion)
Else
the function returns all the intermediate steps in the diffusion process as well
i.e. a Tensor of size (n_samples, n_dim) and a list of `self.n_steps` Tensors of size (n_samples, n_dim) each.
Return: (n_samples, n_dim)(final result), [(n_samples, n_dim)(intermediate) x n_steps]
"""
if return_intermediate:
out_samples = []
out_samples.append(torch.randn((n_samples, self.n_dim)))
for t in range(self.n_steps-1, -1, -1):
out_samples.append(self.p_sample(out_samples[-1], t, label))
return out_samples[-1], out_samples
else:
out_sample = torch.randn((n_samples, self.n_dim))
for t in range(self.n_steps-1, -1, -1):
out_sample = self.p_sample(out_sample, t, label)
return out_sample
def configure_optimizers(self):
"""
Sets up the optimizer to be used for backprop.
Must return a `torch.optim.XXX` instance.
You may choose to add certain hyperparameters of the optimizers to the `train.py` as well.
In our experiments, we chose one good value of optimizer hyperparameters for all experiments.
"""
# We have experimented with both SGD and Adam optimiser
#return torch.optim.SGD(self.model.parameters(), lr=0.01)
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np
from ddpm import DDPM
def expand_alphas(batch, alpha, t):
"""
If t is not a tensor object than expand alpha[t] to shape of batch
else get alpha[t] in the shape of x
"""
if isinstance(t, int):
return alpha[t].expand(batch.size(0),1)
else:
return torch.zeros((batch.size(0),1)) + alpha[t][:,None]
class EDDPM(pl.LightningModule):
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear",model_chkpt=None,model_hparams=None):
super().__init__()
"""
If you include more hyperparams (e.g. `n_layers`), be sure to add that to `argparse` from `train.py`.
Also, manually make sure that this new hyperparameter is being saved in `hparams.yaml`.
"""
self.save_hyperparameters()
"""
Your model implementation starts here. We have separate learnable modules for `time_embed` and `model`.
You may choose a different architecture altogether. Feel free to explore what works best for you.
If your architecture is just a sequence of `torch.nn.XXX` layers, using `torch.nn.Sequential` will be easier.
`time_embed` can be learned or a fixed function based on the insights you get from visualizing the data.
If your `model` is different for different datasets, you can use a hyperparameter to switch between them.
Make sure that your hyperparameter behaves as expecte and is being saved correctly in `hparams.yaml`.
"""
embedding_dim = 8
# nn.Embedding is used to embed time in a latend dimention that can be passed to the model
#self.time_embed = nn.Embedding(n_steps,embedding_dim)
self.model = DDPM.load_from_checkpoint(model_chkpt,hparams=model_hparams)
self.frozen_model = DDPM.load_from_checkpoint(model_chkpt,hparams=model_hparams)
self.frozen_model.eval()
"""
Be sure to save at least these 2 parameters in the model instance.
"""
self.n_steps = n_steps
self.n_dim = n_dim
# New hyperparameter
# Scheduler type (liner/sigmoid/cosine) is stored
self.scheduler_type=scheduler_type
"""
Sets up variables for noise schedule
"""
self.init_alpha_beta_schedule(lbeta, ubeta)
def forward(self, x, t, label):
"""
Similar to `forward` function in `nn.Module`.
Notice here that `x` and `t` are passed separately. If you are using an architecture that combines
`x` and `t` in a different way, modify this function appropriately.
"""
if not isinstance(t, torch.Tensor):
t = torch.LongTensor([t]).expand(x.size(0))
#t_embed = self.time_embed(t)
#return self.model(torch.cat((x, t_embed), dim=1).float())
return self.model.forward(x,t,label)
def init_alpha_beta_schedule(self, lbeta, ubeta):
"""
Set up your noise schedule. You can perhaps have an additional hyperparameter that allows you to
switch between various schedules for answering q4 in depth. Make sure that this hyperparameter
is included correctly while saving and loading your checkpoints.
"""
if self.scheduler_type=="linear":
self.beta = torch.linspace(lbeta,ubeta,self.n_steps)
elif self.scheduler_type == "sigmoid":
self.beta = torch.linspace(-6, 6, self.n_steps)
self.beta = torch.sigmoid(self.beta) * (ubeta - lbeta) + lbeta
elif self.scheduler_type == "cosine":
self.beta = torch.linspace(0, 1, self.n_steps)
self.beta = (torch.cos((self.beta)*np.pi/2)+0.002)/1.002 * (ubeta - lbeta) + lbeta
#Store different type of alpha and beta for speedup calculation
self.alpha = 1-self.beta
self.alpha_cum = torch.sqrt(torch.cumprod(1 - self.beta, 0))
self.alpha_cum_sqrt = torch.sqrt(self.alpha_cum)
self.one_min_alphas_sum_sqrt = torch.sqrt(1-self.alpha_cum)
def q_sample(self, x, t):
"""
Sample from q given x_t.
"""
alpha = expand_alphas(x,self.alpha_cum_sqrt, t)
one_minus_alpha = expand_alphas(x, 1-self.alpha_cum, t)
_q_sample = (alpha * x) + one_minus_alpha * torch.randn_like(x)
return _q_sample
def p_sample(self, x, t, label):
"""
Sample from p given x_t.
"""
epsilon_factor = (expand_alphas(x, self.beta, t) / expand_alphas(x, self.one_min_alphas_sum_sqrt, t))
epsilon_theta = self.forward(x, t, label)
mean = (1 / expand_alphas(x, torch.sqrt(self.alpha), t)) * (x - (epsilon_factor * epsilon_theta))
sigma = expand_alphas(x, torch.sqrt(self.beta), t)
sample = mean + sigma * torch.randn_like(x)
return sample
def training_step(self, batch, batch_idx):
"""
Implements one training step.
Given a batch of samples (n_samples, n_dim) from the distribution you must calculate the loss
for this batch. Simply return this loss from this function so that PyTorch Lightning will
automatically do the backprop for you.
Refer to the DDPM paper [1] for more details about equations that you need to implement for
calculating loss. Make sure that all the operations preserve gradients for proper backprop.
Refer to PyTorch Lightning documentation [2,3] for more details about how the automatic backprop
will update the parameters based on the loss you return from this function.
References:
[1]: https://arxiv.org/abs/2006.11239
[2]: https://pytorch-lightning.readthedocs.io/en/stable/
[3]: https://www.pytorchlightning.ai/tutorials
"""
batch_x = batch[:,:3]
label = batch[:,3]
#print(batch_x.shape)
#batch_conditioning = batch[:,3:]
#print(batch_conditioning.shape)
t = torch.zeros((batch_x.shape[0],), dtype = torch.long) + 0
#t = torch.randint(0, self.n_steps, size=(batch_x.shape[0],))
#print(t.shape)
alpha_sqrt = expand_alphas(batch_x, self.alpha_cum_sqrt, t)
#print(alpha_sqrt.shape)
one_min_alpha_sqrt = expand_alphas(batch_x, self.one_min_alphas_sum_sqrt, t)
#print(one_min_alpha_sqrt.shape)
noise = torch.randn_like(batch_x)
#print(noise.shape)
x = batch_x * alpha_sqrt + noise * one_min_alpha_sqrt
# c = batch_conditioning * alpha_sqrt + noise * one_min_alpha_sqrt
negative_latent = self.forward(x, t, label=label)
with torch.no_grad():
positive_latent = self.frozen_model(x,t, label)
neutral_latent = self.frozen_model(x,t, None)
eta = 1
guidance = neutral_latent - eta*(positive_latent-neutral_latent)
# fig = plt.figure()
# fig.add_subplot(141,projection='3d')
# ax = fig.add_subplot(141, projection='3d').scatter(positive_latent[:,0].cpu().numpy(), positive_latent[:,1].cpu().numpy(), positive_latent[:,2].cpu().numpy(), marker='o')
# ax2 = fig.add_subplot(142, projection='3d').scatter(neutral_latent[:,0].cpu().numpy(), neutral_latent[:,1].cpu().numpy(), neutral_latent[:,2].cpu().numpy(), marker='o')
# ax3 = fig.add_subplot(143, projection='3d').scatter(guidance[:,0].cpu().numpy(), guidance[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), marker='o')
# ax4 = fig.add_subplot(144, projection='3d').scatter(negative_latent.detach().numpy()[:,0], negative_latent.detach().numpy()[:,1], negative_latent.detach().numpy()[:,2], marker='o')
# plt.show()
#guidance = frozen_unconditioning_score - guidance
criteria = torch.nn.MSELoss()
return criteria(negative_latent, guidance)
#guidance = frozen_unconditioning_score-frozen_conditioning_score
#guidance = frozen_conditioning_score - eta*(frozen_unconditioning_score-frozen_conditioning_score)
#guidance = frozen_unconditioning_score - eta*(frozen_unconditioning_score-frozen_conditioning_score)
#return (guidance-output).square().mean()
def sample(self, n_samples, label, progress=False, return_intermediate=False):
"""
Implements inference step for the DDPM.
`progress` is an optional flag to implement -- it should just show the current step in diffusion
reverse process.
If `return_intermediate` is `False`,
the function returns a `n_samples` sampled from the learned DDPM
i.e. a Tensor of size (n_samples, n_dim).
Return: (n_samples, n_dim)(final result from diffusion)
Else
the function returns all the intermediate steps in the diffusion process as well
i.e. a Tensor of size (n_samples, n_dim) and a list of `self.n_steps` Tensors of size (n_samples, n_dim) each.
Return: (n_samples, n_dim)(final result), [(n_samples, n_dim)(intermediate) x n_steps]
"""
if return_intermediate:
out_samples = []
out_samples.append(torch.randn((n_samples, self.n_dim)))
for t in range(self.n_steps-1, -1, -1):
out_samples.append(self.p_sample(out_samples[-1], t))
return out_samples[-1], out_samples
else:
out_sample = torch.randn((n_samples, self.n_dim))
for t in range(self.n_steps-1, -1, -1):
out_sample = self.p_sample(out_sample, t, label)
return out_sample
def configure_optimizers(self):
"""
Sets up the optimizer to be used for backprop.
Must return a `torch.optim.XXX` instance.
You may choose to add certain hyperparameters of the optimizers to the `train.py` as well.
In our experiments, we chose one good value of optimizer hyperparameters for all experiments.
"""
# We have experimented with both SGD and Adam optimiser
#return torch.optim.SGD(self.model.parameters(), lr=0.01)
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
\ No newline at end of file
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
sample_sz = 10000
def gaussian(x, mu, sig):
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
def gaussian_3d(x, y, mu, sig):
return np.exp(-np.power(x - mu[0], 2.) / (2 * np.power(sig[0], 2.))) * np.exp(-np.power(y - mu[1], 2.) / (2 * np.power(sig[1], 2.)))
def generate_two_gaussian(n_samples=100):
x = np.linspace(-3, 3, n_samples)
y = np.linspace(-3, 3, n_samples)
X, Y = np.meshgrid(x, y)
mu = [-1, -1, 0]
sig = [0.5, 0.5, 0.5]
z = gaussian_3d(X, Y, mu, sig)
mu = [1, 1, 0]
new_z = gaussian_3d(X, Y, mu, sig)
z = z + new_z
data = np.concatenate((X.reshape(-1,1), Y.reshape(-1,1), z.reshape(-1,1)), axis=1)
#print(data.shape)
return data
def generate_one_gaussian(n_samples=100):
x = np.linspace(-3, 3, n_samples)
y = np.linspace(-3, 3, n_samples)
X, Y = np.meshgrid(x, y)
mu = [-1, -1, 0]
sig = [0.5, 0.5, 0.5]
z = gaussian_3d(X, Y, mu, sig)
data = np.concatenate((X.reshape(-1,1), Y.reshape(-1,1), z.reshape(-1,1)), axis=1)
#print(data.shape)
return data
def gen_gaus_train():
mean1 = [1.5,1.5,1.5]
mean2 = [-1.5,-1.5,-1.5]
var = [[1,0,0],[0,1,0],[0,0,1]]
data1 = np.random.multivariate_normal(mean1, var, sample_sz)
data2 = np.random.multivariate_normal(mean2, var, sample_sz)
return np.concatenate((data1,data2), axis=0)
def gen_gaus_finetune():
mean = [-1.5,-1.5,-1.5]
var = [[1,0,0],[0,1,0],[0,0,1]]
return np.random.multivariate_normal(mean, var, 2*sample_sz)
def gen_sphere():
vec = np.random.randn(sample_sz, 3)
vec /= np.linalg.norm(vec, axis=1, keepdims=True)
return 4*vec#/2 +0.5
#
# print(z.size)
return np.concatenate((np.concatenate((data_xy,data_xy),axis = 0),z), axis=1)
def gen_gaus_test():
mean = [1.5,1.5,1.5]
var = [[1,0,0],[0,1,0],[0,0,1]]
return np.random.multivariate_normal(mean, var, 2*sample_sz)
def gen_plane_data():
data = np.zeros((sample_sz, 1))
data_xy = np.random.rand(sample_sz, 2) * 8 -4
return np.concatenate((data_xy, data),axis=1)
def gen_plane_data_xy():
data = np.zeros((sample_sz, 1))
data_xy = np.random.rand(sample_sz, 2) * 8 -4
return np.concatenate((data_xy, data),axis=1)
def gen_plane_data_yz():
data = np.zeros((sample_sz, 1))
data_xy = np.random.rand(sample_sz, 2) * 8 -4
return np.concatenate((data,data_xy),axis=1)
def gen_curve_data():
data_xy = np.random.rand(sample_sz, 2) * 8 - 4
return np.concatenate((data_xy, np.sum(np.power(data_xy,3),axis=1,keepdims=True)),axis=1)
def gen_sphere_test():
mean = [0,0,0]
var = [[1,0,0],[0,1,0],[0,0,1]]
return np.random.multivariate_normal(mean, var, sample_sz)
train_data = [gen_plane_data_xy, gen_sphere_test]
train_idx = [0, 1]
finetune_data = [gen_plane_data_xy]
test_idx = [0]
train_dataset = []
for i, data_gen in enumerate(train_data):
train_dataset.append(np.concatenate(
(data_gen(),np.zeros((sample_sz,1),dtype = int)+train_idx[i]),axis = 1))
train_dataset = np.concatenate(train_dataset, axis=0)
finetune_dataset = []
for i, data_gen in enumerate(finetune_data):
finetune_dataset.append(np.concatenate(
(data_gen(),np.zeros((sample_sz,1),dtype = int)+test_idx[i]),axis = 1))
finetune_dataset = np.concatenate(finetune_dataset, axis=0)
np.save('./data/gaussian_3D_train.npy', train_dataset)
np.save('./data/gaussian_3D_ft.npy', finetune_dataset)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(train_dataset[:, 0], train_dataset[:, 1], train_dataset[:, 2], c=train_dataset[:, 3], alpha=0.2)
#ax.set_zlim(0, 3)
# ax.set_xlim(-4, 4)
# ax.set_ylim(-4, 4)
# ax.set_zlim(-4, 4)
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(finetune_dataset[:, 0], finetune_dataset[:, 1], finetune_dataset[:, 2], c="red", alpha=0.2)
#ax.set_zlim(0, 3)
plt.show()
'''
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import pandas as pd
from scipy.stats import multivariate_normal
import torch
#3d cylinder data generate
sample_sz = 10000
def cylinder(r=0.25, h=1, n=10000):
#function which generates 3d points on a cylinder
n_samples = n
points = np.zeros((n_samples, 3))
theta = np.random.uniform(0, 2*np.pi, n_samples)
z = np.random.uniform(0, h, n_samples)
points[:, 0] = r * np.cos(theta)
points[:, 1] = r * np.sin(theta)
points[:, 2] = z
return points
def get_cylinder_helix(type_data='train'):
helix_3d = np.load(f'./data/helix_3D_{type_data}.npy')
#shift z axis of helix_3d
helix_3d[:, 2] += 2
n_samples = helix_3d.shape[0]
#print(n_samples)
cylinder_data = cylinder(0.05, 1,n=n_samples)
concat = np.concatenate((cylinder_data, helix_3d), axis=0)
np.random.shuffle(concat)
df = pd.DataFrame(concat, columns=['x','y','z'])
df = df.drop_duplicates(['x','y'])
df = df.drop_duplicates(['x','z'])
df = df.drop_duplicates(['y','z'])
concat = df.values
mean = np.mean(concat, axis=0)
std = np.std(concat, axis=0)
concat = (concat - mean) / std
#print('type',type(concat),'concat',concat.shape)
return concat
cylinder_data = cylinder(0.05, 1,n=10000)
data = get_cylinder_helix('test')
cylinder_data = np.concatenate((cylinder_data, np.ones((sample_sz,1), dtype=int)),axis=1)
data = np.concatenate((data, np.zeros((sample_sz,1), dtype=int)),axis=1)
data = np.concatenate((cylinder_data,data), axis=0)
print(data)
print(cylinder_data)
np.save('./data/gaussian_3D_train.npy', data)
np.save('./data/gaussian_3D_ft.npy', cylinder_data)
#scatter plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=data[:, 2], cmap=cm.spring, alpha=0.1)
#ax.scatter(cylinder_data[:, 0], cylinder_data[:, 1], cylinder_data[:, 2], c=cylinder_data[:, 2], cmap=cm.spring, alpha=0.1)
#ax.set_zlim(0, 3)
plt.show()
'''
lbeta: 1.0e-05
n_dim: 3
n_steps: 100
scheduler_type: linear
ubeta: 0.0128
lbeta: 1.0e-05
n_dim: 3
n_steps: 100
scheduler_type: linear
ubeta: 0.0128
lbeta: 1.0e-05
model_chkpt: run_ddpm/last.ckpt
model_hparams: run_ddpm/lightning_logs/version_0/hparams.yaml
n_dim: 3
n_steps: 100
scheduler_type: linear
ubeta: 0.0128
import os
import argparse
import torch
import numpy as np
from ddpm import DDPM
from eddpm import EDDPM
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('-e','--erased', action='store_true', help='model type ddpm/eddpm')
args = parser.parse_args()
if args.erased:
model_type = "eddpm"
else:
model_type = "ddpm"
if args.erased:
litmodel = EDDPM.load_from_checkpoint(
"run_eddpm/last.ckpt",
hparams_file= "./run_eddpm/lightning_logs/version_0/hparams.yaml"
)
else:
litmodel = DDPM.load_from_checkpoint(
"run_ddpm/last.ckpt",
hparams_file= "run_ddpm/lightning_logs/version_0/hparams.yaml"
)
litmodel.eval()
sample_sz = 10000
labels = [0, 1, None]
for label in labels:
with torch.no_grad():
print(f"Sampling for Label {label}")
if label!= None:
label = torch.tensor(label)
gendata = litmodel.sample(sample_sz, label)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(gendata[:, 0], gendata[:, 1], gendata[:, 2], c=gendata[:, 2], alpha=0.1)
#ax.scatter(testdata[:, 0], testdata[:, 1], testdata[:, 2], c=testdata[:, 2], alpha=0.1)
ax.set_xlim(-4, 4)
ax.set_ylim(-4, 4)
ax.set_zlim(-4, 4)
ax.set_xlabel("x axis")
ax.set_ylabel("y axis")
ax.set_zlabel("z axis")
plt.savefig(f"results/{model_type}_{label}.pdf")
plt.show()
"""
label = torch.tensor(1)
print(label)
with torch.no_grad():
gendata = litmodel.sample(10000, label)
#temp_data = np.concatenate((testdata.numpy(), gendata.numpy()), axis=0)
#print(f'gendata.shape = {gendata.shape}')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(gendata[:, 0], gendata[:, 1], gendata[:, 2], c=gendata[:, 2], cmap=cm.spring, alpha=0.1)
#ax.scatter(testdata[:, 0], testdata[:, 1], testdata[:, 2], c=testdata[:, 2], alpha=0.1)
# ax.set_xlim(-3, 3)
# ax.set_ylim(-3, 3)
# ax.set_zlim(0, 1)
plt.show()
label = None
print(label)
with torch.no_grad():
gendata = litmodel.sample(10000, label)
#temp_data = np.concatenate((testdata.numpy(), gendata.numpy()), axis=0)
#print(f'gendata.shape = {gendata.shape}')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(gendata[:, 0], gendata[:, 1], gendata[:, 2], c=gendata[:, 2], cmap=cm.spring, alpha=0.1)
#ax.scatter(testdata[:, 0], testdata[:, 1], testdata[:, 2], c=testdata[:, 2], alpha=0.1)
# ax.set_xlim(-3, 3)
# ax.set_ylim(-3, 3)
# ax.set_zlim(0, 1)
plt.show()"""
\ No newline at end of file
import argparse
import torch
import pytorch_lightning as pl
from ddpm import DDPM
from dataset import DistDataset
parser = argparse.ArgumentParser()
model_args = parser.add_argument_group('model')
model_args.add_argument('--n_dim', type=int, default=3, help='Number of dimensions')
model_args.add_argument('--n_steps', type=int, default=100, help='Number of diffusion steps')
model_args.add_argument('--lbeta', type=float, default=1e-5, help='Lower bound of beta')
model_args.add_argument('--ubeta', type=float, default=1.28e-2, help='Upper bound of beta')
model_args.add_argument('--scheduler_type', type=str, default="linear", help='Variance Scheduling Function')
training_args = parser.add_argument_group('training')
training_args.add_argument('--seed', type=int, default=1618, help='Random seed for experiments')
training_args.add_argument('--n_epochs', type=int, default=100, help='Number of training epochs')
training_args.add_argument('--batch_size', type=int, default=1024, help='Batch size for training dataloader')
training_args.add_argument('--train_data_path', type=str, default='./data/gaussian_3D_train.npy', help='Path to training data numpy file')
training_args.add_argument('--savedir', type=str, default='./run_ddpm/', help='Root directory where all checkpoint and logs will be saved')
args = parser.parse_args()
n_dim = args.n_dim
n_steps = args.n_steps
lbeta = args.lbeta
ubeta = args.ubeta
scheduler_type=args.scheduler_type
pl.seed_everything(args.seed)
batch_size = args.batch_size
n_epochs = args.n_epochs
savedir = args.savedir
litmodel = DDPM(
n_dim=n_dim,
n_steps=n_steps,
lbeta=lbeta,
ubeta=ubeta,
scheduler_type=scheduler_type
)
train_dataset = DistDataset(args.train_data_path)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
run_name = f'n_dim={n_dim},n_steps={n_steps},lbeta={lbeta:.3e},ubeta={ubeta:.3e},scheduler_type={scheduler_type},batch_size={batch_size},n_epochs={n_epochs}'
trainer = pl.Trainer(
deterministic=True,
logger=pl.loggers.TensorBoardLogger(f'{savedir}/'),
max_epochs=n_epochs,
log_every_n_steps=1,
callbacks=[
# A dummy model checkpoint callback that stores the latest model at the end of every epoch
pl.callbacks.ModelCheckpoint(
dirpath=f'{savedir}/',
filename='{epoch:04d}-{train_loss:.3f}',
save_top_k=1,
monitor='epoch',
mode='max',
save_last=True,
every_n_epochs=10,
),
]
)
trainer.fit(model=litmodel, train_dataloaders=train_dataloader)
import argparse
import torch
import pytorch_lightning as pl
from eddpm import EDDPM
from dataset import DistDataset
parser = argparse.ArgumentParser()
model_args = parser.add_argument_group('model')
model_args.add_argument('--n_dim', type=int, default=3, help='Number of dimensions')
model_args.add_argument('--n_steps', type=int, default=100, help='Number of diffusion steps')
model_args.add_argument('--lbeta', type=float, default=1e-5, help='Lower bound of beta')
model_args.add_argument('--ubeta', type=float, default=1.28e-2, help='Upper bound of beta')
model_args.add_argument('--scheduler_type', type=str, default="linear", help='Variance Scheduling Function')
training_args = parser.add_argument_group('training')
training_args.add_argument('--seed', type=int, default=1618, help='Random seed for experiments')
training_args.add_argument('--n_epochs', type=int, default=100, help='Number of training epochs')
training_args.add_argument('--batch_size', type=int, default=1024, help='Batch size for training dataloader')
training_args.add_argument('--train_data_path', type=str, default='./data/gaussian_3D_train.npy', help='Path to training data numpy file')
training_args.add_argument('--ft_data_path', type=str, default='./data/gaussian_3D_ft.npy', help='Path to conditioning data numpy file')
training_args.add_argument('--savedir', type=str, default='./run_eddpm/', help='Root directory where all checkpoint and logs will be saved')
args = parser.parse_args()
n_dim = args.n_dim
n_steps = args.n_steps
lbeta = args.lbeta
ubeta = args.ubeta
scheduler_type=args.scheduler_type
pl.seed_everything(args.seed)
batch_size = args.batch_size
n_epochs = args.n_epochs
savedir = args.savedir
litmodel = EDDPM(
n_dim=n_dim,
n_steps=n_steps,
lbeta=lbeta,
ubeta=ubeta,
scheduler_type=scheduler_type,
model_chkpt='run_ddpm/last.ckpt',
model_hparams='run_ddpm/lightning_logs/version_0/hparams.yaml'
)
for name, param in litmodel.named_parameters():
if 'frozen_model' in name:
param.requires_grad = False
print(name, param.requires_grad)
train_dataset = DistDataset(args.ft_data_path)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
run_name = f'n_dim={n_dim},n_steps={n_steps},lbeta={lbeta:.3e},ubeta={ubeta:.3e},scheduler_type={scheduler_type},batch_size={batch_size},n_epochs={n_epochs}'
trainer = pl.Trainer(
deterministic=True,
logger=pl.loggers.TensorBoardLogger(f'{savedir}/'),
max_epochs=n_epochs,
log_every_n_steps=10,
callbacks=[
# A dummy model checkpoint callback that stores the latest model at the end of every epoch
pl.callbacks.ModelCheckpoint(
dirpath=f'{savedir}/',
filename='{epoch:04d}-{train_loss:.3f}',
save_top_k=1,
monitor='epoch',
mode='max',
save_last=True,
every_n_epochs=1,
),
]
)
trainer.fit(model=litmodel, train_dataloaders=train_dataloader)
diff --git a/abstract_ddpm/convert.py b/abstract_ddpm/convert.py
deleted file mode 100644
index 84b274f..0000000
--- a/abstract_ddpm/convert.py
+++ /dev/null
@@ -1,36 +0,0 @@
-
-import numpy as np
-'''
-sin_train = np.load('./data/3d_sin_5_5_train.npy')
-sin_test = np.load('./data/3d_sin_5_5_test.npy')
-helix_train = np.load('./data/helix_3D_train.npy')
-helix_test = np.load('./data/helix_3D_test.npy')
-sin_bounds = np.load('./data/3d_sin_5_5_bounds.npy')
-helix_bounds = np.load('./data/helix_3D_bounds.npy')
-
-merged_train = np.concatenate((sin_train, helix_train), axis=0)
-merged_test = np.concatenate((sin_test, helix_test), axis=0)
-merged_bounds = np.concatenate((sin_bounds, helix_bounds), axis=0)
-
-np.save('./data/3d_sin_helix_train.npy', merged_train)
-np.save('./data/3d_sin_helix_test.npy', merged_test)
-np.save('./data/3d_sin_helix_bounds.npy', merged_bounds)
-'''
-from generate import *
-
-train_data = get_cylinder_helix('train')
-print(train_data.shape)
-test_data = get_cylinder_helix('test')
-print(test_data.shape)
-
-np.save('./data/3d_cylinder_helix_train.npy',train_data)
-np.save('./data/3d_cylinder_helix_test.npy',test_data)
-
-helix_3d_train = np.load('./data/helix_3D_train.npy')
-helix_3d_test = np.load('./data/helix_3D_test.npy')
-
-helix_3d_train[:,2] += 2
-helix_3d_test[:,2] += 2
-
-np.save('./data/shifted_helix_3D_train.npy',helix_3d_train)
-np.save('./data/shifted_helix_3D_test.npy',helix_3d_test)
\ No newline at end of file
diff --git a/abstract_ddpm/finetune.py b/abstract_ddpm/finetune.py
index 5fea7c8..932feaa 100644
--- a/abstract_ddpm/finetune.py
+++ b/abstract_ddpm/finetune.py
@@ -3,6 +3,9 @@ import torch
import pytorch_lightning as pl
from finetune_model import FineTuneLitDiffusionModel
from dataset import ThreeDSinDataset
+"""
+Python file to perform the finetuning on the diffusion model.
+"""
parser = argparse.ArgumentParser()
diff --git a/abstract_ddpm/finetune_model.py b/abstract_ddpm/finetune_model.py
index 9c07a0f..39b5e65 100644
--- a/abstract_ddpm/finetune_model.py
+++ b/abstract_ddpm/finetune_model.py
@@ -18,6 +18,9 @@ def expand_alphas(batch, alpha, t):
class FineTuneLitDiffusionModel(pl.LightningModule):
+ """
+ Class for finuetuning the diffusion model.
+ """
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear",model_chkpt=None,model_hparams=None):
super().__init__()
"""
@@ -162,59 +165,11 @@ class FineTuneLitDiffusionModel(pl.LightningModule):
frozen_unconditioning_score = self.frozen_model(x,t)
frozen_nc_score = self.frozen_model(nc,t)
- #guidance = frozen_unconditioning_score - eta*(frozen_conditioning_score-frozen_unconditioning_score)
- #a = frozen_conditioning_score
- #b = frozen_unconditioning_score
- #guidance =
- #guidance = frozen_unconditioning_score + frozen_inv_conditioning_score
- #temp = frozen_unconditioning_score - frozen_conditioning_score
- #guidance = frozen_unconditioning_score + frozen_conditioning_score
- #temp = ((eta)*(frozen_empty_score - frozen_unconditioning_score) - (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score))/(eta-1.0)
- # temp = frozen_unconditioning_score - temp
- #guidance = frozen_unconditioning_score - frozen_empty_score + frozen_unconditioning_score + (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score)
- #guidance = (eta-1.0)*frozen_conditioning_score - (3.0-eta)*(frozen_unconditioning_score) - frozen_empty_score
- #guidance = frozen_unconditioning_score - (eta)*(frozen_conditioning_score-frozen_unconditioning_score)
- #temp = (eta/2.0)*((frozen_empty_score - frozen_unconditioning_score) - (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score))
- #guidance = frozen_unconditioning_score - temp
- #a = frozen_conditioning_score
- #b = frozen_unconditioning_score
temp = frozen_unconditioning_score - frozen_nc_score
guidance = (frozen_unconditioning_score + (eta)*(temp))/(eta)
a = output
b = guidance - output
- # a = frozen_empty_score - frozen_unconditioning_score
- # b = frozen_empty_score - frozen_conditioning_score
- # temp = a-b
- #guidance = frozen_unconditioning_score - temp
- #guidance = frozen_unconditioning_score - b
- # fig = plt.figure()
- # fig.add_subplot(141,projection='3d')
- # ax = fig.add_subplot(141, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), a[:,2].detach().numpy(), c=batch_x[:,2], marker='o')
- # ax2 = fig.add_subplot(142, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), b[:,2].detach().numpy(), c=batch_x[:,2], marker='o')
- # ax3 = fig.add_subplot(143, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), temp[:,2].cpu().numpy(), c=batch_x[:,2], marker='o')
- # ax4 = fig.add_subplot(144, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2], marker='o')
- # plt.show()
- # fig = plt.figure()
- # fig.add_subplot(121,projection='3d')
- # ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), output[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
- # ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2].cpu().numpy(), marker='o')
- # plt.show()
- # plt.clf()
- # guidance = frozen_empty_score + (eta-1.0)*(frozen_conditioning_score - frozen_unconditioning_score)
- # fig = plt.figure()
- # fig.add_subplot(121,projection='3d')
- # ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), output[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
- # ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2].cpu().numpy(), marker='o')
- # plt.show()
- # plt.clf()
-
- # fig = plt.figure()
- # fig.add_subplot(121,projection='3d')
- # ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), x[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
- # ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), c[:,2].cpu().numpy(), c=batch_conditioning[:,2].cpu().numpy(), marker='o')
- # plt.show()
- # plt.clf()
criteria = torch.nn.MSELoss()
return criteria(output, guidance)
diff --git a/abstract_ddpm/generate.py b/abstract_ddpm/generate.py
index 2c2e120..6af11c0 100644
--- a/abstract_ddpm/generate.py
+++ b/abstract_ddpm/generate.py
@@ -2,7 +2,9 @@ import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
-
+"""
+Contains various methods on generating distributions.
+"""
def gaussian(x, mu, sig):
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
diff --git a/conditional_ddpm/ddpm.py b/conditional_ddpm/ddpm.py
index 8a4a4a4..5eb018e 100644
--- a/conditional_ddpm/ddpm.py
+++ b/conditional_ddpm/ddpm.py
@@ -37,7 +37,7 @@ class DDPM(pl.LightningModule):
embedding_dim = 8
num_classes = 2
-
+ self.guidance_scale = 1
# nn.Embedding is used to embed time in a latend dimention that can be passed to the model
self.time_embed = nn.Embedding(n_steps,embedding_dim)
self.label_embed = nn.Embedding(num_classes, embedding_dim)
@@ -81,10 +81,15 @@ class DDPM(pl.LightningModule):
if not isinstance(t, torch.Tensor):
t = torch.LongTensor([t]).expand(x.size(0))
t_embed = self.time_embed(t)
+
+ noise_prediction_uncond = self.model(torch.cat((x, t_embed), dim=1).float())
if label != None:
label = label.to(torch.int)
- t_embed += self.label_embed(label)
- return self.model(torch.cat((x, t_embed), dim=1).float())
+ noise_prediction_text = self.model(torch.cat((x, t_embed+self.label_embed(label)), dim=1).float())
+ else:
+ noise_prediction_text = self.model(torch.cat((x, t_embed), dim=1).float())
+ noise_prediction = noise_prediction_uncond + self.guidance_scale * (noise_prediction_text - noise_prediction_uncond)
+ return noise_prediction
def init_alpha_beta_schedule(self, lbeta, ubeta):
diff --git a/conditional_ddpm/results/ddpm_0.pdf b/conditional_ddpm/results/ddpm_0.pdf
index ab58b03..45d0600 100644
Binary files a/conditional_ddpm/results/ddpm_0.pdf and b/conditional_ddpm/results/ddpm_0.pdf differ
diff --git a/conditional_ddpm/results/ddpm_1.pdf b/conditional_ddpm/results/ddpm_1.pdf
index 54e41b8..843f831 100644
Binary files a/conditional_ddpm/results/ddpm_1.pdf and b/conditional_ddpm/results/ddpm_1.pdf differ
diff --git a/conditional_ddpm/results/ddpm_None.pdf b/conditional_ddpm/results/ddpm_None.pdf
index 6a4264a..4deaa32 100644
Binary files a/conditional_ddpm/results/ddpm_None.pdf and b/conditional_ddpm/results/ddpm_None.pdf differ
diff --git a/conditional_ddpm/results/eddpm_0.pdf b/conditional_ddpm/results/eddpm_0.pdf
index 4b5028f..9593d2d 100644
Binary files a/conditional_ddpm/results/eddpm_0.pdf and b/conditional_ddpm/results/eddpm_0.pdf differ
diff --git a/conditional_ddpm/results/eddpm_1.pdf b/conditional_ddpm/results/eddpm_1.pdf
index aa73ee3..72a6f3e 100644
Binary files a/conditional_ddpm/results/eddpm_1.pdf and b/conditional_ddpm/results/eddpm_1.pdf differ
diff --git a/conditional_ddpm/results/eddpm_None.pdf b/conditional_ddpm/results/eddpm_None.pdf
index fe62651..5a06aef 100644
Binary files a/conditional_ddpm/results/eddpm_None.pdf and b/conditional_ddpm/results/eddpm_None.pdf differ
diff --git a/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt b/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt
index b78fac6..9db9168 100644
Binary files a/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt and b/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt differ
diff --git a/conditional_ddpm/run_ddpm/last.ckpt b/conditional_ddpm/run_ddpm/last.ckpt
index b78fac6..9db9168 100644
Binary files a/conditional_ddpm/run_ddpm/last.ckpt and b/conditional_ddpm/run_ddpm/last.ckpt differ
diff --git a/conditional_ddpm/run_ddpm/lightning_logs/version_0/events.out.tfevents.1683082330.saswat-HP-Pavilion.67046.0 b/conditional_ddpm/run_ddpm/lightning_logs/version_0/events.out.tfevents.1683082330.saswat-HP-Pavilion.67046.0
deleted file mode 100644
index b1ed0d9..0000000
Binary files a/conditional_ddpm/run_ddpm/lightning_logs/version_0/events.out.tfevents.1683082330.saswat-HP-Pavilion.67046.0 and /dev/null differ
diff --git a/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt b/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt
index 1a11410..b9579aa 100644
Binary files a/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt and b/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt differ
diff --git a/conditional_ddpm/run_eddpm/last.ckpt b/conditional_ddpm/run_eddpm/last.ckpt
index 1a11410..b9579aa 100644
Binary files a/conditional_ddpm/run_eddpm/last.ckpt and b/conditional_ddpm/run_eddpm/last.ckpt differ
diff --git a/conditional_ddpm/run_eddpm/lightning_logs/version_0/events.out.tfevents.1683082525.saswat-HP-Pavilion.67404.0 b/conditional_ddpm/run_eddpm/lightning_logs/version_0/events.out.tfevents.1683082525.saswat-HP-Pavilion.67404.0
deleted file mode 100644
index dd8faf7..0000000
Binary files a/conditional_ddpm/run_eddpm/lightning_logs/version_0/events.out.tfevents.1683082525.saswat-HP-Pavilion.67404.0 and /dev/null differ
diff --git a/stable_df_esd/app.py b/stable_df_esd/app.py
index 9402f6b..3509fa9 100644
--- a/stable_df_esd/app.py
+++ b/stable_df_esd/app.py
@@ -6,6 +6,9 @@ from PIL import Image
from inference import inference
import numpy as np
+"""
+Python file to run a demo on the finetuned model.
+"""
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
SD_PATH = 'CompVis/stable-diffusion-v1-4'
diff --git a/stable_df_esd/inference.py b/stable_df_esd/inference.py
index f7ecf81..9b348b6 100644
--- a/stable_df_esd/inference.py
+++ b/stable_df_esd/inference.py
@@ -6,6 +6,9 @@ from PIL import Image
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
def inference(prompt,sd_path, esd_path):
+ """
+ Given a prompt and a path to ESD generate image after loading ESD.
+ """
ddpm = SD()
ddpm = ddpm.to(DEVICE)
ddpm.eval()
diff --git a/stable_df_esd/models.py b/stable_df_esd/models.py
index 8574174..5e575c5 100644
--- a/stable_df_esd/models.py
+++ b/stable_df_esd/models.py
@@ -17,7 +17,16 @@ DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
class SD(torch.nn.Module):
+ """
+ Implementation of Stable diffusion model.
+ """
def __init__(self):
+ """
+ Initialise the SD model with various pretraind models.
+ AutoencoderKL for dimentionality reduction of image, before passing it to
+ the UNet model for getting the score.
+ CLIPTokenizer and CLIPTextModel are used for toknization and encoding of text prompt.
+ """
super().__init__()
self.encoder = AutoencoderKL.from_pretrained(ENCODER_PATH,subfolder=ENCODER_FOLDER)
self.ddpm = UNet2DConditionModel.from_pretrained(DECODER_PATH,subfolder=DECODER_FOLDER)
@@ -28,7 +37,9 @@ class SD(torch.nn.Module):
@torch.no_grad()
def __call__(self,prompts,pixel_size=512,n_steps=100,batch_size=1,last_itr=None):
-
+ """
+ Given a prompt get output from reverse diffusion process.
+ """
if type(prompts) != list:
prompts = [prompts]
@@ -56,7 +67,9 @@ class SD(torch.nn.Module):
@torch.no_grad()
def reverse_diffusion(self,latents,embeddings,last_itr=1000,first_itr=0,original=False):
latents_steps = []
-
+ """
+ Implementation of reverse diffusion process.
+ """
for itr in tqdm(range(first_itr, last_itr)):
noise_pred = self.predict_noise(itr, latents, embeddings)
@@ -75,6 +88,9 @@ class SD(torch.nn.Module):
def encode_text(self,prompts, count):
+ """
+ Encode the text using the text tokenizer and encoder from CLIP model.
+ """
tokens = self.text_tokenize(prompts)
text_encodings = self.text_encode(tokens)
@@ -120,7 +136,9 @@ class SD(torch.nn.Module):
return pil_images
def predict_noise(self,iteration,latents,text_embeddings,guidance_scale=7.5):
-
+ """
+ The function that predicts noise given a latents, text embedding.
+ """
# Doing double forward pass
latents = torch.cat([latents] * 2)
latents = self.scheduler.scale_model_input(latents, self.scheduler.timesteps[iteration])
diff --git a/stable_df_esd/train.py b/stable_df_esd/train.py
index bb72028..4bb1883 100644
--- a/stable_df_esd/train.py
+++ b/stable_df_esd/train.py
@@ -9,6 +9,9 @@ ddpm = ddpm.to(DEVICE)
ddpm.train()
def train(prompt,epochs=100,eta=1.0,path='./saved_models/esd.pt'):
+ """
+ Method that finetunes the ESD model.
+ """
frozen_ddpm = deepcopy(ddpm)
frozen_ddpm.eval()
......@@ -6,6 +6,9 @@ from PIL import Image
from inference import inference
import numpy as np
"""
Python file to run a demo on the finetuned model.
"""
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
SD_PATH = 'CompVis/stable-diffusion-v1-4'
......
......@@ -6,6 +6,9 @@ from PIL import Image
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
def inference(prompt,sd_path, esd_path):
"""
Given a prompt and a path to ESD generate image after loading ESD.
"""
ddpm = SD()
ddpm = ddpm.to(DEVICE)
ddpm.eval()
......
......@@ -17,7 +17,16 @@ DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
class SD(torch.nn.Module):
"""
Implementation of Stable diffusion model.
"""
def __init__(self):
"""
Initialise the SD model with various pretraind models.
AutoencoderKL for dimentionality reduction of image, before passing it to
the UNet model for getting the score.
CLIPTokenizer and CLIPTextModel are used for toknization and encoding of text prompt.
"""
super().__init__()
self.encoder = AutoencoderKL.from_pretrained(ENCODER_PATH,subfolder=ENCODER_FOLDER)
self.ddpm = UNet2DConditionModel.from_pretrained(DECODER_PATH,subfolder=DECODER_FOLDER)
......@@ -28,7 +37,9 @@ class SD(torch.nn.Module):
@torch.no_grad()
def __call__(self,prompts,pixel_size=512,n_steps=100,batch_size=1,last_itr=None):
"""
Given a prompt get output from reverse diffusion process.
"""
if type(prompts) != list:
prompts = [prompts]
......@@ -56,7 +67,9 @@ class SD(torch.nn.Module):
@torch.no_grad()
def reverse_diffusion(self,latents,embeddings,last_itr=1000,first_itr=0,original=False):
latents_steps = []
"""
Implementation of reverse diffusion process.
"""
for itr in tqdm(range(first_itr, last_itr)):
noise_pred = self.predict_noise(itr, latents, embeddings)
......@@ -75,6 +88,9 @@ class SD(torch.nn.Module):
def encode_text(self,prompts, count):
"""
Encode the text using the text tokenizer and encoder from CLIP model.
"""
tokens = self.text_tokenize(prompts)
text_encodings = self.text_encode(tokens)
......@@ -120,7 +136,9 @@ class SD(torch.nn.Module):
return pil_images
def predict_noise(self,iteration,latents,text_embeddings,guidance_scale=7.5):
"""
The function that predicts noise given a latents, text embedding.
"""
# Doing double forward pass
latents = torch.cat([latents] * 2)
latents = self.scheduler.scale_model_input(latents, self.scheduler.timesteps[iteration])
......
......@@ -9,6 +9,9 @@ ddpm = ddpm.to(DEVICE)
ddpm.train()
def train(prompt,epochs=100,eta=1.0,path='./saved_models/esd.pt'):
"""
Method that finetunes the ESD model.
"""
frozen_ddpm = deepcopy(ddpm)
frozen_ddpm.eval()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment