Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CLIP
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Saswat
CLIP
Commits
fdb0f7ca
Commit
fdb0f7ca
authored
Nov 27, 2022
by
Saswat
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
change data.py
parent
c09e3ce4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
32 deletions
+35
-32
data.py
data.py
+35
-32
No files found.
data.py
View file @
fdb0f7ca
...
...
@@ -11,7 +11,7 @@ class CLIPDataset(torch.utils.data.Dataset):
"""
Dataset class for getting image and text using a dataset loader.
"""
def
__init__
(
self
,
image_filenames
,
captions
,
tokenizer
,
transforms
):
def
__init__
(
self
,
image_filenames
,
encoded_captions
,
data_len
):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
...
...
@@ -19,11 +19,20 @@ class CLIPDataset(torch.utils.data.Dataset):
"""
self
.
image_filenames
=
image_filenames
self
.
captions
=
list
(
captions
)
self
.
encoded_captions
=
tokenizer
(
list
(
captions
),
padding
=
True
,
truncation
=
True
,
max_length
=
CFG
.
max_length
self
.
encoded_captions
=
encoded_captions
self
.
datalen
=
data_len
self
.
imgtransforms
=
A
.
Compose
(
[
A
.
Resize
(
CFG
.
size
,
CFG
.
size
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
]
)
self
.
transforms
=
transforms
def
__getimage
(
self
,
idx
):
image
=
cv2
.
imread
(
f
"{CFG.image_path}/{self.image_filenames[idx]}"
)
image
=
cv2
.
cvtColor
(
image
,
cv2
.
COLOR_BGR2RGB
)
image
=
self
.
imgtransforms
(
image
=
image
)[
'image'
]
return
torch
.
tensor
(
image
)
.
permute
(
2
,
0
,
1
)
.
float
()
def
__getitem__
(
self
,
idx
):
"""
...
...
@@ -33,13 +42,7 @@ class CLIPDataset(torch.utils.data.Dataset):
key
:
torch
.
tensor
(
values
[
idx
])
for
key
,
values
in
self
.
encoded_captions
.
items
()
}
image
=
cv2
.
imread
(
f
"{CFG.image_path}/{self.image_filenames[idx]}"
)
image
=
cv2
.
cvtColor
(
image
,
cv2
.
COLOR_BGR2RGB
)
image
=
self
.
transforms
(
image
=
image
)[
'image'
]
item
[
'image'
]
=
torch
.
tensor
(
image
)
.
permute
(
2
,
0
,
1
)
.
float
()
item
[
'caption'
]
=
self
.
captions
[
idx
]
item
[
'image'
]
=
self
.
__getimage
(
idx
)
return
item
...
...
@@ -47,34 +50,28 @@ class CLIPDataset(torch.utils.data.Dataset):
"""
Returns size of our training data.
"""
return
len
(
self
.
captions
)
return
self
.
datalen
def
image_transforms
():
"""
Implements image transformations.
"""
return
A
.
Compose
(
[
A
.
Resize
(
CFG
.
size
,
CFG
.
size
,
always_apply
=
True
),
A
.
Normalize
(
max_pixel_value
=
255.0
,
always_apply
=
True
),
]
)
def
gen_train_valid_dfs
():
"""
Split dataset into train and validation dataset.
"""
dataframe
=
pd
.
read_csv
(
f
"{CFG.captions_path}/captions.csv"
)
# Get max number of images
max_id
=
dataframe
[
"id"
]
.
max
()
+
1
if
not
CFG
.
debug
else
100
image_ids
=
np
.
arange
(
0
,
max_id
)
np
.
random
.
seed
(
CFG
.
seed
)
valid_ids
=
np
.
random
.
choice
(
valid_data_ids
=
np
.
random
.
choice
(
image_ids
,
size
=
int
(
0.2
*
len
(
image_ids
)),
replace
=
False
)
train_ids
=
[
id_
for
id_
in
image_ids
if
id_
not
in
valid_ids
]
train_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
train_ids
)]
.
reset_index
(
drop
=
True
)
valid_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
valid_ids
)]
.
reset_index
(
drop
=
True
)
# Generate train and validation data frames.
train_data_ids
=
[
id_
for
id_
in
image_ids
if
id_
not
in
valid_data_ids
]
train_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
train_data_ids
)]
.
reset_index
(
drop
=
True
)
valid_dataframe
=
dataframe
[
dataframe
[
"id"
]
.
isin
(
valid_data_ids
)]
.
reset_index
(
drop
=
True
)
return
train_dataframe
,
valid_dataframe
...
...
@@ -82,13 +79,19 @@ def get_dataset_loader(dataframe, tokenizer, mode):
"""
Build a dataset loader using CLIPDataset.
"""
transforms
=
image_transforms
()
# Generate encoded caption for tokeniser provided
encoded_captions
=
tokenizer
(
list
(
dataframe
[
"caption"
]
.
values
),
padding
=
True
,
truncation
=
True
,
max_length
=
CFG
.
max_length
)
# Pass image and encoded caption to the dataset
dataset
=
CLIPDataset
(
dataframe
[
"image"
]
.
values
,
dataframe
[
"caption"
]
.
values
,
tokenizer
=
tokenizer
,
transforms
=
transforms
,
encoded_captions
,
len
(
list
(
dataframe
[
"caption"
]
.
values
))
)
# Create dataset loader
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
CFG
.
batch_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment