#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: load-alexnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>

from __future__ import print_function
import tensorflow as tf
import numpy as np
import os
import cv2
import argparse

from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta

"""
Usage:
    python -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npy
    ./load-alexnet.py --load alexnet.npy --input cat.png
"""


class Model(ModelDesc):

    def _get_inputs(self):
        return [InputVar(tf.float32, (None, 227, 227, 3), 'input')]

    def _build_graph(self, inputs):
        # img: 227x227x3
        image = inputs[0]

        with argscope([Conv2D, FullyConnected], nl=tf.nn.relu):
            l = Conv2D('conv1', image, out_channel=96, kernel_shape=11, stride=4, padding='VALID')
            l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
            l = MaxPooling('pool1', l, 3, stride=2, padding='VALID')

            l = Conv2D('conv2', l, out_channel=256, kernel_shape=5, split=2)
            l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
            l = MaxPooling('pool2', l, 3, stride=2, padding='VALID')

            l = Conv2D('conv3', l, out_channel=384, kernel_shape=3)
            l = Conv2D('conv4', l, out_channel=384, kernel_shape=3, split=2)
            l = Conv2D('conv5', l, out_channel=256, kernel_shape=3, split=2)
            l = MaxPooling('pool3', l, 3, stride=2, padding='VALID')

            # This is just a script to load model, so we ignore the dropout layer
            l = FullyConnected('fc6', l, 4096)
            l = FullyConnected('fc7', l, out_dim=4096)
        # fc will have activation summary by default. disable this for the output layer
        logits = FullyConnected('fc8', l, out_dim=1000, nl=tf.identity)
        prob = tf.nn.softmax(logits, name='prob')


def run_test(path, input):
    param_dict = np.load(path, encoding='latin1').item()
    predict_func = OfflinePredictor(PredictConfig(
        model=Model(),
        session_init=ParamRestore(param_dict),
        input_names=['input'],
        output_names=['prob']
    ))

    im = cv2.imread(input)
    assert im is not None, input
    im = cv2.resize(im, (227, 227))[:, :, ::-1].reshape(
        (1, 227, 227, 3)).astype('float32') - 110
    outputs = predict_func([im])[0]
    prob = outputs[0]
    ret = prob.argsort()[-10:][::-1]
    print("Top10 predictions:", ret)

    meta = ILSVRCMeta().get_synset_words_1000()
    print("Top10 class names:", [meta[k] for k in ret])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
    parser.add_argument('--load', required=True,
                        help='.npy model file generated by tensorpack.utils.loadcaffe')
    parser.add_argument('--input', help='an input image', required=True)
    args = parser.parse_args()
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    # run alexnet with given model (in npy format)
    run_test(args.load, args.input)
