Commit 646a3c6f authored by ppwwyyxx's avatar ppwwyyxx

timing for callback

parent 65a3052f
......@@ -12,9 +12,10 @@ import tensorflow as tf
import numpy as np
import os
from utils import logger
from layers import *
from utils import *
from utils.symbolic_functions import *
from utils.summary import *
from dataflow.dataset import Mnist
from dataflow import *
......
......@@ -77,6 +77,7 @@ def start_train(config):
keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)):
for dp in dataset_train.get_data():
feed = {keep_prob_var: 0.5}
feed.update(dict(zip(input_vars, dp)))
......
......@@ -5,7 +5,10 @@
from pkgutil import walk_packages
import os
import os.path
import time
import sys
from contextlib import contextmanager
import logger
def global_import(name):
p = __import__(name, globals(), locals())
......@@ -13,7 +16,16 @@ def global_import(name):
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
global_import('naming')
global_import('callback')
global_import('validation_callback')
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start))
......@@ -7,9 +7,11 @@ import tensorflow as tf
import sys
import numpy as np
import os
import time
from abc import abstractmethod
from .naming import *
import logger
class Callback(object):
def before_train(self):
......@@ -107,7 +109,22 @@ class Callbacks(Callback):
cb.trigger_step(inputs, outputs, cost)
def trigger_epoch(self):
start = time.time()
times = []
for cb in self.callbacks:
s = time.time()
cb.trigger_epoch()
times.append(time.time() - s)
self.writer.flush()
tot = time.time() - start
# log the time of some heavy callbacks
if tot < 3:
return
msgs = []
for idx, t in enumerate(times):
if t / tot > 0.3 and t > 1:
msgs.append("{}:{}".format(
type(self.callbacks[idx]).__name__, t))
logger.info("Callbacks took {} sec. {}".format(tot, ' '.join(msgs)))
......@@ -4,8 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
__all__ = ['StatCounter', 'Accuracy']
class StatCounter(object):
def __init__(self):
self.reset()
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# File: summary.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
__all__ = ['create_summary', 'add_histogram_summary', 'add_activation_summary']
def create_summary(name, v):
"""
Return a tf.Summary object with name and simple value v
......
......@@ -65,5 +65,5 @@ class ValidationError(PeriodicCallback):
cost_avg),
self.epoch_num)
logger.info(
"{} validation after epoch {}: err={}, cost={}".format(
"{} validation after epoch {}: err={:.4f}, cost={:.3f}".format(
self.prefix, self.epoch_num, err_stat.accuracy, cost_avg))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment