Commit 582cd482 authored by Yuxin Wu's avatar Yuxin Wu

split tfutils.common because it's too large. improve logging

parent 596d8008
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict from collections import defaultdict
import time import time
import traceback
from .base import Callback from .base import Callback
from .stats import StatPrinter from .stats import StatPrinter
...@@ -81,7 +82,11 @@ class Callbacks(Callback): ...@@ -81,7 +82,11 @@ class Callbacks(Callback):
def _after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.after_train() # make sure callbacks are properly finalized
try:
cb.after_train()
except Exception:
traceback.print_exc()
def _extra_fetches(self): def _extra_fetches(self):
if self._extra_fetches_cache is not None: if self._extra_fetches_cache is not None:
......
...@@ -10,7 +10,8 @@ from six.moves import zip, range ...@@ -10,7 +10,8 @@ from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, SUMMARY_BACKUP_KEYS from ..utils import logger, get_tqdm, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name, freeze_collection from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import freeze_collection
from ..tfutils import TowerContext from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph from ..predict import build_prediction_graph
......
...@@ -77,7 +77,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name)) ...@@ -77,7 +77,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.path, self.path,
global_step=tf.train.get_global_step(), global_step=tf.train.get_global_step(),
write_meta_graph=False) write_meta_graph=False)
logger.info("Model saved to %s" % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path) logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError): # disk error sometimes.. just ignore it except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
......
...@@ -15,7 +15,7 @@ from ..utils.naming import ( ...@@ -15,7 +15,7 @@ from ..utils.naming import (
MOVING_SUMMARY_VARS_KEY, MOVING_SUMMARY_VARS_KEY,
GLOBAL_STEP_INCR_VAR_NAME, GLOBAL_STEP_INCR_VAR_NAME,
LOCAL_STEP_OP_NAME) LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value
from .base import Callback from .base import Callback
__all__ = ['StepStatPrinter', 'MaintainStepCounter', __all__ = ['StepStatPrinter', 'MaintainStepCounter',
...@@ -59,6 +59,11 @@ class MaintainStepCounter(Callback): ...@@ -59,6 +59,11 @@ class MaintainStepCounter(Callback):
self.gs_incr_var, self.trainer.config.step_per_epoch, self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME) name=LOCAL_STEP_OP_NAME)
def _before_train(self):
gs_val = get_global_step_value()
if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val))
def _extra_fetches(self): def _extra_fetches(self):
return [self.gs_incr_var.op] return [self.gs_incr_var.op]
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: collection.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def backup_collection(keys):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup = backup_collection(keys)
yield
restore_collection(backup)
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME
from ..utils.argtools import memoized from ..utils.argtools import memoized
...@@ -17,10 +14,6 @@ __all__ = ['get_default_sess_config', ...@@ -17,10 +14,6 @@ __all__ = ['get_default_sess_config',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
'get_op_or_tensor_by_name', 'get_op_or_tensor_by_name',
'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection',
'get_tf_version', 'get_tf_version',
'get_name_scope_name' 'get_name_scope_name'
] ]
...@@ -117,56 +110,6 @@ def get_op_or_tensor_by_name(name): ...@@ -117,56 +110,6 @@ def get_op_or_tensor_by_name(name):
return G.get_operation_by_name(name) return G.get_operation_by_name(name)
def backup_collection(keys):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup = backup_collection(keys)
yield
restore_collection(backup)
def get_tf_version(): def get_tf_version():
""" """
Returns: Returns:
......
...@@ -23,7 +23,7 @@ def describe_model(): ...@@ -23,7 +23,7 @@ def describe_model():
v.name, shape.as_list(), ele)) v.name, shape.as_list(), ele))
size_mb = total * 4 / 1024.0**2 size_mb = total * 4 / 1024.0**2
msg.append(colored( msg.append(colored(
"Total param={} ({:01f} MB assuming all float32)".format(total, size_mb), 'cyan')) "Total #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan'))
logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg)) logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg))
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re import re
import time
import weakref import weakref
import six import six
from six.moves import range from six.moves import range
...@@ -11,7 +12,6 @@ from six.moves import range ...@@ -11,7 +12,6 @@ from six.moves import range
import tensorflow as tf import tensorflow as tf
from .config import TrainConfig from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.timer import timed_operation
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
...@@ -177,28 +177,28 @@ class Trainer(object): ...@@ -177,28 +177,28 @@ class Trainer(object):
with self.sess.as_default(): with self.sess.as_default():
try: try:
callbacks.before_train() callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step_value()))
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation( logger.info("Start Epoch {} ...".format(self.epoch_num))
'Epoch {} (global_step {})'.format( start_time = time.time()
self.epoch_num, get_global_step_value() + self.config.step_per_epoch), for self.step_num in range(self.config.step_per_epoch):
log_start=True): if self.coord.should_stop():
for self.step_num in range(self.config.step_per_epoch): return
if self.coord.should_stop(): fetch_data = self.run_step() # implemented by subclass
return if fetch_data is None:
fetch_data = self.run_step() # implemented by subclass # old trainer doesn't return fetch data
if fetch_data is None: callbacks.trigger_step()
# the old Trainer else:
callbacks.trigger_step() callbacks.trigger_step(*fetch_data)
else: logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
callbacks.trigger_step(*fetch_data) self.epoch_num, get_global_step_value(), time.time() - start_time))
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self.trigger_epoch() self.trigger_epoch()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Detected Ctrl+C and shutdown training.") logger.info("Detected Ctrl-C and exiting main loop.")
except: except:
raise raise
finally: finally:
......
...@@ -11,7 +11,8 @@ from six.moves import zip, range ...@@ -11,7 +11,8 @@ from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils import (backup_collection, restore_collection, TowerContext) from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .base import Trainer from .base import Trainer
......
...@@ -7,9 +7,8 @@ import tensorflow as tf ...@@ -7,9 +7,8 @@ import tensorflow as tf
from .base import Trainer from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, from ..tfutils import get_tensors_by_names, TowerContext
freeze_collection, from ..tfutils.collection import freeze_collection
TowerContext)
from ..predict import OnlinePredictor, build_prediction_graph from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput from .input_data import FeedInput
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment