#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: load-alexnet.py
# Author: Yuxin Wu

from __future__ import print_function
import argparse
import numpy as np
import os
import cv2
import tensorflow as tf

from tensorpack import *
from tensorpack.dataflow.dataset import ILSVRCMeta
from tensorpack.tfutils.summary import *


def tower_func(image):
    # img: 227x227x3
    with argscope([Conv2D, FullyConnected], activation=tf.nn.relu):
        l = Conv2D('conv1', image, filters=96, kernel_size=11, strides=4, padding='VALID')
        l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm1')
        l = MaxPooling('pool1', l, 3, strides=2, padding='VALID')

        l = Conv2D('conv2', l, filters=256, kernel_size=5, split=2)
        l = tf.nn.lrn(l, 2, bias=1.0, alpha=2e-5, beta=0.75, name='norm2')
        l = MaxPooling('pool2', l, 3, strides=2, padding='VALID')

        l = Conv2D('conv3', l, filters=384, kernel_size=3)
        l = Conv2D('conv4', l, filters=384, kernel_size=3, split=2)
        l = Conv2D('conv5', l, filters=256, kernel_size=3, split=2)
        l = MaxPooling('pool3', l, 3, strides=2, padding='VALID')

        # This is just a script to load model, so we ignore the dropout layer
        l = FullyConnected('fc6', l, 4096)
        l = FullyConnected('fc7', l, 4096)
    logits = FullyConnected('fc8', l, 1000)
    tf.nn.softmax(logits, name='prob')


def run_test(path, input):
    param_dict = dict(np.load(path))
    predictor = OfflinePredictor(PredictConfig(
        input_signature=[tf.TensorSpec((None, 227, 227, 3), tf.float32, 'input')],
        tower_func=tower_func,
        session_init=DictRestore(param_dict),
        input_names=['input'],
        output_names=['prob']
    ))

    im = cv2.imread(input)
    assert im is not None, input
    im = cv2.resize(im, (227, 227))[None, :, :, ::-1].astype('float32') - 110
    outputs = predictor(im)[0]
    prob = outputs[0]
    ret = prob.argsort()[-10:][::-1]
    print("Top10 predictions:", ret)

    meta = ILSVRCMeta().get_synset_words_1000()
    print("Top10 class names:", [meta[k] for k in ret])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
    parser.add_argument('--load', required=True,
                        help='.npz model file generated by tensorpack.utils.loadcaffe')
    parser.add_argument('--input', help='an input image', required=True)
    args = parser.parse_args()
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    # run alexnet with given model (in npz format)
    run_test(args.load, args.input)
