Commit b32fb1a1 authored by Saswat's avatar Saswat

add GUI

parent c580318c
...@@ -2,8 +2,17 @@ import torch ...@@ -2,8 +2,17 @@ import torch
class CFG: class CFG:
debug = False debug = False
image_path = "/home/saswat/Desktop/flickr30k_images/flickr30k_images/" seed = 42
# Paths
dataset_path = "/home/saswat/Desktop/flickr30k_images/"
image_path = dataset_path+"flickr30k_images/"
captions_path = "." captions_path = "."
model_path = "./model/best.pt"
img_emb_path = "./model/image_embeddings.pt"
# Training Params
epochs = 2
batch_size = 8 batch_size = 8
num_workers = 4 num_workers = 4
head_lr = 1e-3 head_lr = 1e-3
...@@ -12,9 +21,9 @@ class CFG: ...@@ -12,9 +21,9 @@ class CFG:
weight_decay = 1e-3 weight_decay = 1e-3
patience = 1 patience = 1
factor = 0.8 factor = 0.8
epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Pretrained model (Text Embeddings/ Image Embeddings) params
model_name = 'resnet50' model_name = 'resnet50'
image_embedding = 2048 image_embedding = 2048
text_encoder_model = "distilbert-base-uncased" text_encoder_model = "distilbert-base-uncased"
...@@ -34,4 +43,4 @@ class CFG: ...@@ -34,4 +43,4 @@ class CFG:
projection_dim = 256 projection_dim = 256
dropout = 0.1 dropout = 0.1
model_path = "./model/best.pt"
\ No newline at end of file \ No newline at end of file
import torch import torch
from config import CFG from config import CFG
import cv2 import cv2
import torch.nn.functional as F
import albumentations as A import albumentations as A
import numpy as np
from utils import *
import pandas as pd
class CLIPDataset(torch.utils.data.Dataset): class CLIPDataset(torch.utils.data.Dataset):
"""
Dataset class for getting image and text using a dataset loader.
"""
def __init__(self, image_filenames, captions, tokenizer, transforms): def __init__(self, image_filenames, captions, tokenizer, transforms):
""" """
image_filenames and cpations must have the same length; so, if there are image_filenames and cpations must have the same length; so, if there are
...@@ -19,6 +26,9 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -19,6 +26,9 @@ class CLIPDataset(torch.utils.data.Dataset):
self.transforms = transforms self.transforms = transforms
def __getitem__(self, idx): def __getitem__(self, idx):
"""
Provided a index return the encoded caption with image and original caption.
"""
item = { item = {
key: torch.tensor(values[idx]) key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items() for key, values in self.encoded_captions.items()
...@@ -34,11 +44,16 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -34,11 +44,16 @@ class CLIPDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
"""
Returns size of our training data.
"""
return len(self.captions) return len(self.captions)
def get_transforms(mode="train"): def get_transforms(mode="train"):
"""
Implements image transformations.
"""
if mode == "train": if mode == "train":
return A.Compose( return A.Compose(
[ [
...@@ -53,3 +68,39 @@ def get_transforms(mode="train"): ...@@ -53,3 +68,39 @@ def get_transforms(mode="train"):
A.Normalize(max_pixel_value=255.0, always_apply=True), A.Normalize(max_pixel_value=255.0, always_apply=True),
] ]
) )
def gen_train_valid_dfs():
"""
Split dataset into train and validation dataset.
"""
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id)
np.random.seed(CFG.seed)
valid_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
def get_dataset_loader(dataframe, tokenizer, mode):
"""
Build a dataset loader using CLIPDataset.
"""
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
\ No newline at end of file
...@@ -5,7 +5,7 @@ from transformers import DistilBertModel, DistilBertConfig ...@@ -5,7 +5,7 @@ from transformers import DistilBertModel, DistilBertConfig
class ImageEncoder(nn.Module): class ImageEncoder(nn.Module):
""" """
Encode images to a fixed size vector Encode images using a pretrained model like resnet.
""" """
def __init__( def __init__(
...@@ -21,6 +21,9 @@ class ImageEncoder(nn.Module): ...@@ -21,6 +21,9 @@ class ImageEncoder(nn.Module):
return self.model(x) return self.model(x)
class TextEncoder(nn.Module): class TextEncoder(nn.Module):
"""
Encode Text using a pretrained Langauge model as DistilBert
"""
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__() super().__init__()
if pretrained: if pretrained:
...@@ -40,6 +43,9 @@ class TextEncoder(nn.Module): ...@@ -40,6 +43,9 @@ class TextEncoder(nn.Module):
return last_hidden_state[:, self.target_token_idx, :] return last_hidden_state[:, self.target_token_idx, :]
class ProjectionHead(nn.Module): class ProjectionHead(nn.Module):
"""
Convert dimentions of both image and text embeddings to a fixed size of embedding.
"""
def __init__( def __init__(
self, self,
embedding_dim, embedding_dim,
......
from tkinter import *
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from tkinter.ttk import *
from infer import find_matches
from data import make_train_valid_dfs
from infer import *
from config import CFG
_, valid_df = make_train_valid_dfs()
model = get_image_embeddings(valid_df, CFG.model_path)
plot = None
def get_figure():
global plot
print("getting figure for sentence: ", entry1.get())
fig = find_matches(model, query=entry1.get(), image_filenames=valid_df['image'].values, n=9)
print("Got figure for sentence: ", entry1.get())
plot = FigureCanvasTkAgg(fig, root)
plot.get_tk_widget().grid(row=2, column=0, rowspan=7, columnspan= 7)
#plot.get_tk_widget().pack(side='top')
root = Tk()
width = root.winfo_screenwidth()
height = root.winfo_screenheight()
root.geometry("%dx%d"%(width,height))
root.title('CLIP Model Demo')
label1 = Label(root, text="Enter Text")
label1.grid(row=0, column=0)
#label1.pack(side='top')
entry1 = Entry(root, width=50)
entry1.grid(row=0, column=1, columnspan=5)
#entry1.pack(side='top')
button1 = Button(root, text="Get Images", command=get_figure)
button1.grid(row=0, column = 6)
#button1.pack(side='top')
root.mainloop()
...@@ -7,27 +7,33 @@ from config import CFG ...@@ -7,27 +7,33 @@ from config import CFG
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import cv2 import cv2
import torch.nn.functional as F import torch.nn.functional as F
from data import *
def get_image_embeddings(valid_df, model_path): img_embeddings = None
def get_image_embeddings(valid_df, model_path, get_emb = False):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid") valid_loader = get_dataset_loader(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device) model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(model_path, map_location=CFG.device)) model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval() model.eval()
valid_image_embeddings = [] valid_image_embeddings = []
if get_emb:
with torch.no_grad(): with torch.no_grad():
for batch in tqdm(valid_loader): for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device)) image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features) image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings) valid_image_embeddings.append(image_embeddings)
return model, torch.cat(valid_image_embeddings) torch.save(torch.cat(valid_image_embeddings), CFG.img_emb_path)
return model
_, valid_df = make_train_valid_dfs() def load_img_embeddings():
model, image_embeddings = get_image_embeddings(valid_df, "./model/best.pt") global img_embeddings
img_embeddings = torch.load(CFG.img_emb_path, map_location=CFG.device)
img_embeddings = F.normalize(img_embeddings, p=2, dim=-1)
def find_matches(model, image_embeddings, query, image_filenames, n=9): def find_matches(model, query, image_filenames, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
encoded_query = tokenizer([query]) encoded_query = tokenizer([query])
batch = { batch = {
...@@ -40,25 +46,24 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9): ...@@ -40,25 +46,24 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
) )
text_embeddings = model.text_projection(text_features) text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) image_embeddings_n = img_embeddings
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T dot_similarity = text_embeddings_n @ image_embeddings_n.T
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5) values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [image_filenames[idx] for idx in indices[::5]] matches = [image_filenames[idx] for idx in indices[::5]]
plt.clf()
_, axes = plt.subplots(3, 3, figsize=(10, 10)) plt.cla()
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
for match, ax in zip(matches, axes.flatten()): for match, ax in zip(matches, axes.flatten()):
image = cv2.imread(f"{CFG.image_path}/{match}") image = cv2.imread(f"{CFG.image_path}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ax.imshow(image) ax.imshow(image)
ax.axis("off") ax.axis("off")
return fig
plt.show() load_img_embeddings()
##_, valid_df = gen_train_valid_dfs()
find_matches(model, #model = get_image_embeddings(valid_df, "./model/best.pt")
image_embeddings, #find_matches(model, query="boy is playing", image_filenames=valid_df['image'].values, n=9)
query="boy is playing",
image_filenames=valid_df['image'].values,
n=9)
...@@ -6,80 +6,110 @@ from config import CFG ...@@ -6,80 +6,110 @@ from config import CFG
import itertools import itertools
from clip import CLIPModel from clip import CLIPModel
from utils import AvgMeter, get_lr from utils import AvgMeter, get_lr
from dataset import CLIPDataset, get_transforms from data import *
from transformers import DistilBertTokenizer from transformers import DistilBertTokenizer
import torch.nn.functional as F import torch.nn.functional as F
from utils import * from utils import *
import pickle
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
loss_meter = AvgMeter()
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch in tqdm_object:
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
count = batch["image"].size(0) # Reference for loading dataset and model architecture.
loss_meter.update(loss.item(), count) # https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2
#
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
return loss_meter
def train_epoch(model, train_data_loader, optimizer):
def valid_epoch(model, valid_loader): """
loss_meter = AvgMeter() Method to start a train epoch.
"""
tqdm_object = tqdm(valid_loader, total=len(valid_loader)) loss_sum = 0
for batch in tqdm_object: loss_count = 0
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} for batch in tqdm(train_data_loader, total=len(train_data_loader)):
batch_data = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch_data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_size = batch_data["image"].size(0)
loss_count += max(1, batch_size)
loss_sum += loss.item() * max(1,batch_size)
return loss_sum/loss_count
def valid_epoch(model, valid_data_loader):
"""
Method to start a validation epoch.
"""
loss_sum = 0
loss_count = 0
#tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm(valid_data_loader, total=len(valid_data_loader)):
batch_data = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch) loss = model(batch)
batch_size = batch_data["image"].size(0)
loss_count += max(1, batch_size)
loss_sum += loss.item() * max(1,batch_size)
#tqdm_object.set_postfix(valid_loss=loss_sum/loss_count)
return loss_sum/loss_count
def store_image_embeddings(valid_df, model):
"""
Generate and store embeddings of validation
"""
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = get_dataset_loader(valid_df, tokenizer, mode="valid")
count = batch["image"].size(0) #model = CLIPModel().to(CFG.device)
loss_meter.update(loss.item(), count) #model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval()
tqdm_object.set_postfix(valid_loss=loss_meter.avg) valid_image_embeddings = []
return loss_meter with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
torch.save(torch.cat(image_embeddings), CFG.img_emb_path)
return
def main(): def main():
train_df, valid_df = make_train_valid_dfs()
# Generate train and valid dataframe
train_data, valid_data = gen_train_valid_dfs()
# Use pretrained tokeniser to tokensize the captions
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
# Create data loaders for both train and validation data
train_loader = get_dataset_loader(train_data, tokenizer, mode="train")
valid_loader = get_dataset_loader(valid_data, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device) model = CLIPModel().to(CFG.device)
params = [ params = [
{"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr}, {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
{"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr}, {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
{"params": itertools.chain( {"params": itertools.chain( model.image_projection.parameters(), model.text_projection.parameters()
model.image_projection.parameters(), model.text_projection.parameters()
), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay} ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
] ]
optimizer = torch.optim.AdamW(params, weight_decay=0.) optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
) )
step = "epoch" min_loss = float('inf')
best_loss = float('inf')
for epoch in range(CFG.epochs): for epoch in range(CFG.epochs):
print(f"Epoch: {epoch + 1}") print(f"Epoch: {epoch + 1}")
model.train() model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) train_loss = train_epoch(model, train_loader, optimizer)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader) valid_loss = valid_epoch(model, valid_loader)
if valid_loss.avg < best_loss: if valid_loss < min_loss:
best_loss = valid_loss.avg min_loss = valid_loss
torch.save(model.state_dict(), CFG.model_path) torch.save(model.state_dict(), CFG.model_path)
print("Saved Best Model!") print("Saved Best Model!")
lr_scheduler.step(valid_loss.avg) lr_scheduler.step(valid_loss)
main() main()
import numpy as np
import pandas as pd import pandas as pd
from dataset import * from data import *
import torch.nn.functional as F from config import CFG
def generate_context():
df = pd.read_csv(CFG.image_path+"/results.csv", delimiter="|") def add_imageid():
"""
modify the caption and image name file with assigning id to each image which has
more than one caption.
"""
df = pd.read_csv(CFG.dataset_path+"/results.csv", delimiter="|")
df.columns = ['image', 'caption_number', 'caption'] df.columns = ['image', 'caption_number', 'caption']
df['caption'] = df['caption'].str.lstrip() df['caption'] = df['caption'].str.lstrip()
df['caption_number'] = df['caption_number'].str.lstrip() df['caption_number'] = df['caption_number'].str.lstrip()
...@@ -12,40 +16,12 @@ def generate_context(): ...@@ -12,40 +16,12 @@ def generate_context():
df.loc[19999, 'caption'] = "A dog runs across the grass ." df.loc[19999, 'caption'] = "A dog runs across the grass ."
ids = [id_ for id_ in range(len(df) // 5) for i in range(5)] ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
df['id'] = ids df['id'] = ids
df.to_csv("captions.csv", index=False) df.to_csv(CFG.captions_path+"captions.csv", index=False)
df.head()
def make_train_valid_dfs():
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id)
np.random.seed(42)
valid_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
def build_loaders(dataframe, tokenizer, mode):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
class AvgMeter: class AvgMeter:
"""
Helps in storing average of loss across a batch/epoch
"""
def __init__(self, name="Metric"): def __init__(self, name="Metric"):
self.name = name self.name = name
self.reset() self.reset()
...@@ -63,5 +39,8 @@ class AvgMeter: ...@@ -63,5 +39,8 @@ class AvgMeter:
return text return text
def get_lr(optimizer): def get_lr(optimizer):
"""
Return paramater gourp for lr in oprimiser.
"""
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
return param_group["lr"] return param_group["lr"]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment