Commit fdb0f7ca authored by Saswat's avatar Saswat

change data.py

parent c09e3ce4
......@@ -11,7 +11,7 @@ class CLIPDataset(torch.utils.data.Dataset):
"""
Dataset class for getting image and text using a dataset loader.
"""
def __init__(self, image_filenames, captions, tokenizer, transforms):
def __init__(self, image_filenames, encoded_captions, data_len):
"""
image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive
......@@ -19,11 +19,20 @@ class CLIPDataset(torch.utils.data.Dataset):
"""
self.image_filenames = image_filenames
self.captions = list(captions)
self.encoded_captions = tokenizer(
list(captions), padding=True, truncation=True, max_length=CFG.max_length
self.encoded_captions = encoded_captions
self.datalen = data_len
self.imgtransforms = A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
self.transforms = transforms
def __getimage(self, idx):
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.imgtransforms(image=image)['image']
return torch.tensor(image).permute(2, 0, 1).float()
def __getitem__(self, idx):
"""
......@@ -33,13 +42,7 @@ class CLIPDataset(torch.utils.data.Dataset):
key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items()
}
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]
item['image'] = self.__getimage(idx)
return item
......@@ -47,34 +50,28 @@ class CLIPDataset(torch.utils.data.Dataset):
"""
Returns size of our training data.
"""
return len(self.captions)
return self.datalen
def image_transforms():
"""
Implements image transformations.
"""
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
def gen_train_valid_dfs():
"""
Split dataset into train and validation dataset.
"""
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
# Get max number of images
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id)
np.random.seed(CFG.seed)
valid_ids = np.random.choice(
valid_data_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False
)
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
# Generate train and validation data frames.
train_data_ids = [id_ for id_ in image_ids if id_ not in valid_data_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_data_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_data_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe
......@@ -82,13 +79,19 @@ def get_dataset_loader(dataframe, tokenizer, mode):
"""
Build a dataset loader using CLIPDataset.
"""
transforms = image_transforms()
# Generate encoded caption for tokeniser provided
encoded_captions = tokenizer(
list(dataframe["caption"].values), padding=True, truncation=True, max_length=CFG.max_length
)
# Pass image and encoded caption to the dataset
dataset = CLIPDataset(
dataframe["image"].values,
dataframe["caption"].values,
tokenizer=tokenizer,
transforms=transforms,
encoded_captions,
len(list(dataframe["caption"].values))
)
# Create dataset loader
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=CFG.batch_size,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment