Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CLIP
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Saswat
CLIP
Commits
c09e3ce4
Commit
c09e3ce4
authored
Nov 27, 2022
by
Saswat
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix train
parent
b32fb1a1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
47 deletions
+10
-47
config.py
config.py
+1
-1
data.py
data.py
+8
-16
train.py
train.py
+1
-3
utils.py
utils.py
+0
-27
No files found.
config.py
View file @
c09e3ce4
import
torch
import
torch
class
CFG
:
class
CFG
:
debug
=
Fals
e
debug
=
Tru
e
seed
=
42
seed
=
42
# Paths
# Paths
...
...
data.py
View file @
c09e3ce4
...
@@ -50,24 +50,16 @@ class CLIPDataset(torch.utils.data.Dataset):
...
@@ -50,24 +50,16 @@ class CLIPDataset(torch.utils.data.Dataset):
return
len
(
self
.
captions
)
return
len
(
self
.
captions
)
def
get_transforms
(
mode
=
"train"
):
def
image_transforms
(
):
"""
"""
Implements image transformations.
Implements image transformations.
"""
"""
if
mode
==
"train"
:
return
A
.
Compose
(
return
A
.
Compose
(
[
[
A
.
Resize
(
CFG
.
size
,
CFG
.
size
,
always_apply
=
True
),
A
.
Resize
(
CFG
.
size
,
CFG
.
size
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
]
]
)
)
else
:
return
A
.
Compose
(
[
A
.
Resize
(
CFG
.
size
,
CFG
.
size
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
]
)
def
gen_train_valid_dfs
():
def
gen_train_valid_dfs
():
"""
"""
...
@@ -90,7 +82,7 @@ def get_dataset_loader(dataframe, tokenizer, mode):
...
@@ -90,7 +82,7 @@ def get_dataset_loader(dataframe, tokenizer, mode):
"""
"""
Build a dataset loader using CLIPDataset.
Build a dataset loader using CLIPDataset.
"""
"""
transforms
=
get_transforms
(
mode
=
mode
)
transforms
=
image_transforms
(
)
dataset
=
CLIPDataset
(
dataset
=
CLIPDataset
(
dataframe
[
"image"
]
.
values
,
dataframe
[
"image"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
...
...
train.py
View file @
c09e3ce4
...
@@ -5,7 +5,6 @@ from tqdm import tqdm
...
@@ -5,7 +5,6 @@ from tqdm import tqdm
from
config
import
CFG
from
config
import
CFG
import
itertools
import
itertools
from
clip
import
CLIPModel
from
clip
import
CLIPModel
from
utils
import
AvgMeter
,
get_lr
from
data
import
*
from
data
import
*
from
transformers
import
DistilBertTokenizer
from
transformers
import
DistilBertTokenizer
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -45,7 +44,7 @@ def valid_epoch(model, valid_data_loader):
...
@@ -45,7 +44,7 @@ def valid_epoch(model, valid_data_loader):
#tqdm_object = tqdm(valid_loader, total=len(valid_loader))
#tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for
batch
in
tqdm
(
valid_data_loader
,
total
=
len
(
valid_data_loader
)):
for
batch
in
tqdm
(
valid_data_loader
,
total
=
len
(
valid_data_loader
)):
batch_data
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
batch_data
=
{
k
:
v
.
to
(
CFG
.
device
)
for
k
,
v
in
batch
.
items
()
if
k
!=
"caption"
}
loss
=
model
(
batch
)
loss
=
model
(
batch
_data
)
batch_size
=
batch_data
[
"image"
]
.
size
(
0
)
batch_size
=
batch_data
[
"image"
]
.
size
(
0
)
loss_count
+=
max
(
1
,
batch_size
)
loss_count
+=
max
(
1
,
batch_size
)
loss_sum
+=
loss
.
item
()
*
max
(
1
,
batch_size
)
loss_sum
+=
loss
.
item
()
*
max
(
1
,
batch_size
)
...
@@ -104,7 +103,6 @@ def main():
...
@@ -104,7 +103,6 @@ def main():
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
valid_loss
=
valid_epoch
(
model
,
valid_loader
)
valid_loss
=
valid_epoch
(
model
,
valid_loader
)
if
valid_loss
<
min_loss
:
if
valid_loss
<
min_loss
:
min_loss
=
valid_loss
min_loss
=
valid_loss
torch
.
save
(
model
.
state_dict
(),
CFG
.
model_path
)
torch
.
save
(
model
.
state_dict
(),
CFG
.
model_path
)
...
...
utils.py
View file @
c09e3ce4
...
@@ -17,30 +17,3 @@ def add_imageid():
...
@@ -17,30 +17,3 @@ def add_imageid():
ids
=
[
id_
for
id_
in
range
(
len
(
df
)
//
5
)
for
i
in
range
(
5
)]
ids
=
[
id_
for
id_
in
range
(
len
(
df
)
//
5
)
for
i
in
range
(
5
)]
df
[
'id'
]
=
ids
df
[
'id'
]
=
ids
df
.
to_csv
(
CFG
.
captions_path
+
"captions.csv"
,
index
=
False
)
df
.
to_csv
(
CFG
.
captions_path
+
"captions.csv"
,
index
=
False
)
class
AvgMeter
:
"""
Helps in storing average of loss across a batch/epoch
"""
def
__init__
(
self
,
name
=
"Metric"
):
self
.
name
=
name
self
.
reset
()
def
reset
(
self
):
self
.
avg
,
self
.
sum
,
self
.
count
=
[
0
]
*
3
def
update
(
self
,
val
,
count
=
1
):
self
.
count
+=
count
self
.
sum
+=
val
*
count
self
.
avg
=
self
.
sum
/
self
.
count
def
__repr__
(
self
):
text
=
f
"{self.name}: {self.avg:.4f}"
return
text
def
get_lr
(
optimizer
):
"""
Return paramater gourp for lr in oprimiser.
"""
for
param_group
in
optimizer
.
param_groups
:
return
param_group
[
"lr"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment