Commit 09d1e881 authored by Yuxin Wu's avatar Yuxin Wu

fix test script

parent 50859d25
...@@ -78,7 +78,7 @@ def get_model(inputs, is_training): ...@@ -78,7 +78,7 @@ def get_model(inputs, is_training):
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir) logger.set_logger_file(os.path.join(log_dir, 'training.log'))
dataset_train = FakeData([(227,227,3), tuple()], 10) dataset_train = FakeData([(227,227,3), tuple()], 10)
dataset_train = BatchData(dataset_train, 10) dataset_train = BatchData(dataset_train, 10)
...@@ -158,7 +158,7 @@ if __name__ == '__main__': ...@@ -158,7 +158,7 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
#start_train(get_config()) start_train(get_config())
# run alexnet with given model (in npy format) # run alexnet with given model (in npy format)
run_test('alexnet-tuned.npy') run_test('alexnet-tuned.npy')
...@@ -11,8 +11,6 @@ import imp ...@@ -11,8 +11,6 @@ import imp
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils import sessinit from tensorpack.utils import sessinit
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.predict import DatasetPredictor
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(dest='config') parser.add_argument(dest='config')
...@@ -30,6 +28,3 @@ with tf.Graph().as_default() as G: ...@@ -30,6 +28,3 @@ with tf.Graph().as_default() as G:
init.init(sess) init.init(sess)
with sess.as_default(): with sess.as_default():
sessinit.dump_session_params(args.output) sessinit.dump_session_params(args.output)
...@@ -39,7 +39,7 @@ with tf.Graph().as_default() as G: ...@@ -39,7 +39,7 @@ with tf.Graph().as_default() as G:
ds = ImageFromFile(args.images, 3, resize=(227, 227)) ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128) predictor = DatasetPredictor(config, ds, batch=128)
res = predictor.get_all_result() res = predictor.get_all_result()
res = [k[1] for k in res] res = [k.output for k in res]
if args.output_type == 'label': if args.output_type == 'label':
for r in res: for r in res:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from itertools import count from itertools import count, izip
import argparse import argparse
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
...@@ -93,7 +93,7 @@ def get_predict_func(config): ...@@ -93,7 +93,7 @@ def get_predict_func(config):
assert len(input_map) == len(dp), \ assert len(input_map) == len(dp), \
"Graph has {} inputs but dataset only gives {} components!".format( "Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp)) len(input_map), len(dp))
feed = dict(zip(input_map, dp)) feed = dict(izip(input_map, dp))
if output_var_names is not None: if output_var_names is not None:
results = sess.run(output_vars, feed_dict=feed) results = sess.run(output_vars, feed_dict=feed)
return results return results
......
...@@ -56,6 +56,7 @@ class ParamRestore(SessionInit): ...@@ -56,6 +56,7 @@ class ParamRestore(SessionInit):
sess.run(var.assign(value)) sess.run(var.assign(value))
def dump_session_params(path): def dump_session_params(path):
""" dump value of all trainable variables to a dict"""
var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
result = {} result = {}
for v in var: for v in var:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment