import copy
import torch
from torch import optim
from sklearn import svm, metrics
import numpy as np
from . import confusion_matrix as cm

def cross_validation(models, learning_rate, criterion, block_imgs, block_labels, cuda_available, epochs):
    
    loss_matrix = []
    confusion_matrices = []

    for model in models:
        
        print("Evaluating model: {}".format(model))
        accuracies = []
        con_mats = []

        for test_holdout in range(len(block_imgs)):
            
            temp_model = copy.deepcopy(model)
            con_mat = cm.Confusion_Matrix(range(8))
            optimizer = optim.Adam(temp_model.parameters(), lr=learning_rate)
            
            if cuda_available: temp_model.to("cuda")
            for epoch in range(epochs):
                
                temp_model.train()
                for block in range(len(block_imgs)):
                    
                    if block == test_holdout: continue
                    if cuda_available: images, labels = block_imgs[block].to("cuda"), block_labels[block].to("cuda")
                    else: images, labels = block_imgs[block], block_labels[block]
                    
                    optimizer.zero_grad()

                    op = temp_model.forward(images)
                    loss = criterion(op, labels)
                    loss.backward()
                    optimizer.step()
            temp_model.eval()
            with torch.no_grad():
                acc, probabilities = get_accuracy(temp_model, block_imgs[test_holdout], block_labels[test_holdout], cuda_available)
                accuracies.append(acc)
                print("Validation test block: {}\tAccuracy: {}".format(test_holdout, accuracies[-1]))
                y_pred = probabilities.max(dim=1)[1]
                con_mat.update_matrix(y_pred.tolist(), block_labels[test_holdout].tolist())
                con_mats.append(con_mat)
            del temp_model
            if cuda_available: torch.cuda.empty_cache()
        loss_matrix.append(accuracies)
        confusion_matrices.append(con_mats)
    return loss_matrix, confusion_matrices


def get_accuracy(model, imgs, labels, cuda_available):
    accuracy = 0.0
    if cuda_available: images, labels = imgs.to("cuda"), labels.to("cuda")
    else: images, labels = imgs, labels
    op = model.forward(images)
    probabilities = torch.exp(op)
    equalities = (labels.data == probabilities.max(dim=1)[1])
    accuracy = equalities.type(torch.FloatTensor).mean()
    return accuracy, probabilities

def cross_validation_svm(models, learning_rate, gamma_val, block_imgs, block_labels, criterion, cuda_available, epochs):
    loss_matrix = []
    confusion_matrices = []

    for model in models:
        print("Evaluating model: {}".format(model))
        accuracies = []
        con_mats = []
        
        for test_holdout in range(len(block_imgs)):

            temp_model = copy.deepcopy(model)
            optimizer = optim.Adam(temp_model.parameters(), lr=learning_rate)
            if cuda_available: temp_model.to("cuda")
            svm_imgs = []
            svm_labels = []
            for k in range(len(block_imgs)):
                if k == test_holdout: continue
                svm_imgs.append(block_imgs[k])
                svm_labels.append(block_labels[k])
            for e in range(epochs):
                temp_model.train()
                for block in range(len(block_imgs)):
                    if block == test_holdout: continue
                    if cuda_available: images, labels = block_imgs[block].to("cuda"), block_labels[block].to("cuda")
                    else: images, labels = block_imgs[block], block_labels[block]
                    optimizer.zero_grad()

                    output = temp_model.forward(images)
                    loss = criterion(output, labels)
                    loss.backward()
                    optimizer.step()
            temp_model.eval()
            with torch.no_grad():
                classifier = svm.SVC(gamma=gamma_val)
                svm_train = svm_imgs[0]
                svm_lbls = svm_labels[0]
                for i in range(1, len(svm_imgs)):
                    svm_train = torch.cat((svm_train, svm_imgs[i]))
                    svm_lbls = torch.cat((svm_lbls, svm_labels[i]))
                if cuda_available: classifier.fit(temp_model.feature_extraction(svm_train.to("cuda")).cpu().numpy(), svm_lbls.numpy())
                else: classifier.fit(temp_model.feature_extraction(svm_train).numpy(), svm_lbls.numpy())
                if cuda_available: y_pred = classifier.predict(temp_model.feature_extraction(block_imgs[test_holdout].to("cuda")).cpu().numpy())
                else: y_pred = classifier.predict(temp_model.feature_extraction(block_imgs[test_holdout]).numpy())
                equalities = (block_labels[test_holdout].data == torch.IntTensor(y_pred))
                accuracy = equalities.type(torch.FloatTensor).mean()
                accuracies.append(accuracy)
                con_mat = cm.Confusion_Matrix(range(8))
                con_mat.update_matrix(torch.IntTensor(y_pred).tolist(), block_labels[test_holdout].tolist())
                con_mats.append(con_mat)
                print("Validation test block: {}\tAccuracy: {}".format(test_holdout, accuracies[-1]))
            del temp_model
            if cuda_available: torch.cuda.empty_cache()
        loss_matrix.append(accuracies)
        confusion_matrices.append(con_mats)
    return loss_matrix, confusion_matrices

def get_accuracy_svm(model, imgs, labels, cuda_available):
    accuracy = 0.0

    if cuda_available: 
        y_pred = model.classifier.predict(model.feature_extraction(imgs.to("cuda")).cpu().numpy())
    else:
        y_pred = model.classifier.predict(model.feature_extraction(imgs).numpy())

    equalities = (labels.data == torch.IntTensor(y_pred))
    accuracy = equalities.type(torch.FloatTensor).mean()

    return accuracy
