Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e79d74f8
Commit
e79d74f8
authored
Feb 16, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add tests about scheduler
parent
940a1636
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
148 additions
and
25 deletions
+148
-25
.travis.yml
.travis.yml
+2
-6
docs/tutorial/extend/augmentor.md
docs/tutorial/extend/augmentor.md
+2
-2
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+2
-0
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+31
-8
tensorpack/callbacks/param_test.py
tensorpack/callbacks/param_test.py
+96
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+1
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-1
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+4
-1
tests/run-tests.sh
tests/run-tests.sh
+9
-6
No files found.
.travis.yml
View file @
e79d74f8
...
@@ -52,16 +52,12 @@ before_script:
...
@@ -52,16 +52,12 @@ before_script:
-
protoc --version
-
protoc --version
-
python -c "import cv2; print('OpenCV '+ cv2.__version__)"
-
python -c "import cv2; print('OpenCV '+ cv2.__version__)"
-
python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)"
-
python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)"
# Check that these private names can be imported because tensorpack is using them
-
mkdir -p $HOME/tensorpack_data
-
python -c "from tensorflow.python.client.session import _FetchHandler"
-
export TENSORPACK_DATASET=$HOME/tensorpack_data
-
python -c "from tensorflow.python.training.monitored_session import _HookedSession"
-
python -c "import tensorflow as tf; tf.Operation._add_control_input"
script
:
script
:
-
flake8 .
-
flake8 .
-
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then cd examples && flake8 .; fi
# some examples are py3 only
-
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then cd examples && flake8 .; fi
# some examples are py3 only
-
mkdir -p $HOME/tensorpack_data
-
export TENSORPACK_DATASET=$HOME/tensorpack_data
-
$TRAVIS_BUILD_DIR/tests/run-tests.sh
-
$TRAVIS_BUILD_DIR/tests/run-tests.sh
-
cd $TRAVIS_BUILD_DIR
# go back to root so that deploy may work
-
cd $TRAVIS_BUILD_DIR
# go back to root so that deploy may work
...
...
docs/tutorial/extend/augmentor.md
View file @
e79d74f8
### Design of Tensorpack's imgaug Module
###
#
Design of Tensorpack's imgaug Module
The
[
imgaug module
](
../../modules/dataflow.imgaug.html
)
is designed to allow the following usage:
The
[
imgaug module
](
../../modules/dataflow.imgaug.html
)
is designed to allow the following usage:
...
@@ -22,7 +22,7 @@ The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the
...
@@ -22,7 +22,7 @@ The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the
4.
Reset random seed. Random seed can be reset by
4.
Reset random seed. Random seed can be reset by
[
reset_state
](
../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state
)
.
[
reset_state
](
../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state
)
.
This is important for multi-process data loading, and
This is important for multi-process data loading, and
it
is called automatically if you use tensorpack's
the reset method
is called automatically if you use tensorpack's
[
image augmentation dataflow
](
../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent
)
.
[
image augmentation dataflow
](
../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent
)
.
### Write an Image Augmentor
### Write an Image Augmentor
...
...
tensorpack/callbacks/__init__.py
View file @
e79d74f8
...
@@ -47,5 +47,7 @@ for _, module_name, _ in iter_modules(
...
@@ -47,5 +47,7 @@ for _, module_name, _ in iter_modules(
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
if
not
os
.
path
.
isfile
(
srcpath
):
if
not
os
.
path
.
isfile
(
srcpath
):
continue
continue
if
module_name
.
endswith
(
'_test'
):
continue
if
not
module_name
.
startswith
(
'_'
):
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
_global_import
(
module_name
)
tensorpack/callbacks/param.py
View file @
e79d74f8
...
@@ -103,7 +103,7 @@ class ObjAttrParam(HyperParam):
...
@@ -103,7 +103,7 @@ class ObjAttrParam(HyperParam):
def
set_value
(
self
,
v
):
def
set_value
(
self
,
v
):
setattr
(
self
.
obj
,
self
.
attrname
,
v
)
setattr
(
self
.
obj
,
self
.
attrname
,
v
)
def
get_value
(
self
,
v
):
def
get_value
(
self
):
return
getattr
(
self
.
obj
,
self
.
attrname
)
return
getattr
(
self
.
obj
,
self
.
attrname
)
...
@@ -151,8 +151,7 @@ class HyperParamSetter(Callback):
...
@@ -151,8 +151,7 @@ class HyperParamSetter(Callback):
"""
"""
ret
=
self
.
_get_value_to_set
()
ret
=
self
.
_get_value_to_set
()
if
ret
is
not
None
and
ret
!=
self
.
_last_value
:
if
ret
is
not
None
and
ret
!=
self
.
_last_value
:
if
self
.
epoch_num
!=
self
.
_last_epoch_set
:
if
self
.
epoch_num
!=
self
.
_last_epoch_set
:
# Print this message at most once every epoch
# Print this message at most once every epoch
if
self
.
_last_value
is
None
:
if
self
.
_last_value
is
None
:
logger
.
info
(
"[HyperParamSetter] At global_step={}, {} is set to {:.6f}"
.
format
(
logger
.
info
(
"[HyperParamSetter] At global_step={}, {} is set to {:.6f}"
.
format
(
self
.
global_step
,
self
.
param
.
readable_name
,
ret
))
self
.
global_step
,
self
.
param
.
readable_name
,
ret
))
...
@@ -261,13 +260,33 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -261,13 +260,33 @@ class ScheduledHyperParamSetter(HyperParamSetter):
self
.
_step
=
step_based
self
.
_step
=
step_based
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
param
)
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
param
)
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
# override parent
refnum
=
self
.
global_step
if
self
.
_step
else
self
.
epoch_num
return
self
.
_get_value_to_set_at_point
(
self
.
_current_point
())
def
_current_point
(
self
):
return
self
.
global_step
if
self
.
_step
else
self
.
epoch_num
def
_check_value_at_beginning
(
self
):
v
=
None
# we are at `before_train`, therefore the epoch/step associated with `current_point` has finished.
for
p
in
range
(
0
,
self
.
_current_point
()
+
1
):
v
=
self
.
_get_value_to_set_at_point
(
p
)
or
v
actual_value
=
self
.
param
.
get_value
()
if
v
is
not
None
and
v
!=
actual_value
:
logger
.
warn
(
"According to the schedule, parameter '{}' should become {} at the current point. "
"However its current value is {}. "
"You may want to check whether your initialization of the parameter is as expected"
.
format
(
self
.
param
.
readable_name
,
v
,
actual_value
))
def
_get_value_to_set_at_point
(
self
,
point
):
"""
Using schedule, compute the value to be set at a given point.
"""
laste
,
lastv
=
None
,
None
laste
,
lastv
=
None
,
None
for
e
,
v
in
self
.
schedule
:
for
e
,
v
in
self
.
schedule
:
if
e
==
refnum
:
if
e
==
point
:
return
v
# meet the exact boundary, return directly
return
v
# meet the exact boundary, return directly
if
e
>
refnum
:
if
e
>
point
:
break
break
laste
,
lastv
=
e
,
v
laste
,
lastv
=
e
,
v
if
laste
is
None
or
laste
==
e
:
if
laste
is
None
or
laste
==
e
:
...
@@ -276,9 +295,13 @@ class ScheduledHyperParamSetter(HyperParamSetter):
...
@@ -276,9 +295,13 @@ class ScheduledHyperParamSetter(HyperParamSetter):
if
self
.
interp
is
None
:
if
self
.
interp
is
None
:
# If no interpolation, nothing to do.
# If no interpolation, nothing to do.
return
None
return
None
v
=
(
refnum
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
v
=
(
point
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
return
v
return
v
def
_before_train
(
self
):
super
(
ScheduledHyperParamSetter
,
self
)
.
_before_train
()
self
.
_check_value_at_beginning
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
not
self
.
_step
:
if
not
self
.
_step
:
self
.
trigger
()
self
.
trigger
()
...
...
tensorpack/callbacks/param_test.py
0 → 100644
View file @
e79d74f8
# -*- coding: utf-8 -*-
import
unittest
import
tensorflow
as
tf
import
six
from
..utils
import
logger
from
..train.trainers
import
NoOpTrainer
from
.param
import
ScheduledHyperParamSetter
,
ObjAttrParam
class
ParamObject
(
object
):
"""
An object that holds the param to be set, for testing purposes.
"""
PARAM_NAME
=
'param'
def
__init__
(
self
):
self
.
param_history
=
{}
self
.
__dict__
[
self
.
PARAM_NAME
]
=
1.0
def
__setattr__
(
self
,
name
,
value
):
if
name
==
self
.
PARAM_NAME
:
self
.
_set_param
(
value
)
super
(
ParamObject
,
self
)
.
__setattr__
(
name
,
value
)
def
_set_param
(
self
,
value
):
self
.
param_history
[
self
.
trainer
.
global_step
]
=
value
class
ScheduledHyperParamSetterTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_param_obj
=
ParamObject
()
def
tearDown
(
self
):
tf
.
reset_default_graph
()
def
_create_trainer_with_scheduler
(
self
,
scheduler
,
steps_per_epoch
,
max_epoch
,
starting_epoch
=
1
):
trainer
=
NoOpTrainer
()
tf
.
get_variable
(
name
=
'test_var'
,
shape
=
[])
self
.
_param_obj
.
trainer
=
trainer
trainer
.
train_with_defaults
(
callbacks
=
[
scheduler
],
extra_callbacks
=
[],
monitors
=
[],
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
max_epoch
,
starting_epoch
=
starting_epoch
)
return
self
.
_param_obj
.
param_history
def
testInterpolation
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
30
,
0.3
),
(
40
,
0.4
),
(
50
,
0.5
)],
interp
=
'linear'
,
step_based
=
True
)
history
=
self
.
_create_trainer_with_scheduler
(
scheduler
,
10
,
50
,
starting_epoch
=
20
)
self
.
assertEqual
(
min
(
history
.
keys
()),
30
)
self
.
assertEqual
(
history
[
30
],
0.3
)
self
.
assertEqual
(
history
[
40
],
0.4
)
self
.
assertEqual
(
history
[
45
],
0.45
)
def
testSchedule
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
0.4
),
(
30
,
0.5
)])
history
=
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
50
)
self
.
assertEqual
(
min
(
history
.
keys
()),
10
)
self
.
assertEqual
(
len
(
history
),
3
)
def
testStartAfterSchedule
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
0.4
),
(
30
,
0.5
)])
history
=
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
92
,
starting_epoch
=
90
)
self
.
assertEqual
(
len
(
history
),
0
)
@
unittest
.
skipIf
(
six
.
PY2
,
"assertLogs not supported in Python 2"
)
def
testWarningStartInTheMiddle
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
0.4
),
(
30
,
0.5
)])
with
self
.
assertLogs
(
logger
=
logger
.
_logger
,
level
=
'WARNING'
):
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
21
,
starting_epoch
=
20
)
@
unittest
.
skipIf
(
six
.
PY2
,
"unittest.mock not available in Python 2"
)
def
testNoWarningStartInTheMiddle
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
1.0
),
(
30
,
1.5
)])
with
unittest
.
mock
.
patch
(
'tensorpack.utils.logger.warning'
)
as
warning
:
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
22
,
starting_epoch
=
21
)
self
.
assertFalse
(
warning
.
called
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tensorpack/callbacks/steps.py
View file @
e79d74f8
...
@@ -101,7 +101,7 @@ class ProgressBar(Callback):
...
@@ -101,7 +101,7 @@ class ProgressBar(Callback):
class
MaintainStepCounter
(
Callback
):
class
MaintainStepCounter
(
Callback
):
"""
"""
It maintains the global step in the graph, making sure it's increased by one.
It maintains the global step in the graph, making sure it's increased by one
at every `hooked_sess.run`
.
This callback is used internally by the trainer, you don't need to worry about it.
This callback is used internally by the trainer, you don't need to worry about it.
"""
"""
...
...
tensorpack/train/trainers.py
View file @
e79d74f8
...
@@ -61,7 +61,7 @@ class NoOpTrainer(SimpleTrainer):
...
@@ -61,7 +61,7 @@ class NoOpTrainer(SimpleTrainer):
Note that `steps_per_epoch` and `max_epochs` are still valid options.
Note that `steps_per_epoch` and `max_epochs` are still valid options.
"""
"""
def
run_step
(
self
):
def
run_step
(
self
):
pass
self
.
hooked_sess
.
run
([])
# Only exists for type check & back-compatibility
# Only exists for type check & back-compatibility
...
...
tensorpack/utils/logger.py
View file @
e79d74f8
...
@@ -56,11 +56,14 @@ def _getlogger():
...
@@ -56,11 +56,14 @@ def _getlogger():
_logger
=
_getlogger
()
_logger
=
_getlogger
()
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'
warn'
,
'
exception'
,
'debug'
,
'setLevel'
]
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'exception'
,
'debug'
,
'setLevel'
]
# export logger functions
# export logger functions
for
func
in
_LOGGING_METHOD
:
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
locals
()[
func
]
=
getattr
(
_logger
,
func
)
__all__
.
append
(
func
)
__all__
.
append
(
func
)
# 'warn' is deprecated in logging module
warn
=
_logger
.
warning
__all__
.
append
(
'warn'
)
def
_get_time_str
():
def
_get_time_str
():
...
...
tests/run-tests.sh
View file @
e79d74f8
...
@@ -5,14 +5,17 @@ DIR=$(dirname $0)
...
@@ -5,14 +5,17 @@ DIR=$(dirname $0)
cd
$DIR
cd
$DIR
export
TF_CPP_MIN_LOG_LEVEL
=
2
export
TF_CPP_MIN_LOG_LEVEL
=
2
export
TF_CPP_MIN_VLOG_LEVEL
=
2
# test import (#471)
# test import (#471)
python
-c
'from tensorpack.dataflow.imgaug import transform'
python
-c
'from tensorpack.dataflow.imgaug import transform'
# Check that these private names can be imported because tensorpack is using them
python
-c
"from tensorflow.python.client.session import _FetchHandler"
python
-c
"from tensorflow.python.training.monitored_session import _HookedSession"
python
-c
"import tensorflow as tf; tf.Operation._add_control_input"
python
-m
unittest discover
-v
# run tests
# python -m tensorpack.models._test
python
-m
tensorpack.callbacks.param_test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py
TENSORPACK_SERIALIZE
=
pyarrow python test_serializer.py
TENSORPACK_SERIALIZE
=
pyarrow python test_serializer.py
TENSORPACK_SERIALIZE
=
msgpack python test_serializer.py
TENSORPACK_SERIALIZE
=
msgpack python test_serializer.py
python
-m
unittest discover
-v
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment