Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CLIP
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Saswat
CLIP
Commits
b32fb1a1
Commit
b32fb1a1
authored
Nov 27, 2022
by
Saswat
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add GUI
parent
c580318c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
227 additions
and
106 deletions
+227
-106
config.py
config.py
+12
-3
data.py
data.py
+54
-3
encoders.py
encoders.py
+7
-1
gui.py
gui.py
+41
-0
infer.py
infer.py
+26
-21
train.py
train.py
+71
-41
utils.py
utils.py
+16
-37
No files found.
config.py
View file @
b32fb1a1
...
@@ -2,8 +2,17 @@ import torch
...
@@ -2,8 +2,17 @@ import torch
class
CFG
:
class
CFG
:
debug
=
False
debug
=
False
image_path
=
"/home/saswat/Desktop/flickr30k_images/flickr30k_images/"
seed
=
42
# Paths
dataset_path
=
"/home/saswat/Desktop/flickr30k_images/"
image_path
=
dataset_path
+
"flickr30k_images/"
captions_path
=
"."
captions_path
=
"."
model_path
=
"./model/best.pt"
img_emb_path
=
"./model/image_embeddings.pt"
# Training Params
epochs
=
2
batch_size
=
8
batch_size
=
8
num_workers
=
4
num_workers
=
4
head_lr
=
1e-3
head_lr
=
1e-3
...
@@ -12,9 +21,9 @@ class CFG:
...
@@ -12,9 +21,9 @@ class CFG:
weight_decay
=
1e-3
weight_decay
=
1e-3
patience
=
1
patience
=
1
factor
=
0.8
factor
=
0.8
epochs
=
2
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Pretrained model (Text Embeddings/ Image Embeddings) params
model_name
=
'resnet50'
model_name
=
'resnet50'
image_embedding
=
2048
image_embedding
=
2048
text_encoder_model
=
"distilbert-base-uncased"
text_encoder_model
=
"distilbert-base-uncased"
...
@@ -34,4 +43,4 @@ class CFG:
...
@@ -34,4 +43,4 @@ class CFG:
projection_dim
=
256
projection_dim
=
256
dropout
=
0.1
dropout
=
0.1
model_path
=
"./model/best.pt"
\ No newline at end of file
\ No newline at end of file
data
set
.py
→
data.py
View file @
b32fb1a1
import
torch
import
torch
from
config
import
CFG
from
config
import
CFG
import
cv2
import
cv2
import
torch.nn.functional
as
F
import
albumentations
as
A
import
albumentations
as
A
import
numpy
as
np
from
utils
import
*
import
pandas
as
pd
class
CLIPDataset
(
torch
.
utils
.
data
.
Dataset
):
class
CLIPDataset
(
torch
.
utils
.
data
.
Dataset
):
"""
Dataset class for getting image and text using a dataset loader.
"""
def
__init__
(
self
,
image_filenames
,
captions
,
tokenizer
,
transforms
):
def
__init__
(
self
,
image_filenames
,
captions
,
tokenizer
,
transforms
):
"""
"""
image_filenames and cpations must have the same length; so, if there are
image_filenames and cpations must have the same length; so, if there are
...
@@ -19,6 +26,9 @@ class CLIPDataset(torch.utils.data.Dataset):
...
@@ -19,6 +26,9 @@ class CLIPDataset(torch.utils.data.Dataset):
self
.
transforms
=
transforms
self
.
transforms
=
transforms
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""
Provided a index return the encoded caption with image and original caption.
"""
item
=
{
item
=
{
key
:
torch
.
tensor
(
values
[
idx
])
key
:
torch
.
tensor
(
values
[
idx
])
for
key
,
values
in
self
.
encoded_captions
.
items
()
for
key
,
values
in
self
.
encoded_captions
.
items
()
...
@@ -34,11 +44,16 @@ class CLIPDataset(torch.utils.data.Dataset):
...
@@ -34,11 +44,16 @@ class CLIPDataset(torch.utils.data.Dataset):
def
__len__
(
self
):
def
__len__
(
self
):
"""
Returns size of our training data.
"""
return
len
(
self
.
captions
)
return
len
(
self
.
captions
)
def
get_transforms
(
mode
=
"train"
):
def
get_transforms
(
mode
=
"train"
):
"""
Implements image transformations.
"""
if
mode
==
"train"
:
if
mode
==
"train"
:
return
A
.
Compose
(
return
A
.
Compose
(
[
[
...
@@ -53,3 +68,39 @@ def get_transforms(mode="train"):
...
@@ -53,3 +68,39 @@ def get_transforms(mode="train"):
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
]
]
)
)
def
gen_train_valid_dfs
():
"""
Split dataset into train and validation dataset.
"""
dataframe
=
pd
.
read_csv
(
f
"{CFG.captions_path}/captions.csv"
)
max_id
=
dataframe
[
"id"
]
.
max
()
+
1
if
not
CFG
.
debug
else
100
image_ids
=
np
.
arange
(
0
,
max_id
)
np
.
random
.
seed
(
CFG
.
seed
)
valid_ids
=
np
.
random
.
choice
(
image_ids
,
size
=
int
(
0.2
*
len
(
image_ids
)),
replace
=
False
)
train_ids
=
[
id_
for
id_
in
image_ids
if
id_
not
in
valid_ids
]
train_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
train_ids
)]
.
reset_index
(
drop
=
True
)
valid_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
valid_ids
)]
.
reset_index
(
drop
=
True
)
return
train_dataframe
,
valid_dataframe
def
get_dataset_loader
(
dataframe
,
tokenizer
,
mode
):
"""
Build a dataset loader using CLIPDataset.
"""
transforms
=
get_transforms
(
mode
=
mode
)
dataset
=
CLIPDataset
(
dataframe
[
"image"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
tokenizer
=
tokenizer
,
transforms
=
transforms
,
)
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
CFG
.
batch_size
,
num_workers
=
CFG
.
num_workers
,
shuffle
=
True
if
mode
==
"train"
else
False
,
)
return
dataloader
\ No newline at end of file
encoders.py
View file @
b32fb1a1
...
@@ -5,7 +5,7 @@ from transformers import DistilBertModel, DistilBertConfig
...
@@ -5,7 +5,7 @@ from transformers import DistilBertModel, DistilBertConfig
class
ImageEncoder
(
nn
.
Module
):
class
ImageEncoder
(
nn
.
Module
):
"""
"""
Encode images
to a fixed size vector
Encode images
using a pretrained model like resnet.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -21,6 +21,9 @@ class ImageEncoder(nn.Module):
...
@@ -21,6 +21,9 @@ class ImageEncoder(nn.Module):
return
self
.
model
(
x
)
return
self
.
model
(
x
)
class
TextEncoder
(
nn
.
Module
):
class
TextEncoder
(
nn
.
Module
):
"""
Encode Text using a pretrained Langauge model as DistilBert
"""
def
__init__
(
self
,
model_name
=
CFG
.
text_encoder_model
,
pretrained
=
CFG
.
pretrained
,
trainable
=
CFG
.
trainable
):
def
__init__
(
self
,
model_name
=
CFG
.
text_encoder_model
,
pretrained
=
CFG
.
pretrained
,
trainable
=
CFG
.
trainable
):
super
()
.
__init__
()
super
()
.
__init__
()
if
pretrained
:
if
pretrained
:
...
@@ -40,6 +43,9 @@ class TextEncoder(nn.Module):
...
@@ -40,6 +43,9 @@ class TextEncoder(nn.Module):
return
last_hidden_state
[:,
self
.
target_token_idx
,
:]
return
last_hidden_state
[:,
self
.
target_token_idx
,
:]
class
ProjectionHead
(
nn
.
Module
):
class
ProjectionHead
(
nn
.
Module
):
"""
Convert dimentions of both image and text embeddings to a fixed size of embedding.
"""
def
__init__
(
def
__init__
(
self
,
self
,
embedding_dim
,
embedding_dim
,
...
...
gui.py
0 → 100644
View file @
b32fb1a1
from
tkinter
import
*
import
matplotlib.pyplot
as
plt
from
matplotlib.backends.backend_tkagg
import
FigureCanvasTkAgg
from
tkinter.ttk
import
*
from
infer
import
find_matches
from
data
import
make_train_valid_dfs
from
infer
import
*
from
config
import
CFG
_
,
valid_df
=
make_train_valid_dfs
()
model
=
get_image_embeddings
(
valid_df
,
CFG
.
model_path
)
plot
=
None
def
get_figure
():
global
plot
print
(
"getting figure for sentence: "
,
entry1
.
get
())
fig
=
find_matches
(
model
,
query
=
entry1
.
get
(),
image_filenames
=
valid_df
[
'image'
]
.
values
,
n
=
9
)
print
(
"Got figure for sentence: "
,
entry1
.
get
())
plot
=
FigureCanvasTkAgg
(
fig
,
root
)
plot
.
get_tk_widget
()
.
grid
(
row
=
2
,
column
=
0
,
rowspan
=
7
,
columnspan
=
7
)
#plot.get_tk_widget().pack(side='top')
root
=
Tk
()
width
=
root
.
winfo_screenwidth
()
height
=
root
.
winfo_screenheight
()
root
.
geometry
(
"
%
dx
%
d"
%
(
width
,
height
))
root
.
title
(
'CLIP Model Demo'
)
label1
=
Label
(
root
,
text
=
"Enter Text"
)
label1
.
grid
(
row
=
0
,
column
=
0
)
#label1.pack(side='top')
entry1
=
Entry
(
root
,
width
=
50
)
entry1
.
grid
(
row
=
0
,
column
=
1
,
columnspan
=
5
)
#entry1.pack(side='top')
button1
=
Button
(
root
,
text
=
"Get Images"
,
command
=
get_figure
)
button1
.
grid
(
row
=
0
,
column
=
6
)
#button1.pack(side='top')
root
.
mainloop
()
infer.py
View file @
b32fb1a1
...
@@ -7,27 +7,33 @@ from config import CFG
...
@@ -7,27 +7,33 @@ from config import CFG
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
import
cv2
import
cv2
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
data
import
*
def
get_image_embeddings
(
valid_df
,
model_path
):
img_embeddings
=
None
def
get_image_embeddings
(
valid_df
,
model_path
,
get_emb
=
False
):
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
valid_loader
=
build_loaders
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
valid_loader
=
get_dataset_loader
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
model
=
CLIPModel
()
.
to
(
CFG
.
device
)
model
=
CLIPModel
()
.
to
(
CFG
.
device
)
model
.
load_state_dict
(
torch
.
load
(
model_path
,
map_location
=
CFG
.
device
))
model
.
load_state_dict
(
torch
.
load
(
model_path
,
map_location
=
CFG
.
device
))
model
.
eval
()
model
.
eval
()
valid_image_embeddings
=
[]
valid_image_embeddings
=
[]
if
get_emb
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch
in
tqdm
(
valid_loader
):
for
batch
in
tqdm
(
valid_loader
):
image_features
=
model
.
image_encoder
(
batch
[
"image"
]
.
to
(
CFG
.
device
))
image_features
=
model
.
image_encoder
(
batch
[
"image"
]
.
to
(
CFG
.
device
))
image_embeddings
=
model
.
image_projection
(
image_features
)
image_embeddings
=
model
.
image_projection
(
image_features
)
valid_image_embeddings
.
append
(
image_embeddings
)
valid_image_embeddings
.
append
(
image_embeddings
)
return
model
,
torch
.
cat
(
valid_image_embeddings
)
torch
.
save
(
torch
.
cat
(
valid_image_embeddings
),
CFG
.
img_emb_path
)
return
model
_
,
valid_df
=
make_train_valid_dfs
()
def
load_img_embeddings
():
model
,
image_embeddings
=
get_image_embeddings
(
valid_df
,
"./model/best.pt"
)
global
img_embeddings
img_embeddings
=
torch
.
load
(
CFG
.
img_emb_path
,
map_location
=
CFG
.
device
)
img_embeddings
=
F
.
normalize
(
img_embeddings
,
p
=
2
,
dim
=-
1
)
def
find_matches
(
model
,
image_embeddings
,
query
,
image_filenames
,
n
=
9
):
def
find_matches
(
model
,
query
,
image_filenames
,
n
=
9
):
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
encoded_query
=
tokenizer
([
query
])
encoded_query
=
tokenizer
([
query
])
batch
=
{
batch
=
{
...
@@ -40,25 +46,24 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
...
@@ -40,25 +46,24 @@ def find_matches(model, image_embeddings, query, image_filenames, n=9):
)
)
text_embeddings
=
model
.
text_projection
(
text_features
)
text_embeddings
=
model
.
text_projection
(
text_features
)
image_embeddings_n
=
F
.
normalize
(
image_embeddings
,
p
=
2
,
dim
=-
1
)
image_embeddings_n
=
img_embeddings
text_embeddings_n
=
F
.
normalize
(
text_embeddings
,
p
=
2
,
dim
=-
1
)
text_embeddings_n
=
F
.
normalize
(
text_embeddings
,
p
=
2
,
dim
=-
1
)
dot_similarity
=
text_embeddings_n
@
image_embeddings_n
.
T
dot_similarity
=
text_embeddings_n
@
image_embeddings_n
.
T
values
,
indices
=
torch
.
topk
(
dot_similarity
.
squeeze
(
0
),
n
*
5
)
values
,
indices
=
torch
.
topk
(
dot_similarity
.
squeeze
(
0
),
n
*
5
)
matches
=
[
image_filenames
[
idx
]
for
idx
in
indices
[::
5
]]
matches
=
[
image_filenames
[
idx
]
for
idx
in
indices
[::
5
]]
plt
.
clf
()
_
,
axes
=
plt
.
subplots
(
3
,
3
,
figsize
=
(
10
,
10
))
plt
.
cla
()
fig
,
axes
=
plt
.
subplots
(
3
,
3
,
figsize
=
(
10
,
10
))
for
match
,
ax
in
zip
(
matches
,
axes
.
flatten
()):
for
match
,
ax
in
zip
(
matches
,
axes
.
flatten
()):
image
=
cv2
.
imread
(
f
"{CFG.image_path}/{match}"
)
image
=
cv2
.
imread
(
f
"{CFG.image_path}/{match}"
)
image
=
cv2
.
cvtColor
(
image
,
cv2
.
COLOR_BGR2RGB
)
image
=
cv2
.
cvtColor
(
image
,
cv2
.
COLOR_BGR2RGB
)
ax
.
imshow
(
image
)
ax
.
imshow
(
image
)
ax
.
axis
(
"off"
)
ax
.
axis
(
"off"
)
return
fig
plt
.
show
()
load_img_embeddings
()
##_, valid_df = gen_train_valid_dfs()
find_matches
(
model
,
#model = get_image_embeddings(valid_df, "./model/best.pt")
image_embeddings
,
#find_matches(model, query="boy is playing", image_filenames=valid_df['image'].values, n=9)
query
=
"boy is playing"
,
image_filenames
=
valid_df
[
'image'
]
.
values
,
n
=
9
)
train.py
View file @
b32fb1a1
...
@@ -6,80 +6,110 @@ from config import CFG
...
@@ -6,80 +6,110 @@ from config import CFG
import
itertools
import
itertools
from
clip
import
CLIPModel
from
clip
import
CLIPModel
from
utils
import
AvgMeter
,
get_lr
from
utils
import
AvgMeter
,
get_lr
from
data
set
import
CLIPDataset
,
get_transforms
from
data
import
*
from
transformers
import
DistilBertTokenizer
from
transformers
import
DistilBertTokenizer
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
utils
import
*
from
utils
import
*
import
pickle
def
train_epoch
(
model
,
train_loader
,
optimizer
,
lr_scheduler
,
step
):
loss_meter
=
AvgMeter
()
tqdm_object
=
tqdm
(
train_loader
,
total
=
len
(
train_loader
))
for
batch
in
tqdm_object
:
batch
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
if
step
==
"batch"
:
lr_scheduler
.
step
()
count
=
batch
[
"image"
]
.
size
(
0
)
# Reference for loading dataset and model architecture.
loss_meter
.
update
(
loss
.
item
(),
count
)
# https://towardsdatascience.com/simple-implementation-of-openai-clip-model-a-tutorial-ace6ff01d9f2
#
tqdm_object
.
set_postfix
(
train_loss
=
loss_meter
.
avg
,
lr
=
get_lr
(
optimizer
))
return
loss_meter
def
train_epoch
(
model
,
train_data_loader
,
optimizer
):
def
valid_epoch
(
model
,
valid_loader
):
"""
loss_meter
=
AvgMeter
()
Method to start a train epoch.
"""
tqdm_object
=
tqdm
(
valid_loader
,
total
=
len
(
valid_loader
))
loss_sum
=
0
for
batch
in
tqdm_object
:
loss_count
=
0
batch
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
for
batch
in
tqdm
(
train_data_loader
,
total
=
len
(
train_data_loader
)):
batch_data
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch_data
)
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
batch_size
=
batch_data
[
"image"
]
.
size
(
0
)
loss_count
+=
max
(
1
,
batch_size
)
loss_sum
+=
loss
.
item
()
*
max
(
1
,
batch_size
)
return
loss_sum
/
loss_count
def
valid_epoch
(
model
,
valid_data_loader
):
"""
Method to start a validation epoch.
"""
loss_sum
=
0
loss_count
=
0
#tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for
batch
in
tqdm
(
valid_data_loader
,
total
=
len
(
valid_data_loader
)):
batch_data
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch
)
loss
=
model
(
batch
)
batch_size
=
batch_data
[
"image"
]
.
size
(
0
)
loss_count
+=
max
(
1
,
batch_size
)
loss_sum
+=
loss
.
item
()
*
max
(
1
,
batch_size
)
#tqdm_object.set_postfix(valid_loss=loss_sum/loss_count)
return
loss_sum
/
loss_count
def
store_image_embeddings
(
valid_df
,
model
):
"""
Generate and store embeddings of validation
"""
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
valid_loader
=
get_dataset_loader
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
count
=
batch
[
"image"
]
.
size
(
0
)
#model = CLIPModel().to(CFG.device)
loss_meter
.
update
(
loss
.
item
(),
count
)
#model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model
.
eval
()
tqdm_object
.
set_postfix
(
valid_loss
=
loss_meter
.
avg
)
valid_image_embeddings
=
[]
return
loss_meter
with
torch
.
no_grad
():
for
batch
in
tqdm
(
valid_loader
):
image_features
=
model
.
image_encoder
(
batch
[
"image"
]
.
to
(
CFG
.
device
))
image_embeddings
=
model
.
image_projection
(
image_features
)
valid_image_embeddings
.
append
(
image_embeddings
)
torch
.
save
(
torch
.
cat
(
image_embeddings
),
CFG
.
img_emb_path
)
return
def
main
():
def
main
():
train_df
,
valid_df
=
make_train_valid_dfs
()
# Generate train and valid dataframe
train_data
,
valid_data
=
gen_train_valid_dfs
()
# Use pretrained tokeniser to tokensize the captions
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
tokenizer
=
DistilBertTokenizer
.
from_pretrained
(
CFG
.
text_tokenizer
)
train_loader
=
build_loaders
(
train_df
,
tokenizer
,
mode
=
"train"
)
valid_loader
=
build_loaders
(
valid_df
,
tokenizer
,
mode
=
"valid"
)
# Create data loaders for both train and validation data
train_loader
=
get_dataset_loader
(
train_data
,
tokenizer
,
mode
=
"train"
)
valid_loader
=
get_dataset_loader
(
valid_data
,
tokenizer
,
mode
=
"valid"
)
model
=
CLIPModel
()
.
to
(
CFG
.
device
)
model
=
CLIPModel
()
.
to
(
CFG
.
device
)
params
=
[
params
=
[
{
"params"
:
model
.
image_encoder
.
parameters
(),
"lr"
:
CFG
.
image_encoder_lr
},
{
"params"
:
model
.
image_encoder
.
parameters
(),
"lr"
:
CFG
.
image_encoder_lr
},
{
"params"
:
model
.
text_encoder
.
parameters
(),
"lr"
:
CFG
.
text_encoder_lr
},
{
"params"
:
model
.
text_encoder
.
parameters
(),
"lr"
:
CFG
.
text_encoder_lr
},
{
"params"
:
itertools
.
chain
(
{
"params"
:
itertools
.
chain
(
model
.
image_projection
.
parameters
(),
model
.
text_projection
.
parameters
()
model
.
image_projection
.
parameters
(),
model
.
text_projection
.
parameters
()
),
"lr"
:
CFG
.
head_lr
,
"weight_decay"
:
CFG
.
weight_decay
}
),
"lr"
:
CFG
.
head_lr
,
"weight_decay"
:
CFG
.
weight_decay
}
]
]
optimizer
=
torch
.
optim
.
AdamW
(
params
,
weight_decay
=
0.
)
optimizer
=
torch
.
optim
.
AdamW
(
params
,
weight_decay
=
0.
)
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
mode
=
"min"
,
patience
=
CFG
.
patience
,
factor
=
CFG
.
factor
optimizer
,
mode
=
"min"
,
patience
=
CFG
.
patience
,
factor
=
CFG
.
factor
)
)
step
=
"epoch"
min_loss
=
float
(
'inf'
)
best_loss
=
float
(
'inf'
)
for
epoch
in
range
(
CFG
.
epochs
):
for
epoch
in
range
(
CFG
.
epochs
):
print
(
f
"Epoch: {epoch + 1}"
)
print
(
f
"Epoch: {epoch + 1}"
)
model
.
train
()
model
.
train
()
train_loss
=
train_epoch
(
model
,
train_loader
,
optimizer
,
lr_scheduler
,
step
)
train_loss
=
train_epoch
(
model
,
train_loader
,
optimizer
)
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
valid_loss
=
valid_epoch
(
model
,
valid_loader
)
valid_loss
=
valid_epoch
(
model
,
valid_loader
)
if
valid_loss
.
avg
<
best
_loss
:
if
valid_loss
<
min
_loss
:
best_loss
=
valid_loss
.
avg
min_loss
=
valid_loss
torch
.
save
(
model
.
state_dict
(),
CFG
.
model_path
)
torch
.
save
(
model
.
state_dict
(),
CFG
.
model_path
)
print
(
"Saved Best Model!"
)
print
(
"Saved Best Model!"
)
lr_scheduler
.
step
(
valid_loss
.
avg
)
lr_scheduler
.
step
(
valid_loss
)
main
()
main
()
utils.py
View file @
b32fb1a1
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
from
data
set
import
*
from
data
import
*
import
torch.nn.functional
as
F
from
config
import
CFG
def
generate_context
():
df
=
pd
.
read_csv
(
CFG
.
image_path
+
"/results.csv"
,
delimiter
=
"|"
)
def
add_imageid
():
"""
modify the caption and image name file with assigning id to each image which has
more than one caption.
"""
df
=
pd
.
read_csv
(
CFG
.
dataset_path
+
"/results.csv"
,
delimiter
=
"|"
)
df
.
columns
=
[
'image'
,
'caption_number'
,
'caption'
]
df
.
columns
=
[
'image'
,
'caption_number'
,
'caption'
]
df
[
'caption'
]
=
df
[
'caption'
]
.
str
.
lstrip
()
df
[
'caption'
]
=
df
[
'caption'
]
.
str
.
lstrip
()
df
[
'caption_number'
]
=
df
[
'caption_number'
]
.
str
.
lstrip
()
df
[
'caption_number'
]
=
df
[
'caption_number'
]
.
str
.
lstrip
()
...
@@ -12,40 +16,12 @@ def generate_context():
...
@@ -12,40 +16,12 @@ def generate_context():
df
.
loc
[
19999
,
'caption'
]
=
"A dog runs across the grass ."
df
.
loc
[
19999
,
'caption'
]
=
"A dog runs across the grass ."
ids
=
[
id_
for
id_
in
range
(
len
(
df
)
//
5
)
for
i
in
range
(
5
)]
ids
=
[
id_
for
id_
in
range
(
len
(
df
)
//
5
)
for
i
in
range
(
5
)]
df
[
'id'
]
=
ids
df
[
'id'
]
=
ids
df
.
to_csv
(
"captions.csv"
,
index
=
False
)
df
.
to_csv
(
CFG
.
captions_path
+
"captions.csv"
,
index
=
False
)
df
.
head
()
def
make_train_valid_dfs
():
dataframe
=
pd
.
read_csv
(
f
"{CFG.captions_path}/captions.csv"
)
max_id
=
dataframe
[
"id"
]
.
max
()
+
1
if
not
CFG
.
debug
else
100
image_ids
=
np
.
arange
(
0
,
max_id
)
np
.
random
.
seed
(
42
)
valid_ids
=
np
.
random
.
choice
(
image_ids
,
size
=
int
(
0.2
*
len
(
image_ids
)),
replace
=
False
)
train_ids
=
[
id_
for
id_
in
image_ids
if
id_
not
in
valid_ids
]
train_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
train_ids
)]
.
reset_index
(
drop
=
True
)
valid_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
valid_ids
)]
.
reset_index
(
drop
=
True
)
return
train_dataframe
,
valid_dataframe
def
build_loaders
(
dataframe
,
tokenizer
,
mode
):
transforms
=
get_transforms
(
mode
=
mode
)
dataset
=
CLIPDataset
(
dataframe
[
"image"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
tokenizer
=
tokenizer
,
transforms
=
transforms
,
)
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
CFG
.
batch_size
,
num_workers
=
CFG
.
num_workers
,
shuffle
=
True
if
mode
==
"train"
else
False
,
)
return
dataloader
class
AvgMeter
:
class
AvgMeter
:
"""
Helps in storing average of loss across a batch/epoch
"""
def
__init__
(
self
,
name
=
"Metric"
):
def
__init__
(
self
,
name
=
"Metric"
):
self
.
name
=
name
self
.
name
=
name
self
.
reset
()
self
.
reset
()
...
@@ -63,5 +39,8 @@ class AvgMeter:
...
@@ -63,5 +39,8 @@ class AvgMeter:
return
text
return
text
def
get_lr
(
optimizer
):
def
get_lr
(
optimizer
):
"""
Return paramater gourp for lr in oprimiser.
"""
for
param_group
in
optimizer
.
param_groups
:
for
param_group
in
optimizer
.
param_groups
:
return
param_group
[
"lr"
]
return
param_group
[
"lr"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment