Commit e68ea2a0 authored by Yuxin Wu's avatar Yuxin Wu

some improvements in logging and fix #118.

parent 582cd482
...@@ -15,6 +15,3 @@ from tensorpack.predict import * ...@@ -15,6 +15,3 @@ from tensorpack.predict import *
if int(numpy.__version__.split('.')[1]) < 9: if int(numpy.__version__.split('.')[1]) < 9:
logger.warn("Numpy < 1.9 could be extremely slow on some tasks.") logger.warn("Numpy < 1.9 could be extremely slow on some tasks.")
if get_tf_version() < 10:
logger.error("tensorpack requires TensorFlow >= 0.10")
...@@ -168,7 +168,6 @@ class PeriodicCallback(ProxyCallback): ...@@ -168,7 +168,6 @@ class PeriodicCallback(ProxyCallback):
def _trigger_epoch(self): def _trigger_epoch(self):
if self.epoch_num % self.period == 0: if self.epoch_num % self.period == 0:
self.cb.epoch_num = self.epoch_num - 1
self.cb.trigger_epoch() self.cb.trigger_epoch()
def __str__(self): def __str__(self):
......
...@@ -27,6 +27,6 @@ class StartProcOrThread(Callback): ...@@ -27,6 +27,6 @@ class StartProcOrThread(Callback):
def _before_train(self): def _before_train(self):
logger.info("Starting " + logger.info("Starting " +
', '.join([k.name for k in self._procs_threads])) ', '.join([k.name for k in self._procs_threads]) + ' ...')
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads) start_proc_mask_signal(self._procs_threads)
...@@ -13,6 +13,9 @@ __all__ = ['describe_model', 'get_shape_str'] ...@@ -13,6 +13,9 @@ __all__ = ['describe_model', 'get_shape_str']
def describe_model(): def describe_model():
""" Print a description of the current model parameters """ """ Print a description of the current model parameters """
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
if len(train_vars) == 0:
logger.info("No trainable variables in the graph!")
return
msg = [""] msg = [""]
total = 0 total = 0
for v in train_vars: for v in train_vars:
......
...@@ -68,25 +68,14 @@ class Trainer(object): ...@@ -68,25 +68,14 @@ class Trainer(object):
def run_step(self): def run_step(self):
""" Abstract method. Run one iteration. """ """ Abstract method. Run one iteration. """
def get_predict_func(self, input_names, output_names): def get_extra_fetches(self):
""" """
Args:
input_names (list), output_names(list): list of names
Returns: Returns:
an OnlinePredictor list: list of tensors/ops to fetch in each step.
"""
raise NotImplementedError()
def get_predict_funcs(self, input_names, output_names, n): This function should only get called after :meth:`setup()` has finished.
""" Return n predictors.
Can be overwritten by subclasses to exploit more
parallelism among predictors.
""" """
if len(self.config.predict_tower) > 1: return self._extra_fetches
logger.warn(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation")
return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
""" """
...@@ -129,28 +118,21 @@ class Trainer(object): ...@@ -129,28 +118,21 @@ class Trainer(object):
""" """
self.add_summary(create_scalar_summary(name, val)) self.add_summary(create_scalar_summary(name, val))
def get_extra_fetches(self):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
return self._extra_fetches
def setup(self): def setup(self):
""" """
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
""" """
self._setup() if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
self._setup() # subclass will setup the graph
describe_model() describe_model()
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches() self._extra_fetches = self.config.callbacks.extra_fetches()
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=self.sess.graph) self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.summary.merge_all() self.summary_op = tf.summary.merge_all()
# create an empty StatHolder # create an empty StatHolder
...@@ -206,3 +188,23 @@ class Trainer(object): ...@@ -206,3 +188,23 @@ class Trainer(object):
self.coord.request_stop() self.coord.request_stop()
self.summary_writer.close() self.summary_writer.close()
self.sess.close() self.sess.close()
def get_predict_func(self, input_names, output_names):
"""
Args:
input_names (list), output_names(list): list of names
Returns:
an OnlinePredictor
"""
raise NotImplementedError()
def get_predict_funcs(self, input_names, output_names, n):
""" Return n predictors.
Can be overwritten by subclasses to exploit more
parallelism among predictors.
"""
if len(self.config.predict_tower) > 1:
logger.warn(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation")
return [self.get_predict_func(input_names, output_names) for k in range(n)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment