Commit 76514ac7 authored by Yuxin Wu's avatar Yuxin Wu

avoid consuming data for summary, in SimpleTrainer

parent 4cb898a4
...@@ -161,6 +161,8 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -161,6 +161,8 @@ class ModelFromMetaGraph(ModelDesc):
Only useful for inference. Only useful for inference.
""" """
# TODO can this be really used for inference?
def __init__(self, filename): def __init__(self, filename):
""" """
Args: Args:
......
...@@ -25,8 +25,8 @@ class FeedfreeTrainerBase(Trainer): ...@@ -25,8 +25,8 @@ class FeedfreeTrainerBase(Trainer):
""" """
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch # run summary_op every epoch
# note that summary_op will take a data from the queue # TODO summary_op will take a data! This is not good for TensorInput.
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self.add_summary(summary_str) self.add_summary(summary_str)
......
...@@ -45,8 +45,12 @@ class FeedInput(InputData): ...@@ -45,8 +45,12 @@ class FeedInput(InputData):
def next_feed(self): def next_feed(self):
data = next(self.data_producer) data = next(self.data_producer)
feed = dict(zip(self.input_vars, data)) feed = dict(zip(self.input_vars, data))
self._last_feed = feed
return feed return feed
def last_feed(self):
return self._last_feed
class FeedfreeInput(InputData): class FeedfreeInput(InputData):
""" Abstract base for input without feed, """ Abstract base for input without feed,
......
...@@ -66,7 +66,7 @@ class SimpleTrainer(Trainer): ...@@ -66,7 +66,7 @@ class SimpleTrainer(Trainer):
self._predictor_factory = PredictorFactory(self.sess, self.model, [0]) self._predictor_factory = PredictorFactory(self.sess, self.model, [0])
if config.dataflow is None: if config.dataflow is None:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, FeedInput) assert isinstance(self._input_method, FeedInput), type(self._input_method)
else: else:
self._input_method = FeedInput(config.dataflow) self._input_method = FeedInput(config.dataflow)
...@@ -93,7 +93,7 @@ class SimpleTrainer(Trainer): ...@@ -93,7 +93,7 @@ class SimpleTrainer(Trainer):
def _trigger_epoch(self): def _trigger_epoch(self):
if self.summary_op is not None: if self.summary_op is not None:
feed = self._input_method.next_feed() feed = self._input_method.last_feed()
summary_str = self.summary_op.eval(feed_dict=feed) summary_str = self.summary_op.eval(feed_dict=feed)
self.add_summary(summary_str) self.add_summary(summary_str)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment