import math
import torch
from torch.nn import Module


class Base_Model(Module):

    def __init__(self, kernel, stride, padding, img_size, num_cnn_layers, out_channels):
        super(Base_Model, self).__init__()
        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.img_size = img_size
        self.num_cnn = num_cnn_layers
        self.out_channels = out_channels
        self.out_dim = 0

        #out_dim calc
        self.out_dim = math.floor( (self.img_size + 2*self.padding[0] - self.kernel[0])/self.stride[0] ) + 1
        self.out_dim = math.floor( (self.out_dim + 2*self.padding[0] - self.kernel[0])/self.stride[0] ) + 1
        self.out_dim = math.floor( (self.out_dim - self.kernel[0])/self.stride[0] ) + 1
        for i in range(self.num_cnn - 1):
            self.out_dim = math.floor( (self.out_dim + 2*self.padding[0] - self.kernel[0])/self.stride[0] ) + 1
            self.out_dim = math.floor( (self.out_dim + 2*self.padding[0] - self.kernel[0])/self.stride[0] ) + 1
            self.out_dim = math.floor( (self.out_dim - self.kernel[0])/self.stride[0] ) + 1
        self.out_dim = (self.out_dim**2)*self.out_channels

        self.cnn_layers = None
        self.classifier = None

    def feature_extraction(self, x):
        if self.cnn_layers == None:
            print("CNN layers not defined...")
            return
        else:
            features = self.cnn_layers(x)
            return features.view(features.size(0), -1)

    def forward(self, x):
        if self.cnn_layers == None or self.classifier == None:
            print("Layers not defined...")
            return
        else:
            features = self.feature_extraction(x)
            output = self.classifier(features)
            return output
            #return features
