Commit 582cd482 authored by Yuxin Wu's avatar Yuxin Wu

split tfutils.common because it's too large. improve logging

parent 596d8008
......@@ -6,6 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager
from collections import defaultdict
import time
import traceback
from .base import Callback
from .stats import StatPrinter
......@@ -81,7 +82,11 @@ class Callbacks(Callback):
def _after_train(self):
for cb in self.cbs:
# make sure callbacks are properly finalized
try:
cb.after_train()
except Exception:
traceback.print_exc()
def _extra_fetches(self):
if self._extra_fetches_cache is not None:
......
......@@ -10,7 +10,8 @@ from six.moves import zip, range
from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name, freeze_collection
from ..tfutils.common import get_op_tensor_name
from ..tfutils.collection import freeze_collection
from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph
......
......@@ -77,7 +77,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self.path,
global_step=tf.train.get_global_step(),
write_meta_graph=False)
logger.info("Model saved to %s" % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver.trigger_epoch!")
......
......@@ -15,7 +15,7 @@ from ..utils.naming import (
MOVING_SUMMARY_VARS_KEY,
GLOBAL_STEP_INCR_VAR_NAME,
LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var
from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value
from .base import Callback
__all__ = ['StepStatPrinter', 'MaintainStepCounter',
......@@ -59,6 +59,11 @@ class MaintainStepCounter(Callback):
self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME)
def _before_train(self):
gs_val = get_global_step_value()
if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val))
def _extra_fetches(self):
return [self.gs_incr_var.op]
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: collection.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def backup_collection(keys):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup = backup_collection(keys)
yield
restore_collection(backup)
......@@ -4,9 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME
from ..utils.argtools import memoized
......@@ -17,10 +14,6 @@ __all__ = ['get_default_sess_config',
'get_op_tensor_name',
'get_tensors_by_names',
'get_op_or_tensor_by_name',
'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection',
'get_tf_version',
'get_name_scope_name'
]
......@@ -117,56 +110,6 @@ def get_op_or_tensor_by_name(name):
return G.get_operation_by_name(name)
def backup_collection(keys):
"""
Args:
keys (list): list of collection keys to backup
Returns:
dict: the backup
"""
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
"""
Restore from a collection backup.
Args:
backup (dict):
"""
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
"""
Clear some collections.
Args:
keys(list): list of collection keys.
"""
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
"""
Args:
keys(list): list of collection keys to freeze.
Returns:
a context where the collections are in the end restored to its initial state.
"""
backup = backup_collection(keys)
yield
restore_collection(backup)
def get_tf_version():
"""
Returns:
......
......@@ -23,7 +23,7 @@ def describe_model():
v.name, shape.as_list(), ele))
size_mb = total * 4 / 1024.0**2
msg.append(colored(
"Total param={} ({:01f} MB assuming all float32)".format(total, size_mb), 'cyan'))
"Total #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan'))
logger.info(colored("Model Parameters: ", 'cyan') + '\n'.join(msg))
......
......@@ -4,6 +4,7 @@
from abc import ABCMeta, abstractmethod
import re
import time
import weakref
import six
from six.moves import range
......@@ -11,7 +12,6 @@ from six.moves import range
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger
from ..utils.timer import timed_operation
from ..callbacks import StatHolder
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
......@@ -177,28 +177,28 @@ class Trainer(object):
with self.sess.as_default():
try:
callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step_value()))
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation(
'Epoch {} (global_step {})'.format(
self.epoch_num, get_global_step_value() + self.config.step_per_epoch),
log_start=True):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
for self.step_num in range(self.config.step_per_epoch):
if self.coord.should_stop():
return
fetch_data = self.run_step() # implemented by subclass
if fetch_data is None:
# the old Trainer
# old trainer doesn't return fetch data
callbacks.trigger_step()
else:
callbacks.trigger_step(*fetch_data)
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, get_global_step_value(), time.time() - start_time))
# trigger epoch outside the timing region.
self.trigger_epoch()
except StopTraining:
logger.info("Training was stopped.")
except KeyboardInterrupt:
logger.info("Detected Ctrl+C and shutdown training.")
logger.info("Detected Ctrl-C and exiting main loop.")
except:
raise
finally:
......
......@@ -11,7 +11,8 @@ from six.moves import zip, range
from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils import (backup_collection, restore_collection, TowerContext)
from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .base import Trainer
......
......@@ -7,9 +7,8 @@ import tensorflow as tf
from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names,
freeze_collection,
TowerContext)
from ..tfutils import get_tensors_by_names, TowerContext
from ..tfutils.collection import freeze_collection
from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment