Commit e79d74f8 authored by Yuxin Wu's avatar Yuxin Wu

Add tests about scheduler

parent 940a1636
......@@ -52,16 +52,12 @@ before_script:
- protoc --version
- python -c "import cv2; print('OpenCV '+ cv2.__version__)"
- python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)"
# Check that these private names can be imported because tensorpack is using them
- python -c "from tensorflow.python.client.session import _FetchHandler"
- python -c "from tensorflow.python.training.monitored_session import _HookedSession"
- python -c "import tensorflow as tf; tf.Operation._add_control_input"
- mkdir -p $HOME/tensorpack_data
- export TENSORPACK_DATASET=$HOME/tensorpack_data
script:
- flake8 .
- if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then cd examples && flake8 .; fi # some examples are py3 only
- mkdir -p $HOME/tensorpack_data
- export TENSORPACK_DATASET=$HOME/tensorpack_data
- $TRAVIS_BUILD_DIR/tests/run-tests.sh
- cd $TRAVIS_BUILD_DIR # go back to root so that deploy may work
......
### Design of Tensorpack's imgaug Module
#### Design of Tensorpack's imgaug Module
The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the following usage:
......@@ -22,7 +22,7 @@ The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the
4. Reset random seed. Random seed can be reset by
[reset_state](../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state).
This is important for multi-process data loading, and
it is called automatically if you use tensorpack's
the reset method is called automatically if you use tensorpack's
[image augmentation dataflow](../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent).
### Write an Image Augmentor
......
......@@ -47,5 +47,7 @@ for _, module_name, _ in iter_modules(
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.endswith('_test'):
continue
if not module_name.startswith('_'):
_global_import(module_name)
......@@ -103,7 +103,7 @@ class ObjAttrParam(HyperParam):
def set_value(self, v):
setattr(self.obj, self.attrname, v)
def get_value(self, v):
def get_value(self):
return getattr(self.obj, self.attrname)
......@@ -151,8 +151,7 @@ class HyperParamSetter(Callback):
"""
ret = self._get_value_to_set()
if ret is not None and ret != self._last_value:
if self.epoch_num != self._last_epoch_set:
# Print this message at most once every epoch
if self.epoch_num != self._last_epoch_set: # Print this message at most once every epoch
if self._last_value is None:
logger.info("[HyperParamSetter] At global_step={}, {} is set to {:.6f}".format(
self.global_step, self.param.readable_name, ret))
......@@ -261,13 +260,33 @@ class ScheduledHyperParamSetter(HyperParamSetter):
self._step = step_based
super(ScheduledHyperParamSetter, self).__init__(param)
def _get_value_to_set(self):
refnum = self.global_step if self._step else self.epoch_num
def _get_value_to_set(self): # override parent
return self._get_value_to_set_at_point(self._current_point())
def _current_point(self):
return self.global_step if self._step else self.epoch_num
def _check_value_at_beginning(self):
v = None
# we are at `before_train`, therefore the epoch/step associated with `current_point` has finished.
for p in range(0, self._current_point() + 1):
v = self._get_value_to_set_at_point(p) or v
actual_value = self.param.get_value()
if v is not None and v != actual_value:
logger.warn("According to the schedule, parameter '{}' should become {} at the current point. "
"However its current value is {}. "
"You may want to check whether your initialization of the parameter is as expected".format(
self.param.readable_name, v, actual_value))
def _get_value_to_set_at_point(self, point):
"""
Using schedule, compute the value to be set at a given point.
"""
laste, lastv = None, None
for e, v in self.schedule:
if e == refnum:
if e == point:
return v # meet the exact boundary, return directly
if e > refnum:
if e > point:
break
laste, lastv = e, v
if laste is None or laste == e:
......@@ -276,9 +295,13 @@ class ScheduledHyperParamSetter(HyperParamSetter):
if self.interp is None:
# If no interpolation, nothing to do.
return None
v = (refnum - laste) * 1. / (e - laste) * (v - lastv) + lastv
v = (point - laste) * 1. / (e - laste) * (v - lastv) + lastv
return v
def _before_train(self):
super(ScheduledHyperParamSetter, self)._before_train()
self._check_value_at_beginning()
def _trigger_epoch(self):
if not self._step:
self.trigger()
......
# -*- coding: utf-8 -*-
import unittest
import tensorflow as tf
import six
from ..utils import logger
from ..train.trainers import NoOpTrainer
from .param import ScheduledHyperParamSetter, ObjAttrParam
class ParamObject(object):
"""
An object that holds the param to be set, for testing purposes.
"""
PARAM_NAME = 'param'
def __init__(self):
self.param_history = {}
self.__dict__[self.PARAM_NAME] = 1.0
def __setattr__(self, name, value):
if name == self.PARAM_NAME:
self._set_param(value)
super(ParamObject, self).__setattr__(name, value)
def _set_param(self, value):
self.param_history[self.trainer.global_step] = value
class ScheduledHyperParamSetterTest(unittest.TestCase):
def setUp(self):
self._param_obj = ParamObject()
def tearDown(self):
tf.reset_default_graph()
def _create_trainer_with_scheduler(self, scheduler,
steps_per_epoch, max_epoch, starting_epoch=1):
trainer = NoOpTrainer()
tf.get_variable(name='test_var', shape=[])
self._param_obj.trainer = trainer
trainer.train_with_defaults(
callbacks=[scheduler],
extra_callbacks=[],
monitors=[],
steps_per_epoch=steps_per_epoch,
max_epoch=max_epoch,
starting_epoch=starting_epoch
)
return self._param_obj.param_history
def testInterpolation(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(30, 0.3), (40, 0.4), (50, 0.5)], interp='linear', step_based=True)
history = self._create_trainer_with_scheduler(scheduler, 10, 50, starting_epoch=20)
self.assertEqual(min(history.keys()), 30)
self.assertEqual(history[30], 0.3)
self.assertEqual(history[40], 0.4)
self.assertEqual(history[45], 0.45)
def testSchedule(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
history = self._create_trainer_with_scheduler(scheduler, 1, 50)
self.assertEqual(min(history.keys()), 10)
self.assertEqual(len(history), 3)
def testStartAfterSchedule(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
history = self._create_trainer_with_scheduler(scheduler, 1, 92, starting_epoch=90)
self.assertEqual(len(history), 0)
@unittest.skipIf(six.PY2, "assertLogs not supported in Python 2")
def testWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 0.4), (30, 0.5)])
with self.assertLogs(logger=logger._logger, level='WARNING'):
self._create_trainer_with_scheduler(scheduler, 1, 21, starting_epoch=20)
@unittest.skipIf(six.PY2, "unittest.mock not available in Python 2")
def testNoWarningStartInTheMiddle(self):
scheduler = ScheduledHyperParamSetter(
ObjAttrParam(self._param_obj, ParamObject.PARAM_NAME),
[(10, 0.3), (20, 1.0), (30, 1.5)])
with unittest.mock.patch('tensorpack.utils.logger.warning') as warning:
self._create_trainer_with_scheduler(scheduler, 1, 22, starting_epoch=21)
self.assertFalse(warning.called)
if __name__ == '__main__':
unittest.main()
......@@ -101,7 +101,7 @@ class ProgressBar(Callback):
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
It maintains the global step in the graph, making sure it's increased by one at every `hooked_sess.run`.
This callback is used internally by the trainer, you don't need to worry about it.
"""
......
......@@ -61,7 +61,7 @@ class NoOpTrainer(SimpleTrainer):
Note that `steps_per_epoch` and `max_epochs` are still valid options.
"""
def run_step(self):
pass
self.hooked_sess.run([])
# Only exists for type check & back-compatibility
......
......@@ -56,11 +56,14 @@ def _getlogger():
_logger = _getlogger()
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', 'setLevel']
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'exception', 'debug', 'setLevel']
# export logger functions
for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func)
__all__.append(func)
# 'warn' is deprecated in logging module
warn = _logger.warning
__all__.append('warn')
def _get_time_str():
......
......@@ -5,14 +5,17 @@ DIR=$(dirname $0)
cd $DIR
export TF_CPP_MIN_LOG_LEVEL=2
export TF_CPP_MIN_VLOG_LEVEL=2
# test import (#471)
python -c 'from tensorpack.dataflow.imgaug import transform'
# Check that these private names can be imported because tensorpack is using them
python -c "from tensorflow.python.client.session import _FetchHandler"
python -c "from tensorflow.python.training.monitored_session import _HookedSession"
python -c "import tensorflow as tf; tf.Operation._add_control_input"
python -m unittest discover -v
# python -m tensorpack.models._test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py
# run tests
python -m tensorpack.callbacks.param_test
TENSORPACK_SERIALIZE=pyarrow python test_serializer.py
TENSORPACK_SERIALIZE=msgpack python test_serializer.py
python -m unittest discover -v
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment