Commit fdb0f7ca authored by Saswat's avatar Saswat

change data.py

parent c09e3ce4
...@@ -11,7 +11,7 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -11,7 +11,7 @@ class CLIPDataset(torch.utils.data.Dataset):
""" """
Dataset class for getting image and text using a dataset loader. Dataset class for getting image and text using a dataset loader.
""" """
def __init__(self, image_filenames, captions, tokenizer, transforms): def __init__(self, image_filenames, encoded_captions, data_len):
""" """
image_filenames and cpations must have the same length; so, if there are image_filenames and cpations must have the same length; so, if there are
multiple captions for each image, the image_filenames must have repetitive multiple captions for each image, the image_filenames must have repetitive
...@@ -19,11 +19,20 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -19,11 +19,20 @@ class CLIPDataset(torch.utils.data.Dataset):
""" """
self.image_filenames = image_filenames self.image_filenames = image_filenames
self.captions = list(captions) self.encoded_captions = encoded_captions
self.encoded_captions = tokenizer( self.datalen = data_len
list(captions), padding=True, truncation=True, max_length=CFG.max_length self.imgtransforms = A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
) )
self.transforms = transforms
def __getimage(self, idx):
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.imgtransforms(image=image)['image']
return torch.tensor(image).permute(2, 0, 1).float()
def __getitem__(self, idx): def __getitem__(self, idx):
""" """
...@@ -33,13 +42,7 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -33,13 +42,7 @@ class CLIPDataset(torch.utils.data.Dataset):
key: torch.tensor(values[idx]) key: torch.tensor(values[idx])
for key, values in self.encoded_captions.items() for key, values in self.encoded_captions.items()
} }
item['image'] = self.__getimage(idx)
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = self.transforms(image=image)['image']
item['image'] = torch.tensor(image).permute(2, 0, 1).float()
item['caption'] = self.captions[idx]
return item return item
...@@ -47,34 +50,28 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -47,34 +50,28 @@ class CLIPDataset(torch.utils.data.Dataset):
""" """
Returns size of our training data. Returns size of our training data.
""" """
return len(self.captions) return self.datalen
def image_transforms():
"""
Implements image transformations.
"""
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
def gen_train_valid_dfs(): def gen_train_valid_dfs():
""" """
Split dataset into train and validation dataset. Split dataset into train and validation dataset.
""" """
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv")
# Get max number of images
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100 max_id = dataframe["id"].max() + 1 if not CFG.debug else 100
image_ids = np.arange(0, max_id) image_ids = np.arange(0, max_id)
np.random.seed(CFG.seed) np.random.seed(CFG.seed)
valid_ids = np.random.choice(
valid_data_ids = np.random.choice(
image_ids, size=int(0.2 * len(image_ids)), replace=False image_ids, size=int(0.2 * len(image_ids)), replace=False
) )
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) # Generate train and validation data frames.
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True) train_data_ids = [id_ for id_ in image_ids if id_ not in valid_data_ids]
train_dataframe = dataframe[dataframe["id"].isin(train_data_ids)].reset_index(drop=True)
valid_dataframe = dataframe[dataframe["id"].isin(valid_data_ids)].reset_index(drop=True)
return train_dataframe, valid_dataframe return train_dataframe, valid_dataframe
...@@ -82,13 +79,19 @@ def get_dataset_loader(dataframe, tokenizer, mode): ...@@ -82,13 +79,19 @@ def get_dataset_loader(dataframe, tokenizer, mode):
""" """
Build a dataset loader using CLIPDataset. Build a dataset loader using CLIPDataset.
""" """
transforms = image_transforms() # Generate encoded caption for tokeniser provided
encoded_captions = tokenizer(
list(dataframe["caption"].values), padding=True, truncation=True, max_length=CFG.max_length
)
# Pass image and encoded caption to the dataset
dataset = CLIPDataset( dataset = CLIPDataset(
dataframe["image"].values, dataframe["image"].values,
dataframe["caption"].values, encoded_captions,
tokenizer=tokenizer, len(list(dataframe["caption"].values))
transforms=transforms,
) )
# Create dataset loader
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=CFG.batch_size, batch_size=CFG.batch_size,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment