Commit e79d74f8 authored by Yuxin Wu's avatar Yuxin Wu

Add tests about scheduler

parent 940a1636
...@@ -52,16 +52,12 @@ before_script: ...@@ -52,16 +52,12 @@ before_script:
- protoc --version - protoc --version
- python -c "import cv2; print('OpenCV '+ cv2.__version__)" - python -c "import cv2; print('OpenCV '+ cv2.__version__)"
- python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)" - python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)"
# Check that these private names can be imported because tensorpack is using them - mkdir -p $HOME/tensorpack_data
- python -c "from tensorflow.python.client.session import _FetchHandler" - export TENSORPACK_DATASET=$HOME/tensorpack_data
- python -c "from tensorflow.python.training.monitored_session import _HookedSession"
- python -c "import tensorflow as tf; tf.Operation._add_control_input"
script: script:
- flake8 . - flake8 .
- if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then cd examples && flake8 .; fi # some examples are py3 only - if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then cd examples && flake8 .; fi # some examples are py3 only
- mkdir -p $HOME/tensorpack_data
- export TENSORPACK_DATASET=$HOME/tensorpack_data
- $TRAVIS_BUILD_DIR/tests/run-tests.sh - $TRAVIS_BUILD_DIR/tests/run-tests.sh
- cd $TRAVIS_BUILD_DIR # go back to root so that deploy may work - cd $TRAVIS_BUILD_DIR # go back to root so that deploy may work
......
### Design of Tensorpack's imgaug Module #### Design of Tensorpack's imgaug Module
The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage: The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage:
...@@ -22,7 +22,7 @@ The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the ...@@ -22,7 +22,7 @@ The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the
4. Reset random seed. Random seed can be reset by 4. Reset random seed. Random seed can be reset by
[reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state). [reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state).
This is important for multi-process data loading, and This is important for multi-process data loading, and
it is called automatically if you use tensorpack's the reset method is called automatically if you use tensorpack's
[image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent). [image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent).
### Write an Image Augmentor ### Write an Image Augmentor
......
...@@ -47,5 +47,7 @@ for _, module_name, _ in iter_modules( ...@@ -47,5 +47,7 @@ for _, module_name, _ in iter_modules(
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath): if not os.path.isfile(srcpath):
continue continue
if module_name.endswith('_test'):
continue
if not module_name.startswith('_'): if not module_name.startswith('_'):
_global_import(module_name) _global_import(module_name)
...@@ -103,7 +103,7 @@ class ObjAttrParam(HyperParam): ...@@ -103,7 +103,7 @@ class ObjAttrParam(HyperParam):
def set_value(self, v): def set_value(self, v):
setattr(self.obj, self.attrname, v) setattr(self.obj, self.attrname, v)
def get_value(self, v): def get_value(self):
return getattr(self.obj, self.attrname) return getattr(self.obj, self.attrname)
...@@ -151,8 +151,7 @@ class HyperParamSetter(Callback): ...@@ -151,8 +151,7 @@ class HyperParamSetter(Callback):
""" """
ret = self._get_value_to_set() ret = self._get_value_to_set()
if ret is not None and ret != self._last_value: if ret is not None and ret != self._last_value:
if self.epoch_num != self._last_epoch_set: if self.epoch_num != self._last_epoch_set: # Print this message at most once every epoch
# Print this message at most once every epoch
if self._last_value is None: if self._last_value is None:
logger.info("[HyperParamSetter] At global_step={}, {} is set to {:.6f}".format( logger.info("[HyperParamSetter] At global_step={}, {} is set to {:.6f}".format(
self.global_step, self.param.readable_name, ret)) self.global_step, self.param.readable_name, ret))
...@@ -261,13 +260,33 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -261,13 +260,33 @@ class ScheduledHyperParamSetter(HyperParamSetter):
self._step = step_based self._step = step_based
super(ScheduledHyperParamSetter, self).__init__(param) super(ScheduledHyperParamSetter, self).__init__(param)
def _get_value_to_set(self): def _get_value_to_set(self): # override parent
refnum = self.global_step if self._step else self.epoch_num return self._get_value_to_set_at_point(self._current_point())
def _current_point(self):
return self.global_step if self._step else self.epoch_num
def _check_value_at_beginning(self):
v = None
# we are at `before_train`, therefore the epoch/step associated with `current_point` has finished.
for p in range(0, self._current_point() + 1):
v = self._get_value_to_set_at_point(p) or v
actual_value = self.param.get_value()
if v is not None and v != actual_value:
logger.warn("According to the schedule, parameter '{}' should become {} at the current point. "
"However its current value is {}. "
"You may want to check whether your initialization of the parameter is as expected".format(
self.param.readable_name, v, actual_value))
def _get_value_to_set_at_point(self, point):
"""
Using schedule, compute the value to be set at a given point.
"""
laste, lastv = None, None laste, lastv = None, None
for e, v in self.schedule: for e, v in self.schedule:
if e == refnum: if e == point:
return v # meet the exact boundary, return directly return v # meet the exact boundary, return directly
if e > refnum: if e > point:
break break
laste, lastv = e, v laste, lastv = e, v
if laste is None or laste == e: if laste is None or laste == e:
...@@ -276,9 +295,13 @@ class ScheduledHyperParamSetter(HyperParamSetter): ...@@ -276,9 +295,13 @@ class ScheduledHyperParamSetter(HyperParamSetter):
if self.interp is None: if self.interp is None:
# If no interpolation, nothing to do. # If no interpolation, nothing to do.
return None return None
v = (refnum - laste) * 1. / (e - laste) * (v - lastv) + lastv v = (point - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v return v
def _before_train(self):
super(ScheduledHyperParamSetter, self)._before_train()
self._check_value_at_beginning()
def _trigger_epoch(self): def _trigger_epoch(self):
if not self._step: if not self._step:
self.trigger() self.trigger()
......
# -*- coding: utf-8 -*-
import unittest
import tensorflow as tf
import six
from ..utils import logger
from ..train.trainers import NoOpTrainer
from .param import ScheduledHyperParamSetter, ObjAttrParam
class ParamObject(object):
"""
An object that holds the param to be set, for testing purposes.
"""
PARAM_NAME = 'param'
def __init__(self):
self.param_history = {}
self.__dict__[self.PARAM_NAME] = 1.0
def __setattr__(self, name, value):
if name == self.PARAM_NAME:
self._set_param(value)
super(ParamObject, self).__setattr__(name, value)
def _set_param(self, value):
self.param_history[self.trainer.global_step] = value
class ScheduledHyperParamSetterTest(unittest.TestCase):
def setUp(self):
self._param_obj = ParamObject()
def tearDown(self):
tf.reset_default_graph()
def _create_trainer_with_scheduler(self, scheduler,
steps_per_epoch, max_epoch, starting_epoch=1):
trainer = NoOpTrainer()
tf.get_variable(name='test_var', shape=[])
self._param_obj.trainer = trainer
trainer.train_with_defaults(
callbacks=[scheduler],
extra_callbacks=[],
monitors=[],
steps_per_epoch=steps_per_epoch,
max_epoch=max_epoch,
starting_epoch=starting_epoch
)
return self._param_obj.param_history
def testInterpolation(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(30, 0.3), (40, 0.4), (50, 0.5)], interp='linear', step_based=True)
history = self._create_trainer_with_scheduler(scheduler, 10, 50, starting_epoch=20)
self.assertEqual(min(history.keys()), 30)
self.assertEqual(history[30], 0.3)
self.assertEqual(history[40], 0.4)
self.assertEqual(history[45], 0.45)
def testSchedule(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
history = self._create_trainer_with_scheduler(scheduler, 1, 50)
self.assertEqual(min(history.keys()), 10)
self.assertEqual(len(history), 3)
def testStartAfterSchedule(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(len(history), 0)
@unittest.skipIf(six.PY2, "assertLogs not supported in Python 2")
def testWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
with self.assertLogs(logger=logger._logger, level='WARNING'):
self._create_trainer_with_scheduler(scheduler, 1, 21, starting_epoch=20)
@unittest.skipIf(six.PY2, "unittest.mock not available in Python 2")
def testNoWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 1.0), (30, 1.5)])
with unittest.mock.patch('tensorpack.utils.logger.warning') as warning:
self._create_trainer_with_scheduler(scheduler, 1, 22, starting_epoch=21)
self.assertFalse(warning.called)
if __name__ == '__main__':
unittest.main()
...@@ -101,7 +101,7 @@ class ProgressBar(Callback): ...@@ -101,7 +101,7 @@ class ProgressBar(Callback):
class MaintainStepCounter(Callback): class MaintainStepCounter(Callback):
""" """
It maintains the global step in the graph, making sure it's increased by one. It maintains the global step in the graph, making sure it's increased by one at every `hooked_sess.run`.
This callback is used internally by the trainer, you don't need to worry about it. This callback is used internally by the trainer, you don't need to worry about it.
""" """
......
...@@ -61,7 +61,7 @@ class NoOpTrainer(SimpleTrainer): ...@@ -61,7 +61,7 @@ class NoOpTrainer(SimpleTrainer):
Note that `steps_per_epoch` and `max_epochs` are still valid options. Note that `steps_per_epoch` and `max_epochs` are still valid options.
""" """
def run_step(self): def run_step(self):
pass self.hooked_sess.run([])
# Only exists for type check & back-compatibility # Only exists for type check & back-compatibility
......
...@@ -56,11 +56,14 @@ def _getlogger(): ...@@ -56,11 +56,14 @@ def _getlogger():
_logger = _getlogger() _logger = _getlogger()
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', 'setLevel'] _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'exception', 'debug', 'setLevel']
# export logger functions # export logger functions
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func) locals()[func] = getattr(_logger, func)
__all__.append(func) __all__.append(func)
# 'warn' is deprecated in logging module
warn = _logger.warning
__all__.append('warn')
def _get_time_str(): def _get_time_str():
......
...@@ -5,14 +5,17 @@ DIR=$(dirname $0) ...@@ -5,14 +5,17 @@ DIR=$(dirname $0)
cd $DIR cd $DIR
export TF_CPP_MIN_LOG_LEVEL=2 export TF_CPP_MIN_LOG_LEVEL=2
export TF_CPP_MIN_VLOG_LEVEL=2
# test import (#471) # test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform' python -c 'from tensorpack.dataflow.imgaug import transform'
# Check that these private names can be imported because tensorpack is using them
python -c "from tensorflow.python.client.session import _FetchHandler"
python -c "from tensorflow.python.training.monitored_session import _HookedSession"
python -c "import tensorflow as tf; tf.Operation._add_control_input"
python -m unittest discover -v # run tests
# python -m tensorpack.models._test python -m tensorpack.callbacks.param_test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py TENSORPACK_SERIALIZE=msgpack python test_serializer.py
python -m unittest discover -v
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment