Commit b32fb1a1 authored by Saswat's avatar Saswat

add GUI

parent c580318c
......@@ -2,8 +2,17 @@ import torch
class CFG:
debug = False
image_path = "/home/saswat/Desktop/flickr30k_images/flickr30k_images/"
seed = 42
# Paths
dataset_path = "/home/saswat/Desktop/flickr30k_images/"
image_path = dataset_path+"flickr30k_images/"
captions_path = "."
model_path = "./model/best.pt"
img_emb_path = "./model/image_embeddings.pt"
# Training Params
epochs = 2
batch_size = 8
num_workers = 4
head_lr = 1e-3
......@@ -12,9 +21,9 @@ class CFG:
weight_decay = 1e-3
patience = 1
factor = 0.8
epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Pretrained model (Text Embeddings/ Image Embeddings) params
model_name = 'resnet50'
image_embedding = 2048
text_encoder_model = "distilbert-base-uncased"
......@@ -34,4 +43,4 @@ class CFG:
projection_dim = 256
dropout = 0.1
model_path = "./model/best.pt"
\ No newline at end of file
\ No newline at end of file
import torch
from config import CFG
import cv2
import torch.nn.functional as F
import albumentations as A
import numpy as np
from utils import *
import pandas as pd
class CLIPDataset(torch.utils.data.Dataset):
"""
Dataset class for getting image and text using a dataset loader.
"""
def __init__(self, image_filenames, captions, tokenizer, transforms):
"""
image_filenames and cpations must have the same length; so, if there are
......@@ -19,6 +26,9 @@ class CLIPDataset(torch.utils.data.Dataset):
self.transforms = transforms
def __getitem__(self, idx):
"""
Provided a index return the encoded caption with image and original caption.
"""
item = {
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
......@@ -34,11 +44,16 @@ class CLIPDataset(torch.utils.data.Dataset):
def __len__(self):
"""
Returns size of our training data.
"""
return len(self.captions)
def get_transforms(mode="train"):
"""
Implements image transformations.
"""
if mode == "train":
return A.Compose(
[
......@@ -52,4 +67,40 @@ def get_transforms(mode="train"):
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
\ No newline at end of file
)
def gen_train_valid_dfs():
"""
Split dataset into train and validation dataset.
"""
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id)
np.random.seed(CFG.seed)
valid_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
def get_dataset_loader(dataframe, tokenizer, mode):
"""
Build a dataset loader using CLIPDataset.
"""
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
\ No newline at end of file
......@@ -5,7 +5,7 @@ from transformers import DistilBertModel, DistilBertConfig
class ImageEncoder(nn.Module):
"""
Encode images to a fixed size vector
Encode images using a pretrained model like resnet.
"""
def __init__(
......@@ -21,6 +21,9 @@ class ImageEncoder(nn.Module):
return self.model(x)
class TextEncoder(nn.Module):
"""
Encode Text using a pretrained Langauge model as DistilBert
"""
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
super().__init__()
if pretrained:
......@@ -40,6 +43,9 @@ class TextEncoder(nn.Module):
return last_hidden_state[:, self.target_token_idx, :]
class ProjectionHead(nn.Module):
"""
Convert dimentions of both image and text embeddings to a fixed size of embedding.
"""
def __init__(
self,
embedding_dim,
......
from tkinter import *
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from tkinter.ttk import *
from infer import find_matches
from data import make_train_valid_dfs
from infer import *
from config import CFG
_, valid_df = make_train_valid_dfs()
model = get_image_embeddings(valid_df, CFG.model_path)
plot = None
def get_figure():
global plot
print("getting figure for sentence: ", entry1.get())
fig = find_matches(model, query=entry1.get(), image_filenames=valid_df['image'].values, n=9)
print("Got figure for sentence: ", entry1.get())
plot = FigureCanvasTkAgg(fig, root)
plot.get_tk_widget().grid(row=2, column=0, rowspan=7, columnspan= 7)
#plot.get_tk_widget().pack(side='top')
root = Tk()
width = root.winfo_screenwidth()
height = root.winfo_screenheight()
root.geometry("%dx%d"%(width,height))
root.title('CLIP Model Demo')
label1 = Label(root, text="Enter Text")
label1.grid(row=0, column=0)
#label1.pack(side='top')
entry1 = Entry(root, width=50)
entry1.grid(row=0, column=1, columnspan=5)
#entry1.pack(side='top')
button1 = Button(root, text="Get Images", command=get_figure)
button1.grid(row=0, column = 6)
#button1.pack(side='top')
root.mainloop()
......@@ -7,27 +7,33 @@ from config import CFG
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
from data import *
def get_image_embeddings(valid_df, model_path):
img_embeddings = None
def get_image_embeddings(valid_df, model_path, get_emb = False):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
valid_loader = get_dataset_loader(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval()
valid_image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
return model, torch.cat(valid_image_embeddings)
if get_emb:
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
torch.save(torch.cat(valid_image_embeddings), CFG.img_emb_path)
return model
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "./model/best.pt")
def load_img_embeddings():
global img_embeddings
img_embeddings = torch.load(CFG.img_emb_path, map_location=CFG.device)
img_embeddings = F.normalize(img_embeddings, p=2, dim=-1)
def find_matches(model, image_embeddings, query, image_filenames, n=9):
def find_matches(model, query, image_filenames, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
encoded_query = tokenizer([query])
batch = {
......@@ -40,25 +46,24 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
image_embeddings_n = img_embeddings
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [image_filenames[idx] for idx in indices[::5]]
_, axes = plt.subplots(3, 3, figsize=(10, 10))
plt.clf()
plt.cla()
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
for match, ax in zip(matches, axes.flatten()):
image = cv2.imread(f"{CFG.image_path}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
ax.imshow(image)
ax.axis("off")
plt.show()
return fig
find_matches(model,
image_embeddings,
query="boy is playing",
image_filenames=valid_df['image'].values,
n=9)
load_img_embeddings()
##_, valid_df = gen_train_valid_dfs()
#model = get_image_embeddings(valid_df, "./model/best.pt")
#find_matches(model, query="boy is playing", image_filenames=valid_df['image'].values, n=9)
......@@ -6,80 +6,110 @@ from config import CFG
import itertools
from clip import CLIPModel
from utils import AvgMeter, get_lr
from dataset import CLIPDataset, get_transforms
from data import *
from transformers import DistilBertTokenizer
import torch.nn.functional as F
from utils import *
import pickle
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
loss_meter = AvgMeter()
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch in tqdm_object:
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
count = batch["image"].size(0)
loss_meter.update(loss.item(), count)
# Reference for loading dataset and model architecture.
# https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2
#
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
return loss_meter
def train_epoch(model, train_data_loader, optimizer):
"""
Method to start a train epoch.
"""
loss_sum = 0
loss_count = 0
for batch in tqdm(train_data_loader, total=len(train_data_loader)):
batch_data = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch_data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_size = batch_data["image"].size(0)
loss_count += max(1, batch_size)
loss_sum += loss.item() * max(1,batch_size)
return loss_sum/loss_count
def valid_epoch(model, valid_loader):
loss_meter = AvgMeter()
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm_object:
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
def valid_epoch(model, valid_data_loader):
"""
Method to start a validation epoch.
"""
loss_sum = 0
loss_count = 0
#tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm(valid_data_loader, total=len(valid_data_loader)):
batch_data = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch)
batch_size = batch_data["image"].size(0)
loss_count += max(1, batch_size)
loss_sum += loss.item() * max(1,batch_size)
#tqdm_object.set_postfix(valid_loss=loss_sum/loss_count)
return loss_sum/loss_count
count = batch["image"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
return loss_meter
def store_image_embeddings(valid_df, model):
"""
Generate and store embeddings of validation
"""
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = get_dataset_loader(valid_df, tokenizer, mode="valid")
#model = CLIPModel().to(CFG.device)
#model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval()
valid_image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
torch.save(torch.cat(image_embeddings), CFG.img_emb_path)
return
def main():
train_df, valid_df = make_train_valid_dfs()
# Generate train and valid dataframe
train_data, valid_data = gen_train_valid_dfs()
# Use pretrained tokeniser to tokensize the captions
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
# Create data loaders for both train and validation data
train_loader = get_dataset_loader(train_data, tokenizer, mode="train")
valid_loader = get_dataset_loader(valid_data, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device)
params = [
{"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr},
{"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr},
{"params": itertools.chain(
model.image_projection.parameters(), model.text_projection.parameters()
{"params": itertools.chain( model.image_projection.parameters(), model.text_projection.parameters()
), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)
step = "epoch"
best_loss = float('inf')
min_loss = float('inf')
for epoch in range(CFG.epochs):
print(f"Epoch: {epoch + 1}")
model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
train_loss = train_epoch(model, train_loader, optimizer)
model.eval()
with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader)
if valid_loss.avg < best_loss:
best_loss = valid_loss.avg
if valid_loss < min_loss:
min_loss = valid_loss
torch.save(model.state_dict(), CFG.model_path)
print("Saved Best Model!")
lr_scheduler.step(valid_loss.avg)
lr_scheduler.step(valid_loss)
main()
import numpy as np
import pandas as pd
from dataset import *
import torch.nn.functional as F
from data import *
from config import CFG
def generate_context():
df = pd.read_csv(CFG.image_path+"/results.csv", delimiter="|")
def add_imageid():
"""
modify the caption and image name file with assigning id to each image which has
more than one caption.
"""
df = pd.read_csv(CFG.dataset_path+"/results.csv", delimiter="|")
df.columns = ['image', 'caption_number', 'caption']
df['caption'] = df['caption'].str.lstrip()
df['caption_number'] = df['caption_number'].str.lstrip()
......@@ -12,40 +16,12 @@ def generate_context():
df.loc[19999, 'caption'] = "A dog runs across the grass ."
ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
df['id'] = ids
df.to_csv("captions.csv", index=False)
df.head()
def make_train_valid_dfs():
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id)
np.random.seed(42)
valid_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
def build_loaders(dataframe, tokenizer, mode):
transforms = get_transforms(mode=mode)
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
num_workers=CFG.num_workers,
shuffle=True if mode == "train" else False,
)
return dataloader
df.to_csv(CFG.captions_path+"captions.csv", index=False)
class AvgMeter:
"""
Helps in storing average of loss across a batch/epoch
"""
def __init__(self, name="Metric"):
self.name = name
self.reset()
......@@ -63,5 +39,8 @@ class AvgMeter:
return text
def get_lr(optimizer):
"""
Return paramater gourp for lr in oprimiser.
"""
for param_group in optimizer.param_groups:
return param_group["lr"]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment