Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CLIP
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Saswat
CLIP
Commits
b32fb1a1
Commit
b32fb1a1
authored
Nov 27, 2022
by
Saswat
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add GUI
parent
c580318c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
227 additions
and
106 deletions
+227
-106
config.py
config.py
+12
-3
data.py
data.py
+54
-3
encoders.py
encoders.py
+7
-1
gui.py
gui.py
+41
-0
infer.py
infer.py
+26
-21
train.py
train.py
+71
-41
utils.py
utils.py
+16
-37
No files found.
config.py
View file @
b32fb1a1
...
...
@@ -2,8 +2,17 @@ import torch
class
CFG
:
debug
=
False
image_path
=
"/home/saswat/Desktop/flickr30k_images/flickr30k_images/"
seed
=
42
# Paths
dataset_path
=
"/home/saswat/Desktop/flickr30k_images/"
image_path
=
dataset_path
+
"flickr30k_images/"
captions_path
=
"."
model_path
=
"./model/best.pt"
img_emb_path
=
"./model/image_embeddings.pt"
# Training Params
epochs
=
2
batch_size
=
8
num_workers
=
4
head_lr
=
1e-3
...
...
@@ -12,9 +21,9 @@ class CFG:
weight_decay
=
1e-3
patience
=
1
factor
=
0.8
epochs
=
2
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Pretrained model (Text Embeddings/ Image Embeddings) params
model_name
=
'resnet50'
image_embedding
=
2048
text_encoder_model
=
"distilbert-base-uncased"
...
...
@@ -34,4 +43,4 @@ class CFG:
projection_dim
=
256
dropout
=
0.1
model_path
=
"./model/best.pt"
\ No newline at end of file
\ No newline at end of file
data
set
.py
→
data.py
View file @
b32fb1a1
import
torch
from
config
import
CFG
import
cv2
import
torch.nn.functional
as
F
import
albumentations
as
A
import
numpy
as
np
from
utils
import
*
import
pandas
as
pd
class
CLIPDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
Dataset class for getting image and text using a dataset loader.
"""
def
__init__
(
self
,
image_filenames
,
captions
,
tokenizer
,
transforms
):
"""
image_filenames and cpations must have the same length; so, if there are
...
...
@@ -19,6 +26,9 @@ class CLIPDataset(torch.utils.data.Dataset):
self
.
transforms
=
transforms
def
__getitem__
(
self
,
idx
):
"""
Provided a index return the encoded caption with image and original caption.
"""
item
=
{
key
:
torch
.
tensor
(
values
[
idx
])
for
key
,
values
in
self
.
encoded_captions
.
items
()
...
...
@@ -34,11 +44,16 @@ class CLIPDataset(torch.utils.data.Dataset):
def
__len__
(
self
):
"""
Returns size of our training data.
"""
return
len
(
self
.
captions
)
def
get_transforms
(
mode
=
"train"
):
"""
Implements image transformations.
"""
if
mode
==
"train"
:
return
A
.
Compose
(
[
...
...
@@ -52,4 +67,40 @@ def get_transforms(mode="train"):
A
.
Resize
(
CFG
.
size
,
CFG
.
size
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
]
)
\ No newline at end of file
)
def
gen_train_valid_dfs
():
"""
Split dataset into train and validation dataset.
"""
dataframe
=
pd
.
read_csv
(
f
"{CFG.captions_path}/captions.csv"
)
max_id
=
dataframe
[
"id"
]
.
max
()
+
1
if
not
CFG
.
debug
else
100
image_ids
=
np
.
arange
(
0
,
max_id
)
np
.
random
.
seed
(
CFG
.
seed
)
valid_ids
=
np
.
random
.
choice
(
image_ids
,
size
=
int
(
0.2
*
len
(
image_ids
)),
replace
=
False
)
train_ids
=
[
id_
for
id_
in
image_ids
if
id_
not
in
valid_ids
]
train_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
train_ids
)]
.
reset_index
(
drop
=
True
)
valid_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
valid_ids
)]
.
reset_index
(
drop
=
True
)
return
train_dataframe
,
valid_dataframe
def
get_dataset_loader
(
dataframe
,
tokenizer
,
mode
):
"""
Build a dataset loader using CLIPDataset.
"""
transforms
=
get_transforms
(
mode
=
mode
)
dataset
=
CLIPDataset
(
dataframe
[
"image"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
tokenizer
=
tokenizer
,
transforms
=
transforms
,
)
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
CFG
.
batch_size
,
num_workers
=
CFG
.
num_workers
,
shuffle
=
True
if
mode
==
"train"
else
False
,
)
return
dataloader
\ No newline at end of file
encoders.py
View file @
b32fb1a1
...
...
@@ -5,7 +5,7 @@ from transformers import DistilBertModel, DistilBertConfig
class
ImageEncoder
(
nn
.
Module
):
"""
Encode images
to a fixed size vector
Encode images
using a pretrained model like resnet.
"""
def
__init__
(
...
...
@@ -21,6 +21,9 @@ class ImageEncoder(nn.Module):
return
self
.
model
(
x
)
class
TextEncoder
(
nn
.
Module
):
"""
Encode Text using a pretrained Langauge model as DistilBert
"""
def
__init__
(
self
,
model_name
=
CFG
.
text_encoder_model
,
pretrained
=
CFG
.
pretrained
,
trainable
=
CFG
.
trainable
):
super
()
.
__init__
()
if
pretrained
:
...
...
@@ -40,6 +43,9 @@ class TextEncoder(nn.Module):
return
last_hidden_state
[:,
self
.
target_token_idx
,
:]
class
ProjectionHead
(
nn
.
Module
):
"""
Convert dimentions of both image and text embeddings to a fixed size of embedding.
"""
def
__init__
(
self
,
embedding_dim
,
...
...
gui.py
0 → 100644
View file @
b32fb1a1
from
tkinter
import
*
import
matplotlib.pyplot
as
plt
from
matplotlib.backends.backend_tkagg
import
FigureCanvasTkAgg
from
tkinter.ttk
import
*
from
infer
import
find_matches
from
data
import
make_train_valid_dfs
from
infer
import
*
from
config
import
CFG
_
,
valid_df
=
make_train_valid_dfs
()
model
=
get_image_embeddings
(
valid_df
,
CFG
.
model_path
)
plot
=
None
def
get_figure
():
global
plot
print
(
"getting figure for sentence: "
,
entry1
.
get
())
fig
=
find_matches
(
model
,
query
=
entry1
.
get
(),
image_filenames
=
valid_df
[
'image'
]
.
values
,
n
=
9
)
print
(
"Got figure for sentence: "
,
entry1
.
get
())
plot
=
FigureCanvasTkAgg
(
fig
,
root
)
plot
.
get_tk_widget
()
.
grid
(
row
=
2
,
column
=
0
,
rowspan
=
7
,
columnspan
=
7
)
#plot.get_tk_widget().pack(side='top')
root
=
Tk
()
width
=
root
.
winfo_screenwidth
()
height
=
root
.
winfo_screenheight
()
root
.
geometry
(
"
%
dx
%
d"
%
(
width
,
height
))
root
.
title
(
'CLIP Model Demo'
)
label1
=
Label
(
root
,
text
=
"Enter Text"
)
label1
.
grid
(
row
=
0
,
column
=
0
)
#label1.pack(side='top')
entry1
=
Entry
(
root
,
width
=
50
)
entry1
.
grid
(
row
=
0
,
column
=
1
,
columnspan
=
5
)
#entry1.pack(side='top')
button1
=
Button
(
root
,
text
=
"Get Images"
,
command
=
get_figure
)
button1
.
grid
(
row
=
0
,
column
=
6
)
#button1.pack(side='top')
root
.
mainloop
()
infer.py
View file @
b32fb1a1
...
...
@@ -7,27 +7,33 @@ from config import CFG
import
matplotlib.pyplot
as
plt
import
cv2
import
torch.nn.functional
as
F
from
data
import
*
def
get_image_embeddings
(
valid_df
,
model_path
):
img_embeddings
=
None
def
get_image_embeddings
(
valid_df
,
model_path
,
get_emb
=
False
):
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
valid_loader
=
build_loaders
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
valid_loader
=
get_dataset_loader
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
model
=
CLIPModel
()
.
to
(
CFG
.
device
)
model
.
load_state_dict
(
torch
.
load
(
model_path
,
map_location
=
CFG
.
device
))
model
.
eval
()
valid_image_embeddings
=
[]
with
torch
.
no_grad
():
for
batch
in
tqdm
(
valid_loader
):
image_features
=
model
.
image_encoder
(
batch
[
"image"
]
.
to
(
CFG
.
device
))
image_embeddings
=
model
.
image_projection
(
image_features
)
valid_image_embeddings
.
append
(
image_embeddings
)
return
model
,
torch
.
cat
(
valid_image_embeddings
)
if
get_emb
:
with
torch
.
no_grad
():
for
batch
in
tqdm
(
valid_loader
):
image_features
=
model
.
image_encoder
(
batch
[
"image"
]
.
to
(
CFG
.
device
))
image_embeddings
=
model
.
image_projection
(
image_features
)
valid_image_embeddings
.
append
(
image_embeddings
)
torch
.
save
(
torch
.
cat
(
valid_image_embeddings
),
CFG
.
img_emb_path
)
return
model
_
,
valid_df
=
make_train_valid_dfs
()
model
,
image_embeddings
=
get_image_embeddings
(
valid_df
,
"./model/best.pt"
)
def
load_img_embeddings
():
global
img_embeddings
img_embeddings
=
torch
.
load
(
CFG
.
img_emb_path
,
map_location
=
CFG
.
device
)
img_embeddings
=
F
.
normalize
(
img_embeddings
,
p
=
2
,
dim
=-
1
)
def
find_matches
(
model
,
image_embeddings
,
query
,
image_filenames
,
n
=
9
):
def
find_matches
(
model
,
query
,
image_filenames
,
n
=
9
):
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
encoded_query
=
tokenizer
([
query
])
batch
=
{
...
...
@@ -40,25 +46,24 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
)
text_embeddings
=
model
.
text_projection
(
text_features
)
image_embeddings_n
=
F
.
normalize
(
image_embeddings
,
p
=
2
,
dim
=-
1
)
image_embeddings_n
=
img_embeddings
text_embeddings_n
=
F
.
normalize
(
text_embeddings
,
p
=
2
,
dim
=-
1
)
dot_similarity
=
text_embeddings_n
@
image_embeddings_n
.
T
values
,
indices
=
torch
.
topk
(
dot_similarity
.
squeeze
(
0
),
n
*
5
)
matches
=
[
image_filenames
[
idx
]
for
idx
in
indices
[::
5
]]
_
,
axes
=
plt
.
subplots
(
3
,
3
,
figsize
=
(
10
,
10
))
plt
.
clf
()
plt
.
cla
()
fig
,
axes
=
plt
.
subplots
(
3
,
3
,
figsize
=
(
10
,
10
))
for
match
,
ax
in
zip
(
matches
,
axes
.
flatten
()):
image
=
cv2
.
imread
(
f
"{CFG.image_path}/{match}"
)
image
=
cv2
.
cvtColor
(
image
,
cv2
.
COLOR_BGR2RGB
)
ax
.
imshow
(
image
)
ax
.
axis
(
"off"
)
plt
.
show
()
return
fig
find_matches
(
model
,
image_embeddings
,
query
=
"boy is playing"
,
image_filenames
=
valid_df
[
'image'
]
.
values
,
n
=
9
)
load_img_embeddings
()
##_, valid_df = gen_train_valid_dfs()
#model = get_image_embeddings(valid_df, "./model/best.pt")
#find_matches(model, query="boy is playing", image_filenames=valid_df['image'].values, n=9)
train.py
View file @
b32fb1a1
...
...
@@ -6,80 +6,110 @@ from config import CFG
import
itertools
from
clip
import
CLIPModel
from
utils
import
AvgMeter
,
get_lr
from
data
set
import
CLIPDataset
,
get_transforms
from
data
import
*
from
transformers
import
DistilBertTokenizer
import
torch.nn.functional
as
F
from
utils
import
*
import
pickle
def
train_epoch
(
model
,
train_loader
,
optimizer
,
lr_scheduler
,
step
):
loss_meter
=
AvgMeter
()
tqdm_object
=
tqdm
(
train_loader
,
total
=
len
(
train_loader
))
for
batch
in
tqdm_object
:
batch
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
step
==
"batch"
:
lr_scheduler
.
step
()
count
=
batch
[
"image"
]
.
size
(
0
)
loss_meter
.
update
(
loss
.
item
(),
count
)
# Reference for loading dataset and model architecture.
# https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2
#
tqdm_object
.
set_postfix
(
train_loss
=
loss_meter
.
avg
,
lr
=
get_lr
(
optimizer
))
return
loss_meter
def
train_epoch
(
model
,
train_data_loader
,
optimizer
):
"""
Method to start a train epoch.
"""
loss_sum
=
0
loss_count
=
0
for
batch
in
tqdm
(
train_data_loader
,
total
=
len
(
train_data_loader
)):
batch_data
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch_data
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
batch_size
=
batch_data
[
"image"
]
.
size
(
0
)
loss_count
+=
max
(
1
,
batch_size
)
loss_sum
+=
loss
.
item
()
*
max
(
1
,
batch_size
)
return
loss_sum
/
loss_count
def
valid_epoch
(
model
,
valid_loader
):
loss_meter
=
AvgMeter
()
tqdm_object
=
tqdm
(
valid_loader
,
total
=
len
(
valid_loader
))
for
batch
in
tqdm_object
:
batch
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
def
valid_epoch
(
model
,
valid_data_loader
):
"""
Method to start a validation epoch.
"""
loss_sum
=
0
loss_count
=
0
#tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for
batch
in
tqdm
(
valid_data_loader
,
total
=
len
(
valid_data_loader
)):
batch_data
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch
)
batch_size
=
batch_data
[
"image"
]
.
size
(
0
)
loss_count
+=
max
(
1
,
batch_size
)
loss_sum
+=
loss
.
item
()
*
max
(
1
,
batch_size
)
#tqdm_object.set_postfix(valid_loss=loss_sum/loss_count)
return
loss_sum
/
loss_count
count
=
batch
[
"image"
]
.
size
(
0
)
loss_meter
.
update
(
loss
.
item
(),
count
)
tqdm_object
.
set_postfix
(
valid_loss
=
loss_meter
.
avg
)
return
loss_meter
def
store_image_embeddings
(
valid_df
,
model
):
"""
Generate and store embeddings of validation
"""
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
valid_loader
=
get_dataset_loader
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
#model = CLIPModel().to(CFG.device)
#model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model
.
eval
()
valid_image_embeddings
=
[]
with
torch
.
no_grad
():
for
batch
in
tqdm
(
valid_loader
):
image_features
=
model
.
image_encoder
(
batch
[
"image"
]
.
to
(
CFG
.
device
))
image_embeddings
=
model
.
image_projection
(
image_features
)
valid_image_embeddings
.
append
(
image_embeddings
)
torch
.
save
(
torch
.
cat
(
image_embeddings
),
CFG
.
img_emb_path
)
return
def
main
():
train_df
,
valid_df
=
make_train_valid_dfs
()
# Generate train and valid dataframe
train_data
,
valid_data
=
gen_train_valid_dfs
()
# Use pretrained tokeniser to tokensize the captions
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
train_loader
=
build_loaders
(
train_df
,
tokenizer
,
mode
=
"train"
)
valid_loader
=
build_loaders
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
# Create data loaders for both train and validation data
train_loader
=
get_dataset_loader
(
train_data
,
tokenizer
,
mode
=
"train"
)
valid_loader
=
get_dataset_loader
(
valid_data
,
tokenizer
,
mode
=
"valid"
)
model
=
CLIPModel
()
.
to
(
CFG
.
device
)
params
=
[
{
"params"
:
model
.
image_encoder
.
parameters
(),
"lr"
:
CFG
.
image_encoder_lr
},
{
"params"
:
model
.
text_encoder
.
parameters
(),
"lr"
:
CFG
.
text_encoder_lr
},
{
"params"
:
itertools
.
chain
(
model
.
image_projection
.
parameters
(),
model
.
text_projection
.
parameters
()
{
"params"
:
itertools
.
chain
(
model
.
image_projection
.
parameters
(),
model
.
text_projection
.
parameters
()
),
"lr"
:
CFG
.
head_lr
,
"weight_decay"
:
CFG
.
weight_decay
}
]
optimizer
=
torch
.
optim
.
AdamW
(
params
,
weight_decay
=
0.
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
mode
=
"min"
,
patience
=
CFG
.
patience
,
factor
=
CFG
.
factor
)
step
=
"epoch"
best_loss
=
float
(
'inf'
)
min_loss
=
float
(
'inf'
)
for
epoch
in
range
(
CFG
.
epochs
):
print
(
f
"Epoch: {epoch + 1}"
)
model
.
train
()
train_loss
=
train_epoch
(
model
,
train_loader
,
optimizer
,
lr_scheduler
,
step
)
train_loss
=
train_epoch
(
model
,
train_loader
,
optimizer
)
model
.
eval
()
with
torch
.
no_grad
():
valid_loss
=
valid_epoch
(
model
,
valid_loader
)
if
valid_loss
.
avg
<
best
_loss
:
best_loss
=
valid_loss
.
avg
if
valid_loss
<
min
_loss
:
min_loss
=
valid_loss
torch
.
save
(
model
.
state_dict
(),
CFG
.
model_path
)
print
(
"Saved Best Model!"
)
lr_scheduler
.
step
(
valid_loss
.
avg
)
lr_scheduler
.
step
(
valid_loss
)
main
()
utils.py
View file @
b32fb1a1
import
numpy
as
np
import
pandas
as
pd
from
data
set
import
*
import
torch.nn.functional
as
F
from
data
import
*
from
config
import
CFG
def
generate_context
():
df
=
pd
.
read_csv
(
CFG
.
image_path
+
"/results.csv"
,
delimiter
=
"|"
)
def
add_imageid
():
"""
modify the caption and image name file with assigning id to each image which has
more than one caption.
"""
df
=
pd
.
read_csv
(
CFG
.
dataset_path
+
"/results.csv"
,
delimiter
=
"|"
)
df
.
columns
=
[
'image'
,
'caption_number'
,
'caption'
]
df
[
'caption'
]
=
df
[
'caption'
]
.
str
.
lstrip
()
df
[
'caption_number'
]
=
df
[
'caption_number'
]
.
str
.
lstrip
()
...
...
@@ -12,40 +16,12 @@ def generate_context():
df
.
loc
[
19999
,
'caption'
]
=
"A dog runs across the grass ."
ids
=
[
id_
for
id_
in
range
(
len
(
df
)
//
5
)
for
i
in
range
(
5
)]
df
[
'id'
]
=
ids
df
.
to_csv
(
"captions.csv"
,
index
=
False
)
df
.
head
()
def
make_train_valid_dfs
():
dataframe
=
pd
.
read_csv
(
f
"{CFG.captions_path}/captions.csv"
)
max_id
=
dataframe
[
"id"
]
.
max
()
+
1
if
not
CFG
.
debug
else
100
image_ids
=
np
.
arange
(
0
,
max_id
)
np
.
random
.
seed
(
42
)
valid_ids
=
np
.
random
.
choice
(
image_ids
,
size
=
int
(
0.2
*
len
(
image_ids
)),
replace
=
False
)
train_ids
=
[
id_
for
id_
in
image_ids
if
id_
not
in
valid_ids
]
train_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
train_ids
)]
.
reset_index
(
drop
=
True
)
valid_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
valid_ids
)]
.
reset_index
(
drop
=
True
)
return
train_dataframe
,
valid_dataframe
def
build_loaders
(
dataframe
,
tokenizer
,
mode
):
transforms
=
get_transforms
(
mode
=
mode
)
dataset
=
CLIPDataset
(
dataframe
[
"image"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
tokenizer
=
tokenizer
,
transforms
=
transforms
,
)
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
CFG
.
batch_size
,
num_workers
=
CFG
.
num_workers
,
shuffle
=
True
if
mode
==
"train"
else
False
,
)
return
dataloader
df
.
to_csv
(
CFG
.
captions_path
+
"captions.csv"
,
index
=
False
)
class
AvgMeter
:
"""
Helps in storing average of loss across a batch/epoch
"""
def
__init__
(
self
,
name
=
"Metric"
):
self
.
name
=
name
self
.
reset
()
...
...
@@ -63,5 +39,8 @@ class AvgMeter:
return
text
def
get_lr
(
optimizer
):
"""
Return paramater gourp for lr in oprimiser.
"""
for
param_group
in
optimizer
.
param_groups
:
return
param_group
[
"lr"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment