Commit 76514ac7 authored by Yuxin Wu's avatar Yuxin Wu

avoid consuming data for summary, in SimpleTrainer

parent 4cb898a4
......@@ -161,6 +161,8 @@ class ModelFromMetaGraph(ModelDesc):
Only useful for inference.
"""
# TODO can this be really used for inference?
def __init__(self, filename):
"""
Args:
......
......@@ -25,8 +25,8 @@ class FeedfreeTrainerBase(Trainer):
"""
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
# run summary_op every epoch
# TODO summary_op will take a data! This is not good for TensorInput.
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self.add_summary(summary_str)
......
......@@ -45,8 +45,12 @@ class FeedInput(InputData):
def next_feed(self):
data = next(self.data_producer)
feed = dict(zip(self.input_vars, data))
self._last_feed = feed
return feed
def last_feed(self):
return self._last_feed
class FeedfreeInput(InputData):
""" Abstract base for input without feed,
......
......@@ -66,7 +66,7 @@ class SimpleTrainer(Trainer):
self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if config.dataflow is None:
self._input_method = config.data
assert isinstance(self._input_method, FeedInput)
assert isinstance(self._input_method, FeedInput), type(self._input_method)
else:
self._input_method = FeedInput(config.dataflow)
......@@ -93,7 +93,7 @@ class SimpleTrainer(Trainer):
def _trigger_epoch(self):
if self.summary_op is not None:
feed = self._input_method.next_feed()
feed = self._input_method.last_feed()
summary_str = self.summary_op.eval(feed_dict=feed)
self.add_summary(summary_str)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment