import torch
from logger import Logger
LOGGER = Logger().logger()
LOGGER.info("Started Feature Maps")
class FeatureMaps:
    def __init__(self,arch="vgg19"):
        '''
        Init function
        @params
        arch: str {vgg11,vgg13,vgg16,vgg19,vgg19bn}
        '''
        try:
            self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
        except:
            LOGGER.error("Could not load model")
        return
    
    def get_model(self):
        return self.model

    def get_layers(self,layers=[]):
        '''
        Function to extract layers
        @params
        layers: list
        '''
        weights = []
        for layer in layers:
            try:
                weights.append(self.model.features[layer].weight)
            except:
                LOGGER.error("Could not fetch layer "+str(layer))
        return weights

if __name__ == "__main__":
    fmap = FeatureMaps()
    model = fmap.get_model()
    print(model.features)
    weights = fmap.get_layers([4,2,6])
    print(len(weights))
    for weight in weights:
        print(type(weight),weight.shape)
