Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CS726-ESD
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Meet Narendra
CS726-ESD
Commits
6b63ef62
Commit
6b63ef62
authored
May 03, 2023
by
Saswat
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove diff file
parent
92edbf29
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
321 deletions
+0
-321
diff
diff
+0
-321
No files found.
diff
deleted
100644 → 0
View file @
92edbf29
diff --git a/abstract_ddpm/convert.py b/abstract_ddpm/convert.py
deleted file mode 100644
index 84b274f..0000000
--- a/abstract_ddpm/convert.py
+++ /dev/null
@@ -1,36 +0,0 @@
-
-import numpy as np
-'''
-sin_train = np.load('./data/3d_sin_5_5_train.npy')
-sin_test = np.load('./data/3d_sin_5_5_test.npy')
-helix_train = np.load('./data/helix_3D_train.npy')
-helix_test = np.load('./data/helix_3D_test.npy')
-sin_bounds = np.load('./data/3d_sin_5_5_bounds.npy')
-helix_bounds = np.load('./data/helix_3D_bounds.npy')
-
-merged_train = np.concatenate((sin_train, helix_train), axis=0)
-merged_test = np.concatenate((sin_test, helix_test), axis=0)
-merged_bounds = np.concatenate((sin_bounds, helix_bounds), axis=0)
-
-np.save('./data/3d_sin_helix_train.npy', merged_train)
-np.save('./data/3d_sin_helix_test.npy', merged_test)
-np.save('./data/3d_sin_helix_bounds.npy', merged_bounds)
-'''
-from generate import *
-
-train_data = get_cylinder_helix('train')
-print(train_data.shape)
-test_data = get_cylinder_helix('test')
-print(test_data.shape)
-
-np.save('./data/3d_cylinder_helix_train.npy',train_data)
-np.save('./data/3d_cylinder_helix_test.npy',test_data)
-
-helix_3d_train = np.load('./data/helix_3D_train.npy')
-helix_3d_test = np.load('./data/helix_3D_test.npy')
-
-helix_3d_train[:,2] += 2
-helix_3d_test[:,2] += 2
-
-np.save('./data/shifted_helix_3D_train.npy',helix_3d_train)
-np.save('./data/shifted_helix_3D_test.npy',helix_3d_test)
\
No newline at end of file
diff --git a/abstract_ddpm/finetune.py b/abstract_ddpm/finetune.py
index 5fea7c8..932feaa 100644
--- a/abstract_ddpm/finetune.py
+++ b/abstract_ddpm/finetune.py
@@ -3,6 +3,9 @@
import torch
import pytorch_lightning as pl
from finetune_model import FineTuneLitDiffusionModel
from dataset import ThreeDSinDataset
+"""
+Python file to perform the finetuning on the diffusion model.
+"""
parser = argparse.ArgumentParser()
diff --git a/abstract_ddpm/finetune_model.py b/abstract_ddpm/finetune_model.py
index 9c07a0f..39b5e65 100644
--- a/abstract_ddpm/finetune_model.py
+++ b/abstract_ddpm/finetune_model.py
@@ -18,6 +18,9 @@
def expand_alphas(batch, alpha, t):
class FineTuneLitDiffusionModel(pl.LightningModule):
+ """
+ Class for finuetuning the diffusion model.
+ """
def __init__(self, n_dim=3, n_steps=200, lbeta=1e-5, ubeta=1e-2, scheduler_type="linear",model_chkpt=None,model_hparams=None):
super().__init__()
"""
@@ -162,59 +165,11 @@
class FineTuneLitDiffusionModel(pl.LightningModule):
frozen_unconditioning_score = self.frozen_model(x,t)
frozen_nc_score = self.frozen_model(nc,t)
- #guidance = frozen_unconditioning_score - eta*(frozen_conditioning_score-frozen_unconditioning_score)
- #a = frozen_conditioning_score
- #b = frozen_unconditioning_score
- #guidance =
- #guidance = frozen_unconditioning_score + frozen_inv_conditioning_score
- #temp = frozen_unconditioning_score - frozen_conditioning_score
- #guidance = frozen_unconditioning_score + frozen_conditioning_score
- #temp = ((eta)*(frozen_empty_score - frozen_unconditioning_score) - (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score))/(eta-1.0)
- # temp = frozen_unconditioning_score - temp
- #guidance = frozen_unconditioning_score - frozen_empty_score + frozen_unconditioning_score + (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score)
- #guidance = (eta-1.0)*frozen_conditioning_score - (3.0-eta)*(frozen_unconditioning_score) - frozen_empty_score
- #guidance = frozen_unconditioning_score - (eta)*(frozen_conditioning_score-frozen_unconditioning_score)
- #temp = (eta/2.0)*((frozen_empty_score - frozen_unconditioning_score) - (eta-1.0)*(frozen_conditioning_score-frozen_unconditioning_score))
- #guidance = frozen_unconditioning_score - temp
- #a = frozen_conditioning_score
- #b = frozen_unconditioning_score
temp = frozen_unconditioning_score - frozen_nc_score
guidance = (frozen_unconditioning_score + (eta)*(temp))/(eta)
a = output
b = guidance - output
- # a = frozen_empty_score - frozen_unconditioning_score
- # b = frozen_empty_score - frozen_conditioning_score
- # temp = a-b
- #guidance = frozen_unconditioning_score - temp
- #guidance = frozen_unconditioning_score - b
- # fig = plt.figure()
- # fig.add_subplot(141,projection='3d')
- # ax = fig.add_subplot(141, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), a[:,2].detach().numpy(), c=batch_x[:,2], marker='o')
- # ax2 = fig.add_subplot(142, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), b[:,2].detach().numpy(), c=batch_x[:,2], marker='o')
- # ax3 = fig.add_subplot(143, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), temp[:,2].cpu().numpy(), c=batch_x[:,2], marker='o')
- # ax4 = fig.add_subplot(144, projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2], marker='o')
- # plt.show()
- # fig = plt.figure()
- # fig.add_subplot(121,projection='3d')
- # ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), output[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
- # ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2].cpu().numpy(), marker='o')
- # plt.show()
- # plt.clf()
- # guidance = frozen_empty_score + (eta-1.0)*(frozen_conditioning_score - frozen_unconditioning_score)
- # fig = plt.figure()
- # fig.add_subplot(121,projection='3d')
- # ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), output[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
- # ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), guidance[:,2].cpu().numpy(), c=batch_x[:,2].cpu().numpy(), marker='o')
- # plt.show()
- # plt.clf()
-
- # fig = plt.figure()
- # fig.add_subplot(121,projection='3d')
- # ax = fig.add_subplot(121,projection='3d').scatter(batch_x[:,0].detach().numpy(), batch_x[:,1].detach().numpy(), x[:,2].detach().numpy(), c=batch_x[:,2].detach().numpy(), marker='o')
- # ax2 = fig.add_subplot(122,projection='3d').scatter(batch_x[:,0].cpu().numpy(), batch_x[:,1].cpu().numpy(), c[:,2].cpu().numpy(), c=batch_conditioning[:,2].cpu().numpy(), marker='o')
- # plt.show()
- # plt.clf()
criteria = torch.nn.MSELoss()
return criteria(output, guidance)
diff --git a/abstract_ddpm/generate.py b/abstract_ddpm/generate.py
index 2c2e120..6af11c0 100644
--- a/abstract_ddpm/generate.py
+++ b/abstract_ddpm/generate.py
@@ -2,7 +2,9 @@
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm
-
+"""
+Contains various methods on generating distributions.
+"""
def gaussian(x, mu, sig):
return np.exp(-np.power(x - mu, 2.) / (2 * np.power(sig, 2.)))
diff --git a/conditional_ddpm/ddpm.py b/conditional_ddpm/ddpm.py
index 8a4a4a4..5eb018e 100644
--- a/conditional_ddpm/ddpm.py
+++ b/conditional_ddpm/ddpm.py
@@ -37,7 +37,7 @@
class DDPM(pl.LightningModule):
embedding_dim = 8
num_classes = 2
-
+ self.guidance_scale = 1
# nn.Embedding is used to embed time in a latend dimention that can be passed to the model
self.time_embed = nn.Embedding(n_steps,embedding_dim)
self.label_embed = nn.Embedding(num_classes, embedding_dim)
@@ -81,10 +81,15 @@
class DDPM(pl.LightningModule):
if not isinstance(t, torch.Tensor):
t = torch.LongTensor([t]).expand(x.size(0))
t_embed = self.time_embed(t)
+
+ noise_prediction_uncond = self.model(torch.cat((x, t_embed), dim=1).float())
if label != None:
label = label.to(torch.int)
- t_embed += self.label_embed(label)
- return self.model(torch.cat((x, t_embed), dim=1).float())
+ noise_prediction_text = self.model(torch.cat((x, t_embed+self.label_embed(label)), dim=1).float())
+ else:
+ noise_prediction_text = self.model(torch.cat((x, t_embed), dim=1).float())
+ noise_prediction = noise_prediction_uncond + self.guidance_scale * (noise_prediction_text - noise_prediction_uncond)
+ return noise_prediction
def init_alpha_beta_schedule(self, lbeta, ubeta):
diff --git a/conditional_ddpm/results/ddpm_0.pdf b/conditional_ddpm/results/ddpm_0.pdf
index ab58b03..45d0600 100644
Binary files a/conditional_ddpm/results/ddpm_0.pdf and b/conditional_ddpm/results/ddpm_0.pdf differ
diff --git a/conditional_ddpm/results/ddpm_1.pdf b/conditional_ddpm/results/ddpm_1.pdf
index 54e41b8..843f831 100644
Binary files a/conditional_ddpm/results/ddpm_1.pdf and b/conditional_ddpm/results/ddpm_1.pdf differ
diff --git a/conditional_ddpm/results/ddpm_None.pdf b/conditional_ddpm/results/ddpm_None.pdf
index 6a4264a..4deaa32 100644
Binary files a/conditional_ddpm/results/ddpm_None.pdf and b/conditional_ddpm/results/ddpm_None.pdf differ
diff --git a/conditional_ddpm/results/eddpm_0.pdf b/conditional_ddpm/results/eddpm_0.pdf
index 4b5028f..9593d2d 100644
Binary files a/conditional_ddpm/results/eddpm_0.pdf and b/conditional_ddpm/results/eddpm_0.pdf differ
diff --git a/conditional_ddpm/results/eddpm_1.pdf b/conditional_ddpm/results/eddpm_1.pdf
index aa73ee3..72a6f3e 100644
Binary files a/conditional_ddpm/results/eddpm_1.pdf and b/conditional_ddpm/results/eddpm_1.pdf differ
diff --git a/conditional_ddpm/results/eddpm_None.pdf b/conditional_ddpm/results/eddpm_None.pdf
index fe62651..5a06aef 100644
Binary files a/conditional_ddpm/results/eddpm_None.pdf and b/conditional_ddpm/results/eddpm_None.pdf differ
diff --git a/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt b/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt
index b78fac6..9db9168 100644
Binary files a/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt and b/conditional_ddpm/run_ddpm/epoch=0099-train_loss=0.000.ckpt differ
diff --git a/conditional_ddpm/run_ddpm/last.ckpt b/conditional_ddpm/run_ddpm/last.ckpt
index b78fac6..9db9168 100644
Binary files a/conditional_ddpm/run_ddpm/last.ckpt and b/conditional_ddpm/run_ddpm/last.ckpt differ
diff --git a/conditional_ddpm/run_ddpm/lightning_logs/version_0/events.out.tfevents.1683082330.saswat-HP-Pavilion.67046.0 b/conditional_ddpm/run_ddpm/lightning_logs/version_0/events.out.tfevents.1683082330.saswat-HP-Pavilion.67046.0
deleted file mode 100644
index b1ed0d9..0000000
Binary files a/conditional_ddpm/run_ddpm/lightning_logs/version_0/events.out.tfevents.1683082330.saswat-HP-Pavilion.67046.0 and /dev/null differ
diff --git a/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt b/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt
index 1a11410..b9579aa 100644
Binary files a/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt and b/conditional_ddpm/run_eddpm/epoch=0099-train_loss=0.000.ckpt differ
diff --git a/conditional_ddpm/run_eddpm/last.ckpt b/conditional_ddpm/run_eddpm/last.ckpt
index 1a11410..b9579aa 100644
Binary files a/conditional_ddpm/run_eddpm/last.ckpt and b/conditional_ddpm/run_eddpm/last.ckpt differ
diff --git a/conditional_ddpm/run_eddpm/lightning_logs/version_0/events.out.tfevents.1683082525.saswat-HP-Pavilion.67404.0 b/conditional_ddpm/run_eddpm/lightning_logs/version_0/events.out.tfevents.1683082525.saswat-HP-Pavilion.67404.0
deleted file mode 100644
index dd8faf7..0000000
Binary files a/conditional_ddpm/run_eddpm/lightning_logs/version_0/events.out.tfevents.1683082525.saswat-HP-Pavilion.67404.0 and /dev/null differ
diff --git a/stable_df_esd/app.py b/stable_df_esd/app.py
index 9402f6b..3509fa9 100644
--- a/stable_df_esd/app.py
+++ b/stable_df_esd/app.py
@@ -6,6 +6,9 @@
from PIL import Image
from inference import inference
import numpy as np
+"""
+Python file to run a demo on the finetuned model.
+"""
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
SD_PATH = 'CompVis/stable-diffusion-v1-4'
diff --git a/stable_df_esd/inference.py b/stable_df_esd/inference.py
index f7ecf81..9b348b6 100644
--- a/stable_df_esd/inference.py
+++ b/stable_df_esd/inference.py
@@ -6,6 +6,9 @@
from PIL import Image
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
def inference(prompt,sd_path, esd_path):
+ """
+ Given a prompt and a path to ESD generate image after loading ESD.
+ """
ddpm = SD()
ddpm = ddpm.to(DEVICE)
ddpm.eval()
diff --git a/stable_df_esd/models.py b/stable_df_esd/models.py
index 8574174..5e575c5 100644
--- a/stable_df_esd/models.py
+++ b/stable_df_esd/models.py
@@ -17,7 +17,16 @@
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
class SD(torch.nn.Module):
+ """
+ Implementation of Stable diffusion model.
+ """
def __init__(self):
+ """
+ Initialise the SD model with various pretraind models.
+ AutoencoderKL for dimentionality reduction of image, before passing it to
+ the UNet model for getting the score.
+ CLIPTokenizer and CLIPTextModel are used for toknization and encoding of text prompt.
+ """
super().__init__()
self.encoder = AutoencoderKL.from_pretrained(ENCODER_PATH,subfolder=ENCODER_FOLDER)
self.ddpm = UNet2DConditionModel.from_pretrained(DECODER_PATH,subfolder=DECODER_FOLDER)
@@ -28,7 +37,9 @@
class SD(torch.nn.Module):
@torch.no_grad()
def __call__(self,prompts,pixel_size=512,n_steps=100,batch_size=1,last_itr=None):
-
+ """
+ Given a prompt get output from reverse diffusion process.
+ """
if type(prompts) != list:
prompts = [prompts]
@@ -56,7 +67,9 @@
class SD(torch.nn.Module):
@torch.no_grad()
def reverse_diffusion(self,latents,embeddings,last_itr=1000,first_itr=0,original=False):
latents_steps = []
-
+ """
+ Implementation of reverse diffusion process.
+ """
for itr in tqdm(range(first_itr, last_itr)):
noise_pred = self.predict_noise(itr, latents, embeddings)
@@ -75,6 +88,9 @@
class SD(torch.nn.Module):
def encode_text(self,prompts, count):
+ """
+ Encode the text using the text tokenizer and encoder from CLIP model.
+ """
tokens = self.text_tokenize(prompts)
text_encodings = self.text_encode(tokens)
@@ -120,7 +136,9 @@
class SD(torch.nn.Module):
return pil_images
def predict_noise(self,iteration,latents,text_embeddings,guidance_scale=7.5):
-
+ """
+ The function that predicts noise given a latents, text embedding.
+ """
# Doing double forward pass
latents = torch.cat([latents] * 2)
latents = self.scheduler.scale_model_input(latents, self.scheduler.timesteps[iteration])
diff --git a/stable_df_esd/train.py b/stable_df_esd/train.py
index bb72028..4bb1883 100644
--- a/stable_df_esd/train.py
+++ b/stable_df_esd/train.py
@@ -9,6 +9,9 @@
ddpm = ddpm.to(DEVICE)
ddpm.train()
def train(prompt,epochs=100,eta=1.0,path='./saved_models/esd.pt'):
+ """
+ Method that finetunes the ESD model.
+ """
frozen_ddpm = deepcopy(ddpm)
frozen_ddpm.eval()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment