Commit 09d1e881 authored by Yuxin Wu's avatar Yuxin Wu

fix test script

parent 50859d25
......@@ -78,7 +78,7 @@ def get_model(inputs, is_training):
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir)
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
dataset_train = FakeData([(227,227,3), tuple()], 10)
dataset_train = BatchData(dataset_train, 10)
......@@ -158,7 +158,7 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#start_train(get_config())
start_train(get_config())
# run alexnet with given model (in npy format)
run_test('alexnet-tuned.npy')
......@@ -11,8 +11,6 @@ import imp
from tensorpack.utils import *
from tensorpack.utils import sessinit
from tensorpack.dataflow import *
from tensorpack.predict import DatasetPredictor
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
......@@ -30,6 +28,3 @@ with tf.Graph().as_default() as G:
init.init(sess)
with sess.as_default():
sessinit.dump_session_params(args.output)
......@@ -39,7 +39,7 @@ with tf.Graph().as_default() as G:
ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128)
res = predictor.get_all_result()
res = [k[1] for k in res]
res = [k.output for k in res]
if args.output_type == 'label':
for r in res:
......
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from itertools import count
from itertools import count, izip
import argparse
from collections import namedtuple
import numpy as np
......@@ -93,7 +93,7 @@ def get_predict_func(config):
assert len(input_map) == len(dp), \
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp))
feed = dict(izip(input_map, dp))
if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed)
return results
......
......@@ -56,6 +56,7 @@ class ParamRestore(SessionInit):
sess.run(var.assign(value))
def dump_session_params(path):
""" dump value of all trainable variables to a dict"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
result = {}
for v in var:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment