import torch

class CFG:
    debug = False
    seed = 42

    # Paths
    dataset_path = "/home/saswat/Desktop/flickr30k_images/"
    image_path = dataset_path+"flickr30k_images/"
    captions_path = "."
    model_path = "./model/best.pt"
    img_emb_path = "./model/image_embeddings.pt"

    # Training Params
    epochs = 2
    batch_size = 8
    num_workers = 4
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Pretrained model (Text Embeddings/ Image Embeddings) params
    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1