#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: load-vgg16.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import cv2
import tensorflow as tf
import numpy as np
import os
import argparse
import cPickle as pkl

from tensorpack.train import TrainConfig, start_train
from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import ILSVRCMeta

"""
Usage:
    python2 -m tensorpack.utils.loadcaffe PATH/TO/models/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy
    ./load_vgg16.py --load vgg16.npy --input cat.png
"""

class Model(ModelDesc):
    def _get_input_vars(self):
        return [InputVar(tf.float32, (None, 224, 224, 3), 'input'),
                InputVar(tf.int32, (None,), 'label') ]

    def _build_graph(self, inputs, is_training):
        is_training = bool(is_training)
        keep_prob = tf.constant(0.5 if is_training else 1.0)

        image, label = inputs

        with argscope(Conv2D, kernel_shape=3):
            # 224
            l = Conv2D('conv1_1', image, 64)
            l = Conv2D('conv1_2', l, 64)
            l = MaxPooling('pool1', l, 2)
            # 112

            l = Conv2D('conv2_1', l, 128)
            l = Conv2D('conv2_2', l, 128)
            l = MaxPooling('pool2', l, 2)
            # 56

            l = Conv2D('conv3_1', l, 256)
            l = Conv2D('conv3_2', l, 256)
            l = Conv2D('conv3_3', l, 256)
            l = MaxPooling('pool3', l, 2)
            # 28

            l = Conv2D('conv4_1', l, 512)
            l = Conv2D('conv4_2', l, 512)
            l = Conv2D('conv4_3', l, 512)
            l = MaxPooling('pool4', l, 2)
            # 14

            l = Conv2D('conv5_1', l, 512)
            l = Conv2D('conv5_2', l, 512)
            l = Conv2D('conv5_3', l, 512)
            l = MaxPooling('pool5', l, 2)
        # 7

        l = FullyConnected('fc6', l, 4096)
        l = tf.nn.dropout(l, keep_prob)
        l = FullyConnected('fc7', l, 4096)
        l = tf.nn.dropout(l, keep_prob)
        logits = FullyConnected('fc8', l, out_dim=1000, nl=tf.identity)
        prob = tf.nn.softmax(logits, name='output')

def run_test(path, input):
    param_dict = np.load(path).item()

    pred_config = PredictConfig(
        model=Model(),
        input_data_mapping=[0],
        session_init=ParamRestore(param_dict),
        output_var_names=['output:0']   # output:0 is the probability distribution
    )
    predict_func = get_predict_func(pred_config)

    import cv2
    im = cv2.imread(input)
    assert im is not None
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (224, 224))
    im = np.reshape(im, (1, 224, 224, 3)).astype('float32')
    im = im - 110
    outputs = predict_func([im])[0]
    prob = outputs[0]
    ret = prob.argsort()[-10:][::-1]
    print ret

    meta = ILSVRCMeta().get_synset_words_1000()
    print [meta[k] for k in ret]

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', default='0',
                        help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
    parser.add_argument('--load', required=True,
                        help='.npy model file generated by tensorpack.utils.loadcaffe')
    parser.add_argument('--input', help='an input image', required=True)
    args = parser.parse_args()
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    run_test(args.load, args.input)
