from keras.applications.vgg19 import VGG19
from keras.applications.vgg19 import preprocess_input
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.models import Model
from matplotlib import pyplot as plt
import os
import numpy as np
from matplotlib import pyplot as plt
import cv2 as cv

def vgg_visualization(inputImage,rootDir):

    model = VGG19()
    image = load_img(inputImage,target_size=(224,224))
    image = img_to_array(image)
    image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2]))
    image = preprocess_input(image)
    generate_visualizations(model,image,rootDir)

def generate_visualizations(model,input_image,root_dir):
        try:
            os.mkdir(root_dir)
        except OSError as error:
            pass
        no_conv_layers = 1
        plt.axis('off')
        
        for layer in model.layers:
            if 'conv' in layer.name:
                try:
                    os.mkdir(os.path.join(root_dir,'conv{}'.format(no_conv_layers)))
                except OSError as error:
                    pass
                op = Model(inputs=model.inputs,outputs=model.layers[no_conv_layers-1].output)
                
                # Output images
                pred = op.predict(input_image)
                index = 0
                count = pred.shape[-1]
                max_count = 8
                max_depth = 4
                while count != index and index<=max_count:
                    plt.imshow(pred[0,:,:,index])
                    plt.savefig(os.path.join(root_dir,'conv{}'.format(no_conv_layers),'{}.jpg'.format(index+1)),transparent=True)
                    index += 1
                
                # Save weights
                weights,bias = layer.get_weights()
                file = open(os.path.join(root_dir,'meta.txt'), "a")
                file.write('{}\n{}\n{}\n'.format(weights.shape[0],weights.shape[2],weights.shape[3]))
                file.close()
                min_val,max_val = weights.min(), weights.max()
                weights = (weights - min_val)/(max_val-min_val+1e-8)
                depth = weights.shape[2]
                dimensions = (100,100)
                for index in range(weights.shape[-1]):
                    if index>max_count:
                        break
                    weight = weights[:,:,:,index]
                    for channel in range(depth):
                        if channel > max_depth:
                            break
                        final_image = cv.resize(weight[:,:,channel], dimensions, interpolation = cv.INTER_AREA)
                        plt.imshow(final_image,cmap='gray')
                        plt.savefig(os.path.join(root_dir,'conv{}'.format(no_conv_layers),'{}-{}.jpg'.format(index+1,channel+1)),transparent=True)
                no_conv_layers += 1
#vgg_visualization('bird.jpg','vggImages')
