import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder


def train_test_split(img_folder, img_transforms, k, total_imgs):
    image_set = ImageFolder(img_folder, transform=img_transforms)
    image_loader = DataLoader(image_set, batch_size=int(total_imgs/k), shuffle=True)
    image_blocks = []
    label_blocks = []
    for i,j in iter(image_loader):
            image_blocks.append(i)
            label_blocks.append(j)
    return (image_blocks, label_blocks, image_set.classes)


def evaluate(model, img, lbl, criterion, cuda_available):
    
    val_loss = 0
    accuracy = 0
    
    if cuda_available: images, labels = img.to('cuda'), lbl.to('cuda')
    else: images, labels = img, lbl
    output = model.forward(images)
    val_loss += criterion(output, labels).item()
    probabilities = torch.exp(output)
        
    equality = (labels.data == probabilities.max(dim=1)[1])
    accuracy += equality.type(torch.FloatTensor).mean()
    
    return val_loss, accuracy, probabilities


def train_classifier(model, optimizer, criterion, epochs, image_blocks, image_labels, holdout, cuda_available):
    steps = 0

    model.to('cuda')

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for i in range(len(image_blocks)):
            if i==holdout: continue
            if cuda_available: images, labels = image_blocks[i].to('cuda'), image_labels[i].to('cuda')
            else: images, labels = image_blocks[i], image_labels[i]
            optimizer.zero_grad()
    
            output = model.forward(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
    
            train_loss += loss.item()
            
        model.eval()
        with torch.no_grad():
            validation_loss, accuracy, probabilities = evaluate(model, image_blocks[holdout], image_labels[holdout], criterion, cuda_available)
        
            print("Epoch: [{}/{}]\tTraining Loss: {}\tValidation Loss: {}\t Validation Accuracy: {}".format(epoch+1, epochs, train_loss/image_blocks[0].shape[0], validation_loss, accuracy))
        
            train_loss = 0
            model.train()

