Commit 50859d25 authored by Yuxin Wu's avatar Yuxin Wu

don't need get_model to return output variable

parent 27ff6a18
...@@ -15,7 +15,7 @@ from tensorpack.models import * ...@@ -15,7 +15,7 @@ from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
BATCH_SIZE = 10 BATCH_SIZE = 10
...@@ -25,7 +25,7 @@ CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE ...@@ -25,7 +25,7 @@ CAPACITY = MIN_AFTER_DEQUEUE + 3 * BATCH_SIZE
def get_model(inputs, is_training): def get_model(inputs, is_training):
# img: 227x227x3 # img: 227x227x3
is_training = bool(is_training) is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 0.0) keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs image, label = inputs
...@@ -73,7 +73,7 @@ def get_model(inputs, is_training): ...@@ -73,7 +73,7 @@ def get_model(inputs, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W add_param_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
......
...@@ -83,7 +83,7 @@ def get_model(inputs, is_training): ...@@ -83,7 +83,7 @@ def get_model(inputs, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*') # monitor all variables add_param_summary('.*') # monitor all variables
return [prob, nr_wrong], tf.add_n([cost, wd_cost], name='cost') return tf.add_n([cost, wd_cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
......
...@@ -31,12 +31,10 @@ def get_model(inputs, is_training): ...@@ -31,12 +31,10 @@ def get_model(inputs, is_training):
label_var: bx1 integer label_var: bx1 integer
is_training: a python bool variable is_training: a python bool variable
Returns: Returns:
(outputs, cost) the cost to minimize. scalar variable
outputs: a list of output variable
cost: the cost to minimize. scalar variable
""" """
is_training = bool(is_training) is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 0.0) keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
...@@ -83,7 +81,7 @@ def get_model(inputs, is_training): ...@@ -83,7 +81,7 @@ def get_model(inputs, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary('.*/W') # monitor histogram of all W add_param_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -126,6 +124,7 @@ def get_config(): ...@@ -126,6 +124,7 @@ def get_config():
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(print_tag=['train_cost', 'train_error']), SummaryWriter(print_tag=['train_cost', 'train_error']),
PeriodicSaver(), PeriodicSaver(),
#ValidationCallback(dataset_test, 'test')
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
......
...@@ -31,7 +31,7 @@ def create_test_graph(): ...@@ -31,7 +31,7 @@ def create_test_graph():
)) ))
for v in input_vars: for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v) Gtest.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost = forward_func(input_vars, is_training=False) cost = forward_func(input_vars, is_training=False)
yield Gtest yield Gtest
@contextmanager @contextmanager
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from itertools import count from itertools import count
import argparse import argparse
from collections import namedtuple
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
...@@ -43,8 +44,8 @@ class PredictConfig(object): ...@@ -43,8 +44,8 @@ class PredictConfig(object):
return a tuple of output list as well as the cost to minimize return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the output_var_names: a list of names of the output variable to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
if None, will predict everything returned by `get_model_func` if None, will only calculate the cost returned by `get_model_func`.
(all outputs as well as the cost). Predict only specific output Predict only specific output (instead of the cost)
might be faster and might require only some of the input variables. might be faster and might require only some of the input variables.
""" """
def assert_type(v, tp): def assert_type(v, tp):
...@@ -66,14 +67,13 @@ def get_predict_func(config): ...@@ -66,14 +67,13 @@ def get_predict_func(config):
A prediction function that takes a list of inputs value, and return A prediction function that takes a list of inputs value, and return
one/a list of output values. one/a list of output values.
If `output_var_names` is set, then the prediction function will If `output_var_names` is set, then the prediction function will
return a list of output values. If not, will return a list of output return a list of output values. If not, will return a cost.
values and a cost.
""" """
output_var_names = config.output_var_names output_var_names = config.output_var_names
# input/output variables # input/output variables
input_vars = config.inputs input_vars = config.inputs
output_vars, cost_var = config.get_model_func(input_vars, is_training=False) cost_var = config.get_model_func(input_vars, is_training=False)
input_map = config.input_dataset_mapping input_map = config.input_dataset_mapping
if input_map is None: if input_map is None:
input_map = input_vars input_map = input_vars
...@@ -81,6 +81,8 @@ def get_predict_func(config): ...@@ -81,6 +81,8 @@ def get_predict_func(config):
# check output_var_names against output_vars # check output_var_names against output_vars
if output_var_names is not None: if output_var_names is not None:
output_vars = [tf.get_default_graph().get_tensor_by_name(n) for n in output_var_names] output_vars = [tf.get_default_graph().get_tensor_by_name(n) for n in output_var_names]
else:
output_vars = []
describe_model() describe_model()
...@@ -96,12 +98,13 @@ def get_predict_func(config): ...@@ -96,12 +98,13 @@ def get_predict_func(config):
results = sess.run(output_vars, feed_dict=feed) results = sess.run(output_vars, feed_dict=feed)
return results return results
else: else:
results = sess.run([cost_var] + output_vars, feed_dict=feed) results = sess.run([cost_var], feed_dict=feed)
cost = results[0] cost = results[0]
outputs = results[1:] return cost
return outputs, cost
return run_input return run_input
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class DatasetPredictor(object): class DatasetPredictor(object):
def __init__(self, predict_config, dataset, batch=0): def __init__(self, predict_config, dataset, batch=0):
""" """
...@@ -118,7 +121,7 @@ class DatasetPredictor(object): ...@@ -118,7 +121,7 @@ class DatasetPredictor(object):
""" a generator to return prediction for each data""" """ a generator to return prediction for each data"""
with tqdm(total=self.ds.size()) as pbar: with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield [dp, self.predict_func(dp)] yield PredictResult(dp, self.predict_func(dp))
pbar.update() pbar.update()
def get_all_result(self): def get_all_result(self):
......
...@@ -36,8 +36,8 @@ class TrainConfig(object): ...@@ -36,8 +36,8 @@ class TrainConfig(object):
the dataset the dataset
input_queue: the queue used for input. default to a FIFO queue input_queue: the queue used for input. default to a FIFO queue
with capacity 5 with capacity 5
get_model_func: a function taking `inputs` and `is_training` and get_model_func: a function taking `inputs` and `is_training`, and
return a tuple of output list as well as the cost to minimize return the cost to minimize
batched_model_input: boolean. If yes, `get_model_func` expected batched batched_model_input: boolean. If yes, `get_model_func` expected batched
input in training. Otherwise, expect single data point in input in training. Otherwise, expect single data point in
training, so that you may do pre-processing and batch them training, so that you may do pre-processing and batch them
...@@ -127,7 +127,7 @@ def start_train(config): ...@@ -127,7 +127,7 @@ def start_train(config):
with tf.device('/gpu:{}'.format(i)), \ with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope: tf.name_scope('tower{}'.format(i)) as scope:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True) cost_var = config.get_model_func(model_inputs, is_training=True)
grads.append( grads.append(
config.optimizer.compute_gradients(cost_var)) config.optimizer.compute_gradients(cost_var))
...@@ -141,7 +141,7 @@ def start_train(config): ...@@ -141,7 +141,7 @@ def start_train(config):
grads = average_gradients(grads) grads = average_gradients(grads)
else: else:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True) cost_var = config.get_model_func(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var) grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads) summary_grads(grads)
avg_maintain_op = summary_moving_average(cost_var) avg_maintain_op = summary_moving_average(cost_var)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment