Commit 2bec1d62 authored by Yuxin Wu's avatar Yuxin Wu

remove Trainer._process_summary and use a general Trainer.add_summary

parent 6d67faf9
......@@ -114,5 +114,5 @@ class Evaluator(Callback):
t = time.time() - t
if t > 10 * 60: # eval takes too long
self.eval_episode = int(self.eval_episode * 0.94)
self.trainer.write_scalar_summary('mean_score', mean)
self.trainer.write_scalar_summary('max_score', max)
self.trainer.add_scalar_summary('mean_score', mean)
self.trainer.add_scalar_summary('max_score', max)
......@@ -135,7 +135,7 @@ def get_config():
FeedfreeInferenceRunner(val_data, [ScalarStats(['cost'])]),
CallbackFactory(
trigger_epoch=lambda self:
self.trainer.write_scalar_summary(
self.trainer.add_scalar_summary(
'validation_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN))),
RunOp(lambda: M.reset_lstm_state()),
......
......@@ -207,8 +207,8 @@ class ExpReplay(DataFlow, Callback):
for k, v in six.iteritems(stats):
try:
mean, max = np.mean(v), np.max(v)
self.trainer.write_scalar_summary('expreplay/mean_' + k, mean)
self.trainer.write_scalar_summary('expreplay/max_' + k, max)
self.trainer.add_scalar_summary('expreplay/mean_' + k, mean)
self.trainer.add_scalar_summary('expreplay/max_' + k, max)
except:
pass
self.player.reset_stat()
......
......@@ -58,7 +58,7 @@ def summary_inferencer(trainer, infs):
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
trainer.write_scalar_summary(k, v)
trainer.add_scalar_summary(k, v)
class InferenceRunner(Callback):
......
......@@ -13,11 +13,11 @@ from .tower import get_current_tower_context
from . import get_global_step_var
from .symbolic_functions import rms
__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
__all__ = ['create_scalar_summary', 'add_param_summary', 'add_activation_summary',
'add_moving_summary', 'summary_moving_average']
def create_summary(name, v):
def create_scalar_summary(name, v):
"""
Returns:
tf.Summary: a tf.Summary object with name and simple scalar value v.
......
......@@ -16,7 +16,7 @@ from ..utils.timer import timed_operation
from ..callbacks import StatHolder
from ..tfutils import get_global_step, get_global_step_var
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_summary
from ..tfutils.summary import create_scalar_summary
__all__ = ['Trainer', 'StopTraining']
......@@ -96,8 +96,14 @@ class Trainer(object):
def _trigger_epoch(self):
pass
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
def add_summary(self, summary):
"""
Add summary to ``self.summary_writer``, and also
add scalar summary to ``self.stat_holder``.
Args:
summary (tf.Summary): a summary object.
"""
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
......@@ -107,17 +113,15 @@ class Trainer(object):
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, get_global_step())
def write_scalar_summary(self, name, val):
def add_scalar_summary(self, name, val):
"""
Write a scalar sumary to both TF events file and StatHolder.
Add a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
"""
self.summary_writer.add_summary(
create_summary(name, val), get_global_step())
self.stat_holder.add_stat(name, val)
self.add_summary(create_scalar_summary(name, val))
def setup(self):
"""
......
......@@ -29,7 +29,7 @@ class FeedfreeTrainerBase(Trainer):
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
self.add_summary(tf.Summary.FromString(summary_str))
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
......
......@@ -209,7 +209,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if self.config.tower > 1:
async_step_total_cnt = int(re.findall(
'[0-9]+', self.async_step_counter.__str__())[0])
self.write_scalar_summary(
self.add_scalar_summary(
'async_global_step', async_step_total_cnt)
except:
logger.exception("Cannot log async_global_step")
......
......@@ -95,7 +95,7 @@ class SimpleTrainer(Trainer):
if self.summary_op is not None:
feed = self._input_method.next_feed()
summary_str = self.summary_op.eval(feed_dict=feed)
self._process_summary(summary_str)
self.add_summary(tf.Summary.FromString(summary_str))
def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment