Commit d5a9e675 authored by Meet Narendra's avatar Meet Narendra 💬

Feature map extractor

parent 107e376e
*pycache*
*.pdf
*.csv
*.ipynb
*Logs*
*.log
\ No newline at end of file
import torch
from logger import Logger
LOGGER = Logger().logger()
LOGGER.info("Started Feature Maps")
class FeatureMaps:
def __init__(self,arch="vgg19"):
'''
Init function
@params
arch: str {vgg11,vgg13,vgg16,vgg19,vgg19bn}
'''
try:
self.model = torch.hub.load('pytorch/vision:v0.10.0',arch,pretrained=True)
except:
LOGGER.error("Could not load model")
return
def get_model(self):
return self.model
def get_layers(self,layers=[]):
'''
Function to extract layers
@params
layers: list
'''
weights = []
for layer in layers:
try:
weights.append(self.model.features[layer].weight)
except:
LOGGER.error("Could not fetch layer "+str(layer))
return weights
if __name__ == "__main__":
fmap = FeatureMaps()
model = fmap.get_model()
print(model.features)
weights = fmap.get_layers([4,2,6])
print(len(weights))
for weight in weights:
print(type(weight),weight.shape)
from distutils.log import Log
import logging
import os
class Logger:
'''
Singleton logger class
'''
_instance = None
_logHandler = None
_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
def __new__(cls,*args,**kwargs):
if not cls._instance:
os.system("rm -rf Logs/")
os.mkdir("Logs/")
logHandler = logging.FileHandler("Logs/style_transfer.log")
logHandler.setFormatter(cls._formatter)
cls._logHandler = logging.getLogger("Logs/style_transfer.log")
cls._logHandler.setLevel(logging.INFO)
cls._logHandler.addHandler(logHandler)
cls._instance = super(Logger, cls).__new__(cls,*args,**kwargs)
return cls._instance
def logger(self):
return self._logHandler
'''
#Demo use
if __name__ == "__main__":
a = Logger()
b = Logger()
print(a is b)
INFO = a.logger()
ERROR = b.logger()
INFO.info("TEST")
ERROR.info("ERROR")
'''
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment