Commit 50859d25 authored by Yuxin Wu's avatar Yuxin Wu

don't need get_model to return output variable

parent 27ff6a18
......@@ -15,7 +15,7 @@ from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
BATCH_SIZE = 10
......@@ -25,7 +25,7 @@ CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training):
# img: 227x227x3
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 0.0)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
......@@ -73,7 +73,7 @@ def get_model(inputs, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
return tf.add_n([wd_cost, cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
......
......@@ -83,7 +83,7 @@ def get_model(inputs, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*') # monitor all variables
return [prob, nr_wrong], tf.add_n([cost, wd_cost], name='cost')
return tf.add_n([cost, wd_cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
......
......@@ -31,12 +31,10 @@ def get_model(inputs, is_training):
label_var: bx1 integer
is_training: a python bool variable
Returns:
(outputs, cost)
outputs: a list of output variable
cost: the cost to minimize. scalar variable
the cost to minimize. scalar variable
"""
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 0.0)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
......@@ -83,7 +81,7 @@ def get_model(inputs, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
return tf.add_n([wd_cost, cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
......@@ -126,6 +124,7 @@ def get_config():
callbacks=Callbacks([
SummaryWriter(print_tag=['train_cost', 'train_error']),
PeriodicSaver(),
#ValidationCallback(dataset_test, 'test')
ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
......
......@@ -31,7 +31,7 @@ def create_test_graph():
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost = forward_func(input_vars, is_training=False)
cost = forward_func(input_vars, is_training=False)
yield Gtest
@contextmanager
......
......@@ -6,6 +6,7 @@
import tensorflow as tf
from itertools import count
import argparse
from collections import namedtuple
import numpy as np
from tqdm import tqdm
......@@ -43,8 +44,8 @@ class PredictConfig(object):
return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the
variables can be any computable tensor in the graph.
if None, will predict everything returned by `get_model_func`
(all outputs as well as the cost). Predict only specific output
if None, will only calculate the cost returned by `get_model_func`.
Predict only specific output (instead of the cost)
might be faster and might require only some of the input variables.
"""
def assert_type(v, tp):
......@@ -66,14 +67,13 @@ def get_predict_func(config):
A prediction function that takes a list of inputs value, and return
one/a list of output values.
If `output_var_names` is set, then the prediction function will
return a list of output values. If not, will return a list of output
values and a cost.
return a list of output values. If not, will return a cost.
"""
output_var_names = config.output_var_names
# input/output variables
input_vars = config.inputs
output_vars, cost_var = config.get_model_func(input_vars, is_training=False)
cost_var = config.get_model_func(input_vars, is_training=False)
input_map = config.input_dataset_mapping
if input_map is None:
input_map = input_vars
......@@ -81,6 +81,8 @@ def get_predict_func(config):
# check output_var_names against output_vars
if output_var_names is not None:
output_vars = [tf.get_default_graph().get_tensor_by_name(n) for n in output_var_names]
else:
output_vars = []
describe_model()
......@@ -96,12 +98,13 @@ def get_predict_func(config):
results = sess.run(output_vars, feed_dict=feed)
return results
else:
results = sess.run([cost_var] + output_vars, feed_dict=feed)
results = sess.run([cost_var], feed_dict=feed)
cost = results[0]
outputs = results[1:]
return outputs, cost
return cost
return run_input
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class DatasetPredictor(object):
def __init__(self, predict_config, dataset, batch=0):
"""
......@@ -118,7 +121,7 @@ class DatasetPredictor(object):
""" a generator to return prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
yield [dp, self.predict_func(dp)]
yield PredictResult(dp, self.predict_func(dp))
pbar.update()
def get_all_result(self):
......
......@@ -36,8 +36,8 @@ class TrainConfig(object):
the dataset
input_queue: the queue used for input. default to a FIFO queue
with capacity 5
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
get_model_func: a function taking `inputs` and `is_training`, and
return the cost to minimize
batched_model_input: boolean. If yes, `get_model_func` expected batched
input in training. Otherwise, expect single data point in
training, so that you may do pre-processing and batch them
......@@ -127,7 +127,7 @@ def start_train(config):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
cost_var = config.get_model_func(model_inputs, is_training=True)
grads.append(
config.optimizer.compute_gradients(cost_var))
......@@ -141,7 +141,7 @@ def start_train(config):
grads = average_gradients(grads)
else:
model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
cost_var = config.get_model_func(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads)
avg_maintain_op = summary_moving_average(cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment