#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: param.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import tensorflow as tf
from abc import abstractmethod, ABCMeta
import operator

from .base import Callback
from ..utils import logger
from ..tfutils import get_op_var_name

__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
           'ScheduledHyperParamSetter']

class HyperParamSetter(Callback):
    """
    Base class to set hyperparameters after every epoch.
    """
    __metaclass__ = ABCMeta

    TF_VAR = 0
    OBJ_ATTR = 1

    def __init__(self, param, shape=[]):
        """
        :param param: either a name of the variable in the graph, or a (object, attribute) tuple
        :param shape: shape of the param
        """
        if isinstance(param, tuple):
            self.param_type = HyperParamSetter.OBJ_ATTR
            self.obj_attr = param
            self.readable_name = param[1]
        else:
            self.param_type = HyperParamSetter.TF_VAR
            self.readable_name, self.var_name = get_op_var_name(param)
        self.shape = shape
        self.last_value = None

    def _setup_graph(self):
        if self.param_type == HyperParamSetter.TF_VAR:
            all_vars = tf.all_variables()
            for v in all_vars:
                if v.name == self.var_name:
                    self.var = v
                    break
            else:
                raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))

            self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
                                             name=self.readable_name + '_feed')
            self.assign_op = self.var.assign(self.val_holder)

    def get_current_value(self):
        """
        :returns: the value to assign to the variable now.
        """
        ret = self._get_current_value()
        if ret is not None and ret != self.last_value:
            logger.info("{} at epoch {} will change to {}".format(
                self.readable_name, self.epoch_num + 1, ret))
        self.last_value = ret
        return ret

    @abstractmethod
    def _get_current_value(self):
        pass

    def _trigger_epoch(self):
        self._set_param()

    def _before_train(self):
        self._set_param()

    def _set_param(self):
        v = self.get_current_value()
        if v is not None:
            if self.param_type == HyperParamSetter.TF_VAR:
                self.assign_op.eval(feed_dict={self.val_holder:v})
            else:
                setattr(self.obj_attr[0], self.obj_attr[1], v)

class HumanHyperParamSetter(HyperParamSetter):
    """
    Set hyperparameters manually by modifying a file.
    """
    def __init__(self, param, file_name):
        """
        :param file_name: a file containing the value of the variable. Each line in the file is a k:v pair
        """
        self.file_name = file_name
        super(HumanHyperParamSetter, self).__init__(param)

    def  _get_current_value(self):
        try:
            with open(self.file_name) as f:
                lines = f.readlines()
            lines = [s.strip().split(':') for s in lines]
            dic = {str(k):float(v) for k, v in lines}
            ret = dic[self.readable_name]
            return ret
        except:
            logger.warn(
                "Failed to parse {} in {}".format(
                    self.readable_name, self.file_name))
            return None

class ScheduledHyperParamSetter(HyperParamSetter):
    """
    Set hyperparameters by a predefined schedule.
    """
    def __init__(self, param, schedule):
        """
        :param schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
        """
        schedule = [(int(a), float(b)) for a, b in schedule]
        self.schedule = sorted(schedule, key=operator.itemgetter(0))
        super(ScheduledHyperParamSetter, self).__init__(param)

    def _get_current_value(self):
        for e, v in self.schedule:
            if e == self.epoch_num:
                return v
        return None



