from django.shortcuts import render
from django.views import View
from django.views.generic.base import TemplateView
from django.urls import reverse
from django.contrib import messages
from django_drf_filepond.models import TemporaryUpload
import os
import shutil
#from .vggModel import vgg_visualization
# Create your views here.

class homeView(TemplateView):
    template_name = 'home.html'

class mnistVisualView(TemplateView):
    template_name = 'mnist.html'

    def get_context_data(self, **kwargs):
        context =  super().get_context_data(**kwargs)
        image_context = load_images('mnist')
        context.update(image_context)
        return context

class vggVisualView(TemplateView):
    template_name = 'vgg.html'

    def get_context_data(self, **kwargs):
        context =  super().get_context_data(**kwargs)
        image_context = load_images('vgg19')
        context.update(image_context)
        return context

class flowerVisualView(TemplateView):
    template_name = 'flower.html'

    def get_context_data(self, **kwargs):
        context =  super().get_context_data(**kwargs)
        image_context = load_images('flower')
        context.update(image_context)
        return context

class catDogVisualView(TemplateView):
    template_name = 'catDog.html'

    def get_context_data(self, **kwargs):
        context =  super().get_context_data(**kwargs)
        image_context = load_images('catDog')
        context.update(image_context)
        return context

class vgguploadView(View):
    def get(self,request,*args,**kwargs):
        return render(request,'vggUpload.html')
    def post(self,request,*args,**kwargs):
        upload_id = request.POST.get('filepond')
        if upload_id !="":
            uploaded_file = TemporaryUpload.objects.get(upload_id=upload_id)
            prefix_path = os.path.join(os.getcwd(),'visualization','static')
            source = os.path.join(prefix_path,'temp-image-uploads',upload_id,uploaded_file.file_id)
            destination = os.path.join(prefix_path,'visualization','vgg19','example','input.jpg')
            shutil.move(source,destination)
            os.rmdir(os.path.join(prefix_path,'temp-image-uploads',upload_id))
            uploaded_file.delete()
            #vgg_visualization(destination,os.path.join(prefix_path,'visualization','vgg19','example'))
        else:
            messages.error(request,'You have not selected any input image')
            return render(request,'vggUpload.html')
        

# Helper codes
def load_images(model_type):
    context = {'model_name':model_type.capitalize()}
    context['images'] = []
    root_dir = os.path.join(os.getcwd(),'visualization','static','visualization',model_type)
    static_path = os.path.join('visualization',model_type)
    for class_type in os.listdir(root_dir):
        meta_data = open(os.path.join(root_dir,class_type,'meta.txt'),'r').readlines()
        line_no = 0
        images = dict()
        images['input'] = os.path.join(static_path,class_type,'input.jpg')
        images['output'] = os.path.join(static_path,class_type,'output.jpg')
        image_dir = os.path.join(root_dir,class_type)
        conv_dict = []
        for conv in os.listdir(image_dir):
            conv_dir = os.path.join(image_dir,conv)
            if os.path.isdir(conv_dir):
                t = dict()
                t['kernel_size'] = meta_data[line_no]
                t['prev_input_depth'] = meta_data[line_no+1]
                t['current_depth'] = meta_data[line_no+2]
                line_no += 3
                t['weights'] = []
                outputs = list(filter(lambda x: '-' not in x,os.listdir(conv_dir)))
                t['outputs'] = sorted(outputs,key = key_outputs)
                t['outputs'] = [os.path.join(static_path,class_type,conv,img) for img in t['outputs']]
                for i in range(len(outputs)):
                    a = list(filter(lambda x: '-' in x,os.listdir(conv_dir)))
                    a = list(filter(lambda x: x.startswith(str(i+1)),a))
                    t['weights'].append(sorted(a,key=key_weights))
                    t['weights'][-1] = [os.path.join(static_path,class_type,conv,img) for img in t['weights'][-1]]
                conv_dict.append(t)
        images['convolution'] = conv_dict
        context['images'].append({'class_type':class_type,'images':images})  
    return context

def key_outputs(x):
    return int(x.split('.')[0])

def key_weights(x):
    return int(x.split('.')[0].split('-')[1])