Commit 8d0daf49 authored by Meet Narendra's avatar Meet Narendra 💬

stable diffusion esd

parent 63fcce2f
import gradio as gr
import torch
from models import SD
from configs import *
from PIL import Image
from inference import inference
import numpy as np
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
SD_PATH = 'CompVis/stable-diffusion-v1-4'
ESD_PATH = './saved_models/esd.pt' #ocean
def generate_image(prompt):
orig_images, ft_images = inference(prompt,sd_path=SD_PATH,esd_path=ESD_PATH)
orig_images = orig_images.resize((200,200))
ft_images = ft_images.resize((200,200))
return orig_images, ft_images
inputs = gr.inputs.Textbox(lines=2, placeholder="Type here to generate an image...")
outputs = [
gr.outputs.Image(type="pil", label="Original Image").style(height="200px", width="200px"),
gr.outputs.Image(type="pil", label="Edited Image").style(height="200px", width="200px")
]
title = "Erasing Concepts from DDPMs"
description = "CS726 Project by Meet, Saswat, Osim."
gr.Interface(generate_image, inputs, outputs, title=title, description=description, examples=[["beachside sunset."]]).launch()
ENCODER_PATH = 'CompVis/stable-diffusion-v1-4'
ENCODER_FOLDER = 'vae'
DECODER_PATH = 'CompVis/stable-diffusion-v1-4'
DECODER_FOLDER = 'unet'
TEXT_TOKENIZER_PATH = 'openai/clip-vit-large-patch14'
TEXT_ENCODER_PATH = 'openai/clip-vit-large-patch14'
DDIM_SCHEDULER_PATH = 'CompVis/stable-diffusion-v1-4'
DDIM_SCHEDULER_FOLDER = 'scheduler'
\ No newline at end of file
from models import SD
from configs import *
import torch
from train import train
from PIL import Image
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
def inference(prompt,sd_path, esd_path):
ddpm = SD()
ddpm = ddpm.to(DEVICE)
ddpm.eval()
gen_images_sd = ddpm(prompt,n_steps=50)
orig_images = gen_images_sd[0][0]
del ddpm
torch.cuda.empty_cache()
esd = SD()
esd.load_state_dict(torch.load(esd_path))
esd = esd.to(DEVICE)
esd.eval()
gen_images_esd = esd(prompt,n_steps=50)
ft_images = gen_images_esd[0][0]
del esd
torch.cuda.empty_cache()
return orig_images, ft_images
if __name__ == '__main__':
SD_PATH = 'CompVis/stable-diffusion-v1-4'
ESD_PATH = './saved_models/esd.pt' #ocean
orig_images, ft_images = inference('beachside sunset',SD_PATH,ESD_PATH)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from configs import *
from diffusers import AutoencoderKL
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPFeatureExtractor
from diffusers import UNet2DConditionModel
from tqdm import tqdm
from PIL import Image
from utils import *
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
class SD(torch.nn.Module):
def __init__(self):
super().__init__()
self.encoder = AutoencoderKL.from_pretrained(ENCODER_PATH,subfolder=ENCODER_FOLDER)
self.ddpm = UNet2DConditionModel.from_pretrained(DECODER_PATH,subfolder=DECODER_FOLDER)
self.tokenizer = CLIPTokenizer.from_pretrained(TEXT_TOKENIZER_PATH)
self.text_encoder = CLIPTextModel.from_pretrained(TEXT_ENCODER_PATH)
self.scheduler = DDIMScheduler.from_pretrained(DDIM_SCHEDULER_PATH,subfolder=DDIM_SCHEDULER_FOLDER)
self.eval()
@torch.no_grad()
def __call__(self,prompts,pixel_size=512,n_steps=100,batch_size=1,last_itr=None):
if type(prompts) != list:
prompts = [prompts]
self.scheduler.set_timesteps(n_steps,DEVICE)
noise = torch.randn(batch_size, self.ddpm.in_channels, pixel_size//8, pixel_size//8, device=DEVICE).repeat(len(prompts), 1, 1, 1)
latent = self.scheduler.init_noise_sigma * noise
text_encodings = self.encode_text(prompts=prompts,count=batch_size)
#print(text_encodings)
last_itr = last_itr if last_itr is not None else n_steps
latent_steps = self.reverse_diffusion(latent,text_encodings,last_itr=last_itr)
latent_steps = [self.decode(latent.to(DEVICE)) for latent in latent_steps]
image_steps = [self.to_image(image) for image in latent_steps]
image_steps = list(zip(*image_steps))
return image_steps
@torch.no_grad()
def reverse_diffusion(self,latents,embeddings,last_itr=1000,first_itr=0,original=False):
latents_steps = []
for itr in tqdm(range(first_itr, last_itr)):
noise_pred = self.predict_noise(itr, latents, embeddings)
#calculate xt-1
output = self.scheduler.step(noise_pred, self.scheduler.timesteps[itr], latents)
latents = output.prev_sample
if itr == last_itr - 1:
output = output.pred_original_sample if original else latents
latents_steps.append(output)
return latents_steps
def encode_text(self,prompts, count):
tokens = self.text_tokenize(prompts)
text_encodings = self.text_encode(tokens)
tokens_uncondition = self.text_tokenize([" "] * len(prompts))
text_encodings_uncondition = self.text_encode(tokens_uncondition)
#print(text_encodings_uncondition.shape)
embeddings = torch.cat([text_encodings_uncondition, text_encodings])
embeddings = embeddings.repeat_interleave(count, 0)
return embeddings
def add_noise(self, latents, noise, step):
return self.scheduler.add_noise(latents, noise, torch.tensor([self.scheduler.timesteps[step]]))
def text_tokenize(self, prompts):
return self.tokenizer(prompts, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
def text_detokenize(self, tokens):
return [self.tokenizer.decode(token) for token in tokens if token != self.tokenizer.vocab_size - 1]
def text_encode(self, tokens):
return self.text_encoder(tokens.input_ids.to(self.ddpm.device))[0]
def decode(self, latents):
return self.encoder.decode(1 / self.encoder.config.scaling_factor * latents).sample
def encode(self, tensors):
return self.encoder.encode(tensors).latent_dist.mode() * 0.18215
def to_image(self, image):
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def predict_noise(self,iteration,latents,text_embeddings,guidance_scale=7.5):
# Doing double forward pass
latents = torch.cat([latents] * 2)
latents = self.scheduler.scale_model_input(latents, self.scheduler.timesteps[iteration])
# Noise prediction
noise_prediction = self.ddpm(latents, self.scheduler.timesteps[iteration], encoder_hidden_states=text_embeddings).sample
# Classifier free guidance
noise_prediction_uncond, noise_prediction_text = noise_prediction.chunk(2)
noise_prediction = noise_prediction_uncond + guidance_scale * (noise_prediction_text - noise_prediction_uncond)
return noise_prediction
if __name__ == "__main__":
#model = SD().to(DEVICE).eval()
model = SD().to(DEVICE).eval()
generated_images = model(prompts="House",n_steps=20,batch_size=1)
image_grid(generated_images,outpath='./images/out')
from models import SD
from configs import *
import torch
from copy import deepcopy
DEVICE = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
ddpm = SD()
ddpm = ddpm.to(DEVICE)
ddpm.train()
def train(prompt,epochs=100,eta=1.0,path='./saved_models/esd.pt'):
frozen_ddpm = deepcopy(ddpm)
frozen_ddpm.eval()
optimizer = torch.optim.Adam(ddpm.parameters(),lr=1e-5)
criterion = torch.nn.MSELoss()
with torch.no_grad():
unconditioned_embeddings = frozen_ddpm.encode_text([''],count=1)
conditioned_embeddings = frozen_ddpm.encode_text([prompt],count=1)
del frozen_ddpm.tokenizer
del frozen_ddpm.text_encoder
del frozen_ddpm.encoder
torch.cuda.empty_cache()
for epoch in range(epochs):
with torch.no_grad():
frozen_ddpm.scheduler.set_timesteps(50,DEVICE)
optimizer.zero_grad()
t = torch.randint(1,50-1,(1,)).item()
noise = torch.randn(1, frozen_ddpm.ddpm.in_channels, 512//8, 512//8, device=DEVICE).repeat(1, 1, 1, 1)
latent = frozen_ddpm.scheduler.init_noise_sigma * noise
ddpm.scheduler.set_timesteps(50,DEVICE)
latent_steps = ddpm.reverse_diffusion(latent,conditioned_embeddings,last_itr=t,first_itr=0,original=False)
frozen_ddpm.scheduler.set_timesteps(1000,DEVICE)
ddpm.scheduler.set_timesteps(1000,DEVICE)
t = int(t/50*1000)
latents_pos = frozen_ddpm.predict_noise(t,latent_steps[0],conditioned_embeddings)
latents_neutral = frozen_ddpm.predict_noise(t,latent_steps[0],unconditioned_embeddings)
latents_neg = ddpm.predict_noise(t,latent_steps[0],conditioned_embeddings)
latents_pos.requires_grad = False
latents_neutral.requires_grad = False
loss = criterion(latents_neg,latents_neutral-(eta*(latents_pos-latents_neutral)))
loss.backward()
optimizer.step()
print(f'Epoch: {epoch} Loss: {loss.item()}')
torch.save(ddpm.state_dict(),path)
torch.cuda.empty_cache()
if __name__ == '__main__':
train('ocean',epochs=100,eta=1e-3,path='./saved_models/esd.pt')
#Much of this code is borrowed
import matplotlib.pyplot as plt
import textwrap
from PIL import Image
def image_grid(images, outpath=None, column_titles=None, row_titles=None):
n_rows = len(images)
n_cols = len(images[0])
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols,
figsize=(n_cols, n_rows), squeeze=False)
for row, _images in enumerate(images):
for column, image in enumerate(_images):
ax = axs[row][column]
ax.imshow(image)
if column_titles and row == 0:
ax.set_title(textwrap.fill(
column_titles[column], width=12), fontsize='x-small')
if row_titles and column == 0:
ax.set_ylabel(row_titles[row], rotation=0, fontsize='x-small', labelpad=1.6 * len(row_titles[row]))
ax.set_xticks([])
ax.set_yticks([])
plt.subplots_adjust(wspace=0, hspace=0)
if outpath is not None:
plt.savefig(outpath, bbox_inches='tight', dpi=300)
plt.close()
else:
plt.tight_layout(pad=0)
image = figure_to_image(plt.gcf())
plt.close()
return image
def figure_to_image(figure):
figure.set_dpi(300)
figure.canvas.draw()
return Image.frombytes('RGB', figure.canvas.get_width_height(), figure.canvas.tostring_rgb())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment