from os import device_encoding
from logger import Logger
import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
LOGGER = Logger().logger()
device=torch.device( "cuda" if (torch.cuda.is_available()) else 'cpu')
#Author: @meetdoshi
class Preprocessor:
    @staticmethod
    def load_image(path):
        '''
        Function to load image
        @params
        path: os.path
        '''
        img = Image.open(path)
        return img

    @staticmethod
    def subtract_mean(img):
        '''
        Function to subtract mean values of RGB channels computed over whole ImageNet dataset
        @params
        img: 3d numpy array
        '''
        mean = np.reshape([103.939, 116.779, 123.68],(1,1,3))#b,g,r
        return img-mean

    @staticmethod
    def reshape_img(img):
        '''
        Function to reshpae image in 224x224xnum_of_channels shape
        @params
        img: 3d numpy array
        '''
        #loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],),])
        loader = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224])])
        img = loader(img).unsqueeze(0)
        #assert img.shape == (1,3,224,224)
        return img.to(device,torch.float)


    @staticmethod
    def process(path):
        '''
        Function to preprocess the image
        @params
        path: os.path
        '''
        img = Preprocessor.load_image(path)
        img = Preprocessor.reshape_img(img)
        #img = Preprocessor.subtract_mean(img)
        return img

'''
if __name__=="__main__":
    prec = Preprocessor()
    img = np.zeros(shape=(4,4,3))
    img = prec.process('test/sem8.jpeg')
'''
