Commit 646a3c6f authored by ppwwyyxx's avatar ppwwyyxx

timing for callback

parent 65a3052f
...@@ -12,9 +12,10 @@ import tensorflow as tf ...@@ -12,9 +12,10 @@ import tensorflow as tf
import numpy as np import numpy as np
import os import os
from utils import logger
from layers import * from layers import *
from utils import * from utils import *
from utils.symbolic_functions import *
from utils.summary import *
from dataflow.dataset import Mnist from dataflow.dataset import Mnist
from dataflow import * from dataflow import *
......
...@@ -77,14 +77,15 @@ def start_train(config): ...@@ -77,14 +77,15 @@ def start_train(config):
keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME) keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
for epoch in xrange(1, max_epoch): for epoch in xrange(1, max_epoch):
for dp in dataset_train.get_data(): with timed_operation('epoch {}'.format(epoch)):
feed = {keep_prob_var: 0.5} for dp in dataset_train.get_data():
feed.update(dict(zip(input_vars, dp))) feed = {keep_prob_var: 0.5}
feed.update(dict(zip(input_vars, dp)))
results = sess.run(
[train_op, cost_var] + output_vars, feed_dict=feed) results = sess.run(
cost = results[1] [train_op, cost_var] + output_vars, feed_dict=feed)
outputs = results[2:] cost = results[1]
callbacks.trigger_step(feed, outputs, cost) outputs = results[2:]
callbacks.trigger_step(feed, outputs, cost)
callbacks.trigger_epoch()
callbacks.trigger_epoch()
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
from pkgutil import walk_packages from pkgutil import walk_packages
import os import os
import os.path import time
import sys
from contextlib import contextmanager
import logger
def global_import(name): def global_import(name):
p = __import__(name, globals(), locals()) p = __import__(name, globals(), locals())
...@@ -13,7 +16,16 @@ def global_import(name): ...@@ -13,7 +16,16 @@ def global_import(name):
for k in lst: for k in lst:
globals()[k] = p.__dict__[k] globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages( global_import('naming')
[os.path.dirname(__file__)]): global_import('callback')
if not module_name.startswith('_'): global_import('validation_callback')
global_import(module_name)
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start))
...@@ -7,9 +7,11 @@ import tensorflow as tf ...@@ -7,9 +7,11 @@ import tensorflow as tf
import sys import sys
import numpy as np import numpy as np
import os import os
import time
from abc import abstractmethod from abc import abstractmethod
from .naming import * from .naming import *
import logger
class Callback(object): class Callback(object):
def before_train(self): def before_train(self):
...@@ -107,7 +109,22 @@ class Callbacks(Callback): ...@@ -107,7 +109,22 @@ class Callbacks(Callback):
cb.trigger_step(inputs, outputs, cost) cb.trigger_step(inputs, outputs, cost)
def trigger_epoch(self): def trigger_epoch(self):
start = time.time()
times = []
for cb in self.callbacks: for cb in self.callbacks:
s = time.time()
cb.trigger_epoch() cb.trigger_epoch()
times.append(time.time() - s)
self.writer.flush() self.writer.flush()
tot = time.time() - start
# log the time of some heavy callbacks
if tot < 3:
return
msgs = []
for idx, t in enumerate(times):
if t / tot > 0.3 and t > 1:
msgs.append("{}:{}".format(
type(self.callbacks[idx]).__name__, t))
logger.info("Callbacks took {} sec. {}".format(tot, ' '.join(msgs)))
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
__all__ = ['StatCounter', 'Accuracy']
class StatCounter(object): class StatCounter(object):
def __init__(self): def __init__(self):
self.reset() self.reset()
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: utils.py # File: summary.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
__all__ = ['create_summary', 'add_histogram_summary', 'add_activation_summary']
def create_summary(name, v): def create_summary(name, v):
""" """
Return a tf.Summary object with name and simple value v Return a tf.Summary object with name and simple value v
......
...@@ -65,5 +65,5 @@ class ValidationError(PeriodicCallback): ...@@ -65,5 +65,5 @@ class ValidationError(PeriodicCallback):
cost_avg), cost_avg),
self.epoch_num) self.epoch_num)
logger.info( logger.info(
"{} validation after epoch {}: err={}, cost={}".format( "{} validation after epoch {}: err={:.4f}, cost={:.3f}".format(
self.prefix, self.epoch_num, err_stat.accuracy, cost_avg)) self.prefix, self.epoch_num, err_stat.accuracy, cost_avg))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment