Commit 2bec1d62 authored by Yuxin Wu's avatar Yuxin Wu

remove Trainer._process_summary and use a general Trainer.add_summary

parent 6d67faf9
...@@ -114,5 +114,5 @@ class Evaluator(Callback): ...@@ -114,5 +114,5 @@ class Evaluator(Callback):
t = time.time() - t t = time.time() - t
if t > 10 * 60: # eval takes too long if t > 10 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.94) self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.write_scalar_summary('mean_score', mean) self.trainer.add_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max) self.trainer.add_scalar_summary('max_score', max)
...@@ -135,7 +135,7 @@ def get_config(): ...@@ -135,7 +135,7 @@ def get_config():
FeedfreeInferenceRunner(val_data, [ScalarStats(['cost'])]), FeedfreeInferenceRunner(val_data, [ScalarStats(['cost'])]),
CallbackFactory( CallbackFactory(
trigger_epoch=lambda self: trigger_epoch=lambda self:
self.trainer.write_scalar_summary( self.trainer.add_scalar_summary(
'validation_perplexity', 'validation_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN))), np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN))),
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
......
...@@ -207,8 +207,8 @@ class ExpReplay(DataFlow, Callback): ...@@ -207,8 +207,8 @@ class ExpReplay(DataFlow, Callback):
for k, v in six.iteritems(stats): for k, v in six.iteritems(stats):
try: try:
mean, max = np.mean(v), np.max(v) mean, max = np.mean(v), np.max(v)
self.trainer.write_scalar_summary('expreplay/mean_' + k, mean) self.trainer.add_scalar_summary('expreplay/mean_' + k, mean)
self.trainer.write_scalar_summary('expreplay/max_' + k, max) self.trainer.add_scalar_summary('expreplay/max_' + k, max)
except: except:
pass pass
self.player.reset_stat() self.player.reset_stat()
......
...@@ -58,7 +58,7 @@ def summary_inferencer(trainer, infs): ...@@ -58,7 +58,7 @@ def summary_inferencer(trainer, infs):
except: except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__)) logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue continue
trainer.write_scalar_summary(k, v) trainer.add_scalar_summary(k, v)
class InferenceRunner(Callback): class InferenceRunner(Callback):
......
...@@ -13,11 +13,11 @@ from .tower import get_current_tower_context ...@@ -13,11 +13,11 @@ from .tower import get_current_tower_context
from . import get_global_step_var from . import get_global_step_var
from .symbolic_functions import rms from .symbolic_functions import rms
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary', __all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary', 'summary_moving_average'] 'add_moving_summary', 'summary_moving_average']
def create_summary(name, v): def create_scalar_summary(name, v):
""" """
Returns: Returns:
tf.Summary: a tf.Summary object with name and simple scalar value v. tf.Summary: a tf.Summary object with name and simple scalar value v.
......
...@@ -16,7 +16,7 @@ from ..utils.timer import timed_operation ...@@ -16,7 +16,7 @@ from ..utils.timer import timed_operation
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_summary from ..tfutils.summary import create_scalar_summary
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -96,8 +96,14 @@ class Trainer(object): ...@@ -96,8 +96,14 @@ class Trainer(object):
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
def _process_summary(self, summary_str): def add_summary(self, summary):
summary = tf.Summary.FromString(summary_str) """
Add summary to ``self.summary_writer``, and also
add scalar summary to ``self.stat_holder``.
Args:
summary (tf.Summary): a summary object.
"""
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
...@@ -107,17 +113,15 @@ class Trainer(object): ...@@ -107,17 +113,15 @@ class Trainer(object):
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, get_global_step()) self.summary_writer.add_summary(summary, get_global_step())
def write_scalar_summary(self, name, val): def add_scalar_summary(self, name, val):
""" """
Write a scalar sumary to both TF events file and StatHolder. Add a scalar sumary to both TF events file and StatHolder.
Args: Args:
name(str) name(str)
val(float) val(float)
""" """
self.summary_writer.add_summary( self.add_summary(create_scalar_summary(name, val))
create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val)
def setup(self): def setup(self):
""" """
......
...@@ -29,7 +29,7 @@ class FeedfreeTrainerBase(Trainer): ...@@ -29,7 +29,7 @@ class FeedfreeTrainerBase(Trainer):
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self.add_summary(tf.Summary.FromString(summary_str))
def _get_input_tensors(self): def _get_input_tensors(self):
return self._input_method.get_input_tensors() return self._input_method.get_input_tensors()
......
...@@ -209,7 +209,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -209,7 +209,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if self.config.tower > 1: if self.config.tower > 1:
async_step_total_cnt = int(re.findall( async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0]) '[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary( self.add_scalar_summary(
'async_global_step', async_step_total_cnt) 'async_global_step', async_step_total_cnt)
except: except:
logger.exception("Cannot log async_global_step") logger.exception("Cannot log async_global_step")
......
...@@ -95,7 +95,7 @@ class SimpleTrainer(Trainer): ...@@ -95,7 +95,7 @@ class SimpleTrainer(Trainer):
if self.summary_op is not None: if self.summary_op is not None:
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
summary_str = self.summary_op.eval(feed_dict=feed) summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str) self.add_summary(tf.Summary.FromString(summary_str))
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0) return self._predictor_factory.get_predictor(input_names, output_names, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment