Commit c09e3ce4 authored by Saswat's avatar Saswat

fix train

parent b32fb1a1
import torch import torch
class CFG: class CFG:
debug = False debug = True
seed = 42 seed = 42
# Paths # Paths
......
...@@ -50,24 +50,16 @@ class CLIPDataset(torch.utils.data.Dataset): ...@@ -50,24 +50,16 @@ class CLIPDataset(torch.utils.data.Dataset):
return len(self.captions) return len(self.captions)
def get_transforms(mode="train"): def image_transforms():
""" """
Implements image transformations. Implements image transformations.
""" """
if mode == "train": return A.Compose(
return A.Compose( [
[ A.Resize(CFG.size, CFG.size, always_apply=True),
A.Resize(CFG.size, CFG.size, always_apply=True), A.Normalize(max_pixel_value=255.0, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True), ]
] )
)
else:
return A.Compose(
[
A.Resize(CFG.size, CFG.size, always_apply=True),
A.Normalize(max_pixel_value=255.0, always_apply=True),
]
)
def gen_train_valid_dfs(): def gen_train_valid_dfs():
""" """
...@@ -90,7 +82,7 @@ def get_dataset_loader(dataframe, tokenizer, mode): ...@@ -90,7 +82,7 @@ def get_dataset_loader(dataframe, tokenizer, mode):
""" """
Build a dataset loader using CLIPDataset. Build a dataset loader using CLIPDataset.
""" """
transforms = get_transforms(mode=mode) transforms = image_transforms()
dataset = CLIPDataset( dataset = CLIPDataset(
dataframe["image"].values, dataframe["image"].values,
dataframe["caption"].values, dataframe["caption"].values,
......
...@@ -5,7 +5,6 @@ from tqdm import tqdm ...@@ -5,7 +5,6 @@ from tqdm import tqdm
from config import CFG from config import CFG
import itertools import itertools
from clip import CLIPModel from clip import CLIPModel
from utils import AvgMeter, get_lr
from data import * from data import *
from transformers import DistilBertTokenizer from transformers import DistilBertTokenizer
import torch.nn.functional as F import torch.nn.functional as F
...@@ -45,7 +44,7 @@ def valid_epoch(model, valid_data_loader): ...@@ -45,7 +44,7 @@ def valid_epoch(model, valid_data_loader):
#tqdm_object = tqdm(valid_loader, total=len(valid_loader)) #tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm(valid_data_loader, total=len(valid_data_loader)): for batch in tqdm(valid_data_loader, total=len(valid_data_loader)):
batch_data = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} batch_data = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}
loss = model(batch) loss = model(batch_data)
batch_size = batch_data["image"].size(0) batch_size = batch_data["image"].size(0)
loss_count += max(1, batch_size) loss_count += max(1, batch_size)
loss_sum += loss.item() * max(1,batch_size) loss_sum += loss.item() * max(1,batch_size)
...@@ -104,7 +103,6 @@ def main(): ...@@ -104,7 +103,6 @@ def main():
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader) valid_loss = valid_epoch(model, valid_loader)
if valid_loss < min_loss: if valid_loss < min_loss:
min_loss = valid_loss min_loss = valid_loss
torch.save(model.state_dict(), CFG.model_path) torch.save(model.state_dict(), CFG.model_path)
......
...@@ -17,30 +17,3 @@ def add_imageid(): ...@@ -17,30 +17,3 @@ def add_imageid():
ids = [id_ for id_ in range(len(df) // 5) for i in range(5)] ids = [id_ for id_ in range(len(df) // 5) for i in range(5)]
df['id'] = ids df['id'] = ids
df.to_csv(CFG.captions_path+"captions.csv", index=False) df.to_csv(CFG.captions_path+"captions.csv", index=False)
class AvgMeter:
"""
Helps in storing average of loss across a batch/epoch
"""
def __init__(self, name="Metric"):
self.name = name
self.reset()
def reset(self):
self.avg, self.sum, self.count = [0] * 3
def update(self, val, count=1):
self.count += count
self.sum += val * count
self.avg = self.sum / self.count
def __repr__(self):
text = f"{self.name}: {self.avg:.4f}"
return text
def get_lr(optimizer):
"""
Return paramater gourp for lr in oprimiser.
"""
for param_group in optimizer.param_groups:
return param_group["lr"]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment