Commit ee0bca2d authored by ppwwyyxx's avatar ppwwyyxx

use 'dataset' as key

parent 3629d9ca
...@@ -113,7 +113,7 @@ def get_config(): ...@@ -113,7 +113,7 @@ def get_config():
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return dict( return dict(
dataset_train=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callback=Callbacks([ callback=Callbacks([
SummaryWriter(), SummaryWriter(),
......
...@@ -126,7 +126,7 @@ def get_config(): ...@@ -126,7 +126,7 @@ def get_config():
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return dict( return dict(
dataset_train=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callback=Callbacks([ callback=Callbacks([
SummaryWriter(), SummaryWriter(),
......
...@@ -18,8 +18,8 @@ def start_infer(config): ...@@ -18,8 +18,8 @@ def start_infer(config):
Args: Args:
config: a tensorpack config dictionary config: a tensorpack config dictionary
""" """
dataset_test = config['dataset_test'] dataset = config['dataset']
assert isinstance(dataset_test, DataFlow), dataset_test.__class__ assert isinstance(dataset, DataFlow), dataset.__class__
# a tf.ConfigProto instance # a tf.ConfigProto instance
sess_config = config.get('session_config', None) sess_config = config.get('session_config', None)
...@@ -53,7 +53,7 @@ def start_infer(config): ...@@ -53,7 +53,7 @@ def start_infer(config):
with sess.as_default(): with sess.as_default():
with timed_operation('running one batch'): with timed_operation('running one batch'):
for dp in dataset_test.get_data(): for dp in dataset.get_data():
feed = dict(zip(input_vars, dp)) feed = dict(zip(input_vars, dp))
fetches = [cost_var] + output_vars fetches = [cost_var] + output_vars
results = sess.run(fetches, feed_dict=feed) results = sess.run(fetches, feed_dict=feed)
......
...@@ -39,8 +39,8 @@ def start_train(config): ...@@ -39,8 +39,8 @@ def start_train(config):
Args: Args:
config: a tensorpack config dictionary config: a tensorpack config dictionary
""" """
dataset_train = config['dataset_train'] dataset = config['dataset']
assert isinstance(dataset_train, DataFlow), dataset_train.__class__ assert isinstance(dataset, DataFlow), dataset.__class__
# a tf.train.Optimizer instance # a tf.train.Optimizer instance
optimizer = config['optimizer'] optimizer = config['optimizer']
...@@ -90,7 +90,7 @@ def start_train(config): ...@@ -90,7 +90,7 @@ def start_train(config):
# start training: # start training:
coord = tf.train.Coordinator() coord = tf.train.Coordinator()
# a thread that keeps filling the queue # a thread that keeps filling the queue
input_th = EnqueueThread(sess, coord, enqueue_op, dataset_train) input_th = EnqueueThread(sess, coord, enqueue_op, dataset)
model_th = tf.train.start_queue_runners( model_th = tf.train.start_queue_runners(
sess=sess, coord=coord, daemon=True, start=True) sess=sess, coord=coord, daemon=True, start=True)
input_th.start() input_th.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment