Commit 63fcce2f authored by Meet Narendra's avatar Meet Narendra 💬

abstract ddpm

parent 66f9ad5f
python eval.py --ckpt_path runs/n_dim=3,n_steps=1000,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=200/last.ckpt \
--hparams_path runs/n_dim=3,n_steps=1000,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=200/lightning_logs/version_0/hparams.yaml \
--eval_nll --vis_diffusion --vis_overlay
python eval.py --ckpt_path runs/n_dim=3,n_steps=100,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=500/last.ckpt \
--hparams_path runs/n_dim=3,n_steps=100,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=500/lightning_logs/version_0/hparams.yaml \
--eval_nll --vis_diffusion --vis_overlay
python eval.py --ckpt_path ft_runs/n_dim=3,n_steps=100,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=500/last.ckpt \
--hparams_path ft_runs/n_dim=3,n_steps=100,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=500/lightning_logs/version_0/hparams.yaml \
--eval_nll --vis_diffusion --vis_overlay
\ No newline at end of file
import numpy as np
'''
sin_train = np.load('./data/3d_sin_5_5_train.npy')
sin_test = np.load('./data/3d_sin_5_5_test.npy')
helix_train = np.load('./data/helix_3D_train.npy')
helix_test = np.load('./data/helix_3D_test.npy')
sin_bounds = np.load('./data/3d_sin_5_5_bounds.npy')
helix_bounds = np.load('./data/helix_3D_bounds.npy')
merged_train = np.concatenate((sin_train, helix_train), axis=0)
merged_test = np.concatenate((sin_test, helix_test), axis=0)
merged_bounds = np.concatenate((sin_bounds, helix_bounds), axis=0)
np.save('./data/3d_sin_helix_train.npy', merged_train)
np.save('./data/3d_sin_helix_test.npy', merged_test)
np.save('./data/3d_sin_helix_bounds.npy', merged_bounds)
'''
from generate import *
train_data = get_cylinder_helix('train')
print(train_data.shape)
test_data = get_cylinder_helix('test')
print(test_data.shape)
np.save('./data/3d_cylinder_helix_train.npy',train_data)
np.save('./data/3d_cylinder_helix_test.npy',test_data)
helix_3d_train = np.load('./data/helix_3D_train.npy')
helix_3d_test = np.load('./data/helix_3D_test.npy')
helix_3d_train[:,2] += 2
helix_3d_test[:,2] += 2
np.save('./data/shifted_helix_3D_train.npy',helix_3d_train)
np.save('./data/shifted_helix_3D_test.npy',helix_3d_test)
\ No newline at end of file
import numpy as np
from torch.utils.data import Dataset
class ThreeDSinDataset(Dataset):
def __init__(self, npy_path, mean=None, std=None):
super().__init__()
if type(npy_path) == str:
self.data = np.load(npy_path)
else:
self.data1 = np.load(npy_path[0])
self.data2 = np.load(npy_path[1])
self.data = np.concatenate((self.data1, self.data2), axis=1)
print('Data shape',self.data.shape)
if mean is None:
mean = np.mean(self.data, axis=0)
if std is None:
std = np.std(self.data, axis=0)
#self.data = (self.data - mean) / std
def __len__(self):
return self.data.shape[0]
def __getitem__(self, index):
return self.data[index]
import numpy as np
import matplotlib.pyplot as plt
data = np.load('data/gaussian_3D_train.npy')
data_test = np.load('data/gaussian_3D_test.npy')
print(data.shape)
print(data_test.shape)
#data= np.concatenate((data, data_test), axis=0)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data[:,0], data[:,1], data[:,2], s=0.1)
plt.show()
import os
import argparse
import torch
import numpy as np
from model import LitDiffusionModel
from finetune_model import FineTuneLitDiffusionModel
from eval_utils import *
from chamferdist import ChamferDistance
parser = argparse.ArgumentParser()
model_args = parser.add_argument_group('model')
model_args.add_argument('--ckpt_path', type=str, help='Path to the model checkpoint', required=True)
model_args.add_argument('--hparams_path', type=str, help='Path to model hyperparameters', required=True)
data_args = parser.add_argument_group('data')
data_args.add_argument('--train_data_path', type=str, default='./data/gaussian_3D_test.npy', help='Path to training data numpy file')
data_args.add_argument('--test_data_path', type=str, default='./data/gaussian_3D_test.npy', help='Path to test data numpy file')
eval_args = parser.add_argument_group('evaluation')
eval_args.add_argument('--savedir', type=str, default='./results/', help='Path to directory for saving evaluation results')
eval_args.add_argument('--n_runs', type=int, default=3, help='Number of runs of evaluation')
eval_args.add_argument('--eval_emd', action='store_true', help='Calculate Earth Mover\'s Distance')
eval_args.add_argument('--eval_emd_samples', type=int, default=128, help='Number of random samples to be sampled for calculating EMD')
eval_args.add_argument('--eval_nll', action='store_true', help='Calculate negative log likelihood')
eval_args.add_argument('--eval_chamfer', action='store_true', help='Calculate Chamfer Distance (using `chamferdist`)')
eval_args.add_argument('--vis_overlay', action='store_true', help='Overlays predicted distribution on top of ground truth')
eval_args.add_argument('--vis_diffusion', action='store_true', help='Shows the evolution of samples through the diffusion process via an animation')
eval_args.add_argument('--vis_track_max', action='store_true', help='Track the point with highest Z in the predicted distribution in diffusion animation')
eval_args.add_argument('--vis_track_min', action='store_true', help='Track the point with lowest Z in the predicted distribution in diffusion animation')
eval_args.add_argument('--vis_smoothed_end', action='store_true', help='Smooths the end of animation by repeating the last frame of animation')
args = parser.parse_args()
# litmodel = LitDiffusionModel.load_from_checkpoint(
# args.ckpt_path,
# hparams_file=args.hparams_path
# )
litmodel = FineTuneLitDiffusionModel.load_from_checkpoint(
args.ckpt_path,
hparams_file=args.hparams_path
)
litmodel.eval()
traindata = np.load(args.train_data_path)
testdata = np.load(args.test_data_path)
# mean = np.mean(traindata, axis=0)
# std = np.std(traindata, axis=0)
# mean[0] = 0
# mean[1] = 0
# std[0] = 1
# std[1] = 1
# print(f'Mean = {mean}')
# print(f'Std dev = {std}')
# traindata = (traindata-mean)/std
# testdata = (testdata-mean)/std
traindata = torch.from_numpy(traindata)
testdata = torch.from_numpy(testdata)
os.makedirs(args.savedir, exist_ok=True)
for i_run in range(args.n_runs):
print(64*'-')
print(f'Evaluation run {i_run+1}/{args.n_runs}')
print(64*'-')
with torch.no_grad():
gendata, intermediate = litmodel.sample(testdata.size(0), progress=True, return_intermediate=True)
#temp_data = np.concatenate((testdata.numpy(), gendata.numpy()), axis=0)
#print(f'gendata.shape = {gendata.shape}')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(gendata[:, 0], gendata[:, 1], gendata[:, 2], c=gendata[:, 2], cmap=cm.spring, alpha=0.1)
#ax.scatter(testdata[:, 0], testdata[:, 1], testdata[:, 2], c=testdata[:, 2], alpha=0.1)
ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)
#ax.set_zlim(0, 1)
plt.show()
with open(f'{args.savedir}/{i_run:02d}_log.txt', 'w') as f:
f.write('Results\n')
# EMD
if args.eval_emd:
idx = np.random.choice(np.arange(gendata.size(0)), size=args.eval_emd_samples, replace=False)
test_emd = get_emd(testdata[idx].numpy(), gendata[idx].numpy())
train_emd = get_emd(traindata[idx].numpy(), gendata[idx].numpy())
print(f'test_emd: {test_emd}')
print(f'train_emd: {train_emd}')
with open(f'{args.savedir}/{i_run:02d}_log.txt', 'a') as f:
f.write(f'test_emd: {test_emd}\n')
f.write(f'train_emd: {train_emd}\n')
# NLL
if args.eval_nll:
test_nll = get_nll(testdata, gendata).item()
train_nll = get_nll(traindata, gendata).item()
print(f'test_nll: {test_nll}')
print(f'train_nll: {train_nll}')
with open(f'{args.savedir}/{i_run:02d}_log.txt', 'a') as f:
f.write(f'test_nll: {test_nll}\n')
f.write(f'train_nll: {train_nll}\n')
# Chamfer
if args.eval_chamfer:
cd = ChamferDistance()
test_chamfer = cd(
testdata.unsqueeze(0).float(),
gendata.unsqueeze(0).float()
).item()
train_chamfer = cd(
traindata.unsqueeze(0).float(),
gendata.unsqueeze(0).float()
).item()
print(f'test_chamfer: {test_chamfer}')
print(f'train_chamfer: {train_chamfer}')
with open(f'{args.savedir}/{i_run:02d}_log.txt', 'a') as f:
f.write(f'test_chamfer: {test_chamfer}\n')
f.write(f'train_chamfer: {train_chamfer}\n')
# Visualize overlay
if args.vis_overlay:
# Only performed with test since it allows for sparser and better visualizations
print('Visualizing predicted distribution by overlaying it on top of ground truth distribution')
plot_final_distributions(
f'{args.savedir}/{i_run:02d}_overlayvis.pdf',
testdata, gendata
)
print(f'Output: {args.savedir}/{i_run:02d}_overlayvis.pdf')
# Visualize diffusion
if args.vis_diffusion:
print('Visualizing evolution of samples through the diffusion process')
print('WARN: this will take a long time depending on the number of diffusion steps')
fname = f'{i_run:02d}.diffusionvis.track_max={args.vis_track_max}.track_min={args.vis_track_min}.smoothed_end={args.vis_smoothed_end}.gif'
animate_intermediate_distributions(
f'{args.savedir}/{fname}',
testdata, intermediate,
track_max=args.vis_track_max,
track_min=args.vis_track_min,
smoothed_end=args.vis_smoothed_end
)
print(f'Output: {args.savedir}/{fname}')
print(64*'-')
import torch
import numpy as np
from scipy.spatial.distance import cdist
from pyemd import emd
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import cm
def gaussian_kernel(x, x0, temperature=1e-1):
dim = x0.size(1)
x = x.view((1, -1))
exp_term = torch.sum(- 0.5 * (x - x0) ** 2, dim=1)
main_term = torch.exp(exp_term / (2 * temperature))
coeff = 1. / torch.sqrt(torch.Tensor([2 * torch.pi * temperature])) ** dim
prod = coeff * main_term
return torch.sum(prod) / x0.size(0)
def get_likelihood(data, pred, temperature):
lh = torch.zeros(pred.size(0))
dim = pred.size(1)
for i in range(pred.size(0)):
lh[i] = gaussian_kernel(pred[i,:], data, temperature)
return torch.mean(lh)
def get_ll(data, pred, temperature=1e-1):
return torch.log(get_likelihood(data, pred, temperature))
def get_nll(data, pred, temperature=1e-1):
return -get_ll(data, pred, temperature)
def get_nll_bits_per_dim(data, pred, temperature=1e-1):
return get_nll(data, pred, temperature) / (torch.log(torch.Tensor([2])) * data.shape[0])
def get_emd(d1, d2):
d_comb = np.concatenate((d1, d2), axis=0)
dist = np.linalg.norm((d_comb), axis=1).reshape((-1,1))
d1 = np.concatenate((np.zeros((d1.shape[0], 1)), d1), axis=1)
d2 = np.concatenate((np.ones((d2.shape[0], 1)), d2), axis=1)
d_comb = np.concatenate((d1, d2), axis=0)
app = np.concatenate((dist, d_comb), axis=1)
app = app[app[:, 0].argsort()]
d1_sig, d2_sig = 1 - app[:, 1], app[:, 1]
dist_sorted = app[:, 2:]
dist = cdist(dist_sorted, dist_sorted)
d1_sig = d1_sig.copy(order='C')
d2_sig = d2_sig.copy(order='C')
dist = dist.copy(order='C')
return emd(d1_sig, d2_sig, dist)
def plot_final_distributions(fname, testdata, gendata):
plt.close()
plt.clf()
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(testdata[:, 0], testdata[:, 1], testdata[:, 2], marker='+', c=testdata[:, 2], cmap=cm.spring, alpha=0.5)
ax.scatter(gendata[:, 0], gendata[:, 1], gendata[:, 2], marker='.', c=gendata[:, 2], cmap=cm.cool, alpha=0.3)
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])
ax.set_zlim([-0.5, 1])
fig.savefig(fname, dpi=300, bbox_inches='tight')
def animate_intermediate_distributions(fname, testdata, intermediate, track_max=False, track_min=False, smoothed_end=True):
plt.close()
plt.clf()
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(testdata[:, 0], testdata[:, 1], testdata[:, 2], marker='+', c=testdata[:, 2], cmap=cm.spring, alpha=0.5)
diffused = ax.scatter(intermediate[0][:, 0], intermediate[0][:, 1], intermediate[0][:, 2], marker='.', c=intermediate[0][:, 2], cmap=cm.cool, alpha=0.1)
if track_max:
max_x, max_y, max_z = [], [], []
max_idx = torch.argmax(intermediate[-1][:, 2])
max_trace, = ax.plot(max_x, max_y, max_z, color='blue')
if track_min:
min_x, min_y, min_z = [], [], []
min_idx = torch.argmin(intermediate[-1][:, 2])
min_trace, = ax.plot(min_x, min_y, min_z, color='red')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])
ax.set_zlim([-0.5, 1])
fig.set_size_inches(5, 5)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
if smoothed_end:
# Repeat last a couple times for nicer animation
for _ in range(len(intermediate) // 5):
intermediate.append(intermediate[-1])
def animate_diffused(i):
global max_x, max_y, max_z, min_x, min_y, min_z
# https://stackoverflow.com/a/41609238
diffused._offsets3d = (intermediate[i][:, 0].detach().cpu().numpy(), intermediate[i][:, 1].detach().cpu().numpy(), intermediate[i][:, 2].detach().cpu().numpy())
diffused._c = intermediate[i][:, 2].detach().cpu().numpy()
if i == 0:
if track_max:
max_x, max_y, max_z = [], [], []
if track_min:
min_x, min_y, min_z = [], [], []
if track_max:
max_x.append(intermediate[i][max_idx, 0])
max_y.append(intermediate[i][max_idx, 1])
max_z.append(intermediate[i][max_idx, 2])
max_trace.set_data(max_x, max_y)
max_trace.set_3d_properties(max_z)
if track_min:
min_x.append(intermediate[i][min_idx, 0])
min_y.append(intermediate[i][min_idx, 1])
min_z.append(intermediate[i][min_idx, 2])
min_trace.set_data(min_x, min_y)
min_trace.set_3d_properties(min_z)
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
fig.tight_layout()
ret = (diffused,)
if track_max:
ret = ret + (max_trace,)
if track_min:
ret = ret + (min_trace,)
return ret
anim = animation.FuncAnimation(fig, animate_diffused, repeat=True, frames=len(intermediate)-1, interval=50)
writer = animation.PillowWriter(fps=60,
metadata=dict(artist='CS726-2023 diffusion model HW2'),
bitrate=1800)
def print_anim_progress(i, n):
msg = 'Starting GIF creation' if i == n else f'Rendering frame {i}/{n}'
print(msg, end='\r', flush=True)
anim.save(fname, writer=writer, dpi=100, progress_callback=print_anim_progress)
print(f'\rAnimation written to "{fname}"\n')
import argparse
import torch
import pytorch_lightning as pl
from finetune_model import FineTuneLitDiffusionModel
from dataset import ThreeDSinDataset
parser = argparse.ArgumentParser()
model_args = parser.add_argument_group('model')
model_args.add_argument('--n_dim', type=int, default=3, help='Number of dimensions')
model_args.add_argument('--n_steps', type=int, default=100, help='Number of diffusion steps')
model_args.add_argument('--lbeta', type=float, default=1e-5, help='Lower bound of beta')
model_args.add_argument('--ubeta', type=float, default=1.28e-2, help='Upper bound of beta')
model_args.add_argument('--scheduler_type', type=str, default="linear", help='Variance Scheduling Function')
training_args = parser.add_argument_group('training')
training_args.add_argument('--seed', type=int, default=1618, help='Random seed for experiments')
training_args.add_argument('--n_epochs', type=int, default=500, help='Number of training epochs')
training_args.add_argument('--batch_size', type=int, default=1024, help='Batch size for training dataloader')
training_args.add_argument('--train_data_path', type=str, default='./data/gaussian_3D_train.npy', help='Path to training data numpy file')
training_args.add_argument('--ft_data_path', type=str, default='./data/gaussian_3D_ft.npy', help='Path to conditioning data numpy file')
training_args.add_argument('--savedir', type=str, default='./ft_runs/', help='Root directory where all checkpoint and logs will be saved')
args = parser.parse_args()
n_dim = args.n_dim
n_steps = args.n_steps
lbeta = args.lbeta
ubeta = args.ubeta
scheduler_type=args.scheduler_type
pl.seed_everything(args.seed)
batch_size = args.batch_size
n_epochs = args.n_epochs
savedir = args.savedir
litmodel = FineTuneLitDiffusionModel(
n_dim=n_dim,
n_steps=n_steps,
lbeta=lbeta,
ubeta=ubeta,
scheduler_type=scheduler_type,
model_chkpt='saved_models/last.ckpt',
model_hparams='saved_models/hparams.yaml'
)
for name, param in litmodel.named_parameters():
if 'frozen_model' in name:
param.requires_grad = False
print(name, param.requires_grad)
train_dataset = ThreeDSinDataset([args.train_data_path,args.ft_data_path])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
run_name = f'n_dim={n_dim},n_steps={n_steps},lbeta={lbeta:.3e},ubeta={ubeta:.3e},scheduler_type={scheduler_type},batch_size={batch_size},n_epochs={n_epochs}'
trainer = pl.Trainer(
deterministic=True,
logger=pl.loggers.TensorBoardLogger(f'{savedir}/{run_name}/'),
max_epochs=n_epochs,
log_every_n_steps=1,
callbacks=[
# A dummy model checkpoint callback that stores the latest model at the end of every epoch
pl.callbacks.ModelCheckpoint(
dirpath=f'{savedir}/{run_name}/',
filename='{epoch:04d}-{train_loss:.3f}',
save_top_k=1,
monitor='epoch',
mode='max',
save_last=True,
every_n_epochs=1,
),
]
)
trainer.fit(model=litmodel, train_dataloaders=train_dataloader)
This diff is collapsed.
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
def gaussian(x, mu, sig):
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
def gaussian_3d(x, y, mu, sig):
return np.exp(-np.power(x - mu[0], 2.) / (2 * np.power(sig[0], 2.))) * np.exp(-np.power(y - mu[1], 2.) / (2 * np.power(sig[1], 2.)))
def generate_3d_gaussian(n_samples=100):
x = np.linspace(-3, 3, n_samples)
y = np.linspace(-3, 3, n_samples)
X, Y = np.meshgrid(x, y)
mu = [-1, -1, 0]
sig = [0.5, 0.5, 0.5]
z = gaussian_3d(X, Y, mu, sig)
mu = [1, 1, 0]
new_z = gaussian_3d(X, Y, mu, sig)
z = z + new_z
data = np.concatenate((X.reshape(-1,1), Y.reshape(-1,1), z.reshape(-1,1)), axis=1)
#print(data.shape)
return data
def generate_3d_gaussian_ft(n_samples=100):
x = np.linspace(-3, 3, n_samples)
y = np.linspace(-3, 3, n_samples)
X, Y = np.meshgrid(x, y)
mu = [-1, -1, 0]
sig = [0.5, 0.5, 0.5]
z = gaussian_3d(X, Y, mu, sig)
data = np.concatenate((X.reshape(-1,1), Y.reshape(-1,1), z.reshape(-1,1)), axis=1)
#print(data.shape)
return data
data = generate_3d_gaussian()
print(data.shape)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=data[:, 2], cmap=cm.spring, alpha=0.1)
#ax.set_zlim(0, 3)
plt.show()
np.save('./data/gaussian_3D_train.npy', data)
new_data = generate_3d_gaussian_ft()
print(new_data.shape)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(new_data[:, 0], new_data[:, 1], new_data[:, 2], c=new_data[:, 2], cmap=cm.spring, alpha=0.1)
#ax.set_zlim(0, 3)
plt.show()
np.save('./data/gaussian_3D_ft.npy', new_data)
new_z = data
new_z[:, 2] -= new_data[:, 2]
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(new_z[:, 0], new_z[:, 1], new_z[:, 2], c=new_z[:, 2], cmap=cm.spring, alpha=0.1)
#ax.set_zlim(0, 3)
plt.show()
np.save('./data/gaussian_3D_test.npy', new_z)
'''
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import pandas as pd
from scipy.stats import multivariate_normal
import torch
#3d cylinder data generate
def cylinder(r=0.25, h=1, n=10000):
#function which generates 3d points on a cylinder
n_samples = n
points = np.zeros((n_samples, 3))
theta = np.random.uniform(0, 2*np.pi, n_samples)
z = np.random.uniform(0, h, n_samples)
points[:, 0] = r * np.cos(theta)
points[:, 1] = r * np.sin(theta)
points[:, 2] = z
return points
def get_cylinder_helix(type_data='train'):
helix_3d = np.load(f'./data/helix_3D_{type_data}.npy')
#shift z axis of helix_3d
helix_3d[:, 2] += 2
n_samples = helix_3d.shape[0]
#print(n_samples)
cylinder_data = cylinder(0.05, 1,n=n_samples)
concat = np.concatenate((cylinder_data, helix_3d), axis=0)
np.random.shuffle(concat)
df = pd.DataFrame(concat, columns=['x','y','z'])
df = df.drop_duplicates(['x','y'])
df = df.drop_duplicates(['x','z'])
df = df.drop_duplicates(['y','z'])
concat = df.values
mean = np.mean(concat, axis=0)
std = np.std(concat, axis=0)
concat = (concat - mean) / std
#print('type',type(concat),'concat',concat.shape)
return concat
# data = multivariate_gaussian()
# print(data.shape)
data = get_cylinder_helix('test')
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=data[:, 2], cmap=cm.spring, alpha=0.1)
#ax.set_zlim(0, 3)
plt.show()
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import numpy as np
def expand_alphas(batch, alpha, t):
"""
If t is not a tensor object than expand alpha[t] to shape of batch
else get alpha[t] in the shape of x
"""
if isinstance(t, int):
return alpha[t].expand(batch.size(0),1)
else:
return torch.zeros((batch.size(0),1)) + alpha[t][:,None]
class LitDiffusionModel(pl.LightningModule):
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear"):
super().__init__()
"""
If you include more hyperparams (e.g. `n_layers`), be sure to add that to `argparse` from `train.py`.
Also, manually make sure that this new hyperparameter is being saved in `hparams.yaml`.
"""
self.save_hyperparameters()
"""
Your model implementation starts here. We have separate learnable modules for `time_embed` and `model`.
You may choose a different architecture altogether. Feel free to explore what works best for you.
If your architecture is just a sequence of `torch.nn.XXX` layers, using `torch.nn.Sequential` will be easier.
`time_embed` can be learned or a fixed function based on the insights you get from visualizing the data.
If your `model` is different for different datasets, you can use a hyperparameter to switch between them.
Make sure that your hyperparameter behaves as expecte and is being saved correctly in `hparams.yaml`.
"""
embedding_dim = 8
# nn.Embedding is used to embed time in a latend dimention that can be passed to the model
self.time_embed = nn.Embedding(n_steps,embedding_dim)
# The model is a sequential model consisting of linear layers with ReLU as activation function
self.model = nn.Sequential(
nn.Linear(embedding_dim + 3, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,3)
)
"""
Be sure to save at least these 2 parameters in the model instance.
"""
self.n_steps = n_steps
self.n_dim = n_dim
# New hyperparameter
# Scheduler type (liner/sigmoid/cosine) is stored
self.scheduler_type=scheduler_type
"""
Sets up variables for noise schedule
"""
self.init_alpha_beta_schedule(lbeta, ubeta)
def forward(self, x, t):
"""
Similar to `forward` function in `nn.Module`.
Notice here that `x` and `t` are passed separately. If you are using an architecture that combines
`x` and `t` in a different way, modify this function appropriately.
"""
if not isinstance(t, torch.Tensor):
t = torch.LongTensor([t]).expand(x.size(0))
t_embed = self.time_embed(t)
return self.model(torch.cat((x, t_embed), dim=1).float())
def init_alpha_beta_schedule(self, lbeta, ubeta):
"""
Set up your noise schedule. You can perhaps have an additional hyperparameter that allows you to
switch between various schedules for answering q4 in depth. Make sure that this hyperparameter
is included correctly while saving and loading your checkpoints.
"""
if self.scheduler_type=="linear":
self.beta = torch.linspace(lbeta,ubeta,self.n_steps)
elif self.scheduler_type == "sigmoid":
self.beta = torch.linspace(-6, 6, self.n_steps)
self.beta = torch.sigmoid(self.beta) * (ubeta - lbeta) + lbeta
elif self.scheduler_type == "cosine":
self.beta = torch.linspace(0, 1, self.n_steps)
self.beta = (torch.cos((self.beta)*np.pi/2)+0.002)/1.002 * (ubeta - lbeta) + lbeta
#Store different type of alpha and beta for speedup calculation
self.alpha = 1-self.beta
self.alpha_cum = torch.sqrt(torch.cumprod(1 - self.beta, 0))
self.alpha_cum_sqrt = torch.sqrt(self.alpha_cum)
self.one_min_alphas_sum_sqrt = torch.sqrt(1-self.alpha_cum)
def q_sample(self, x, t):
"""
Sample from q given x_t.
"""
alpha = expand_alphas(x,self.alpha_cum_sqrt, t)
one_minus_alpha = expand_alphas(x, 1-self.alpha_cum, t)
_q_sample = (alpha * x) + one_minus_alpha * torch.randn_like(x)
return _q_sample
def p_sample(self, x, t):
"""
Sample from p given x_t.
"""
epsilon_factor = (expand_alphas(x, self.beta, t) / expand_alphas(x, self.one_min_alphas_sum_sqrt, t))
epsilon_theta = self.forward(x, t)
mean = (1 / expand_alphas(x, torch.sqrt(self.alpha), t)) * (x - (epsilon_factor * epsilon_theta))
sigma = expand_alphas(x, torch.sqrt(self.beta), t)
sample = mean + sigma * torch.randn_like(x)
return sample
def training_step(self, batch, batch_idx):
"""
Implements one training step.
Given a batch of samples (n_samples, n_dim) from the distribution you must calculate the loss
for this batch. Simply return this loss from this function so that PyTorch Lightning will
automatically do the backprop for you.
Refer to the DDPM paper [1] for more details about equations that you need to implement for
calculating loss. Make sure that all the operations preserve gradients for proper backprop.
Refer to PyTorch Lightning documentation [2,3] for more details about how the automatic backprop
will update the parameters based on the loss you return from this function.
References:
[1]: https://arxiv.org/abs/2006.11239
[2]: https://pytorch-lightning.readthedocs.io/en/stable/
[3]: https://www.pytorchlightning.ai/tutorials
"""
t = torch.randint(0, self.n_steps, size=(batch.shape[0],))
#print(t.shape)
alpha_sqrt = expand_alphas(batch, self.alpha_cum_sqrt, t)
#print(alpha_sqrt.shape)
one_min_alpha_sqrt = expand_alphas(batch, self.one_min_alphas_sum_sqrt, t)
#print(one_min_alpha_sqrt.shape)
noise = torch.randn_like(batch)
#print(noise.shape)
x = batch * alpha_sqrt + noise * one_min_alpha_sqrt
output = self.forward(x, t)
return (noise - output).square().mean()
def sample(self, n_samples, progress=False, return_intermediate=False):
"""
Implements inference step for the DDPM.
`progress` is an optional flag to implement -- it should just show the current step in diffusion
reverse process.
If `return_intermediate` is `False`,
the function returns a `n_samples` sampled from the learned DDPM
i.e. a Tensor of size (n_samples, n_dim).
Return: (n_samples, n_dim)(final result from diffusion)
Else
the function returns all the intermediate steps in the diffusion process as well
i.e. a Tensor of size (n_samples, n_dim) and a list of `self.n_steps` Tensors of size (n_samples, n_dim) each.
Return: (n_samples, n_dim)(final result), [(n_samples, n_dim)(intermediate) x n_steps]
"""
if return_intermediate:
out_samples = []
out_samples.append(torch.randn((n_samples, self.n_dim)))
for t in range(self.n_steps-1, -1, -1):
out_samples.append(self.p_sample(out_samples[-1], t))
return out_samples[-1], out_samples
else:
out_sample = torch.randn((n_samples, self.n_dim))
for t in range(self.n_steps-1, -1, -1):
out_sample = self.p_sample(out_sample, t)
return out_sample
def configure_optimizers(self):
"""
Sets up the optimizer to be used for backprop.
Must return a `torch.optim.XXX` instance.
You may choose to add certain hyperparameters of the optimizers to the `train.py` as well.
In our experiments, we chose one good value of optimizer hyperparameters for all experiments.
"""
# We have experimented with both SGD and Adam optimiser
#return torch.optim.SGD(self.model.parameters(), lr=0.01)
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
\ No newline at end of file
lbeta: 1.0e-05
n_dim: 3
n_steps: 100
scheduler_type: sigmoid
ubeta: 0.0128
lbeta: 1.0e-05
n_dim: 3
n_steps: 100
scheduler_type: linear
ubeta: 0.0128
import torch
import matplotlib.pyplot as plt
# Define the range of values for the learning rate
start_lr = 1e-5
end_lr = 1.28e-2
# Define the number of steps in the schedule
num_steps = 1000
# Define the different schedules
schedules = {
"linear": lambda step: start_lr + (end_lr - start_lr) * step / num_steps,
"cosine": lambda step: end_lr + (start_lr - end_lr) / 2 * (1 + torch.cos(torch.tensor(step / num_steps * 3.1415))),
"sigmoid": lambda step: (start_lr - end_lr) / (1 + torch.exp(torch.tensor((step - num_steps / 2) / (num_steps / 10)))) + end_lr,
"squared": lambda step: start_lr + (end_lr - start_lr) * (step / num_steps) ** 2,
"cube": lambda step: start_lr + (end_lr - start_lr) * (step / num_steps) ** 3,
"exponential": lambda step: start_lr * (end_lr / start_lr) ** (step / num_steps)
}
# Plot the schedules
for name, func in schedules.items():
lr_values = [func(step) for step in range(num_steps)]
plt.plot(lr_values, label=name)
# Set the axis labels and a legend
plt.xlabel("Step")
plt.ylabel("Beta")
plt.yscale("log")
plt.legend()
# Display the plot
plt.show()
from model import LitDiffusionModel
import matplotlib.pyplot as plt
import torch
import numpy as np
from dataset import ThreeDSinDataset
import pytorch_lightning as pl
litmodel = LitDiffusionModel()
litmodel = LitDiffusionModel.load_from_checkpoint(
"./runs/n_dim=3,n_steps=50,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=500/last.ckpt",
hparams_file="./runs/n_dim=3,n_steps=50,lbeta=1.000e-05,ubeta=1.280e-02,scheduler_type=linear,batch_size=1024,n_epochs=500/lightning_logs/version_0/hparams.yaml"
)
print(litmodel.betas)
print(litmodel.alphas_cum)
q_samp = []
x_sam = torch.randn((7781, 3))
q_samp.append(x_sam)
for t in range(49, -1, -1):
print(t)
eps_factor, eps_theta, mean, sigma_t, x_sam = litmodel.p_sample(x_sam, t)
#eps_factor, eps_theta, mean, sigma_t, x_sam = litmodel.p_sample(x_sam, t)
# print("eps_factor", eps_factor.shape)
# print("eps_theta", eps_theta.shape)
# print("mean", mean.shape)
# print("sigma ",sigma_t.shape)
print(eps_factor[0])
print(eps_theta[0])
print(mean[0])
print(sigma_t[0])
q_samp.append(x_sam)
import argparse
import torch
import pytorch_lightning as pl
from model import LitDiffusionModel
from dataset import ThreeDSinDataset
parser = argparse.ArgumentParser()
model_args = parser.add_argument_group('model')
model_args.add_argument('--n_dim', type=int, default=3, help='Number of dimensions')
model_args.add_argument('--n_steps', type=int, default=100, help='Number of diffusion steps')
model_args.add_argument('--lbeta', type=float, default=1e-5, help='Lower bound of beta')
model_args.add_argument('--ubeta', type=float, default=1.28e-2, help='Upper bound of beta')
model_args.add_argument('--scheduler_type', type=str, default="linear", help='Variance Scheduling Function')
training_args = parser.add_argument_group('training')
training_args.add_argument('--seed', type=int, default=1618, help='Random seed for experiments')
training_args.add_argument('--n_epochs', type=int, default=500, help='Number of training epochs')
training_args.add_argument('--batch_size', type=int, default=1024, help='Batch size for training dataloader')
training_args.add_argument('--train_data_path', type=str, default='./data/gaussian_3D_train.npy', help='Path to training data numpy file')
training_args.add_argument('--savedir', type=str, default='./runs/', help='Root directory where all checkpoint and logs will be saved')
args = parser.parse_args()
n_dim = args.n_dim
n_steps = args.n_steps
lbeta = args.lbeta
ubeta = args.ubeta
scheduler_type=args.scheduler_type
pl.seed_everything(args.seed)
batch_size = args.batch_size
n_epochs = args.n_epochs
savedir = args.savedir
litmodel = LitDiffusionModel(
n_dim=n_dim,
n_steps=n_steps,
lbeta=lbeta,
ubeta=ubeta,
scheduler_type=scheduler_type
)
train_dataset = ThreeDSinDataset(args.train_data_path)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
run_name = f'n_dim={n_dim},n_steps={n_steps},lbeta={lbeta:.3e},ubeta={ubeta:.3e},scheduler_type={scheduler_type},batch_size={batch_size},n_epochs={n_epochs}'
trainer = pl.Trainer(
deterministic=True,
logger=pl.loggers.TensorBoardLogger(f'{savedir}/{run_name}/'),
max_epochs=n_epochs,
log_every_n_steps=1,
callbacks=[
# A dummy model checkpoint callback that stores the latest model at the end of every epoch
pl.callbacks.ModelCheckpoint(
dirpath=f'{savedir}/{run_name}/',
filename='{epoch:04d}-{train_loss:.3f}',
save_top_k=1,
monitor='epoch',
mode='max',
save_last=True,
every_n_epochs=1,
),
]
)
trainer.fit(model=litmodel, train_dataloaders=train_dataloader)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment