Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e79d74f8
Commit
e79d74f8
authored
Feb 16, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add tests about scheduler
parent
940a1636
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
148 additions
and
25 deletions
+148
-25
.travis.yml
.travis.yml
+2
-6
docs/tutorial/extend/augmentor.md
docs/tutorial/extend/augmentor.md
+2
-2
tensorpack/callbacks/__init__.py
tensorpack/callbacks/__init__.py
+2
-0
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+31
-8
tensorpack/callbacks/param_test.py
tensorpack/callbacks/param_test.py
+96
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+1
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-1
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+4
-1
tests/run-tests.sh
tests/run-tests.sh
+9
-6
No files found.
.travis.yml
View file @
e79d74f8
...
...
@@ -52,16 +52,12 @@ before_script:
-
protoc --version
-
python -c "import cv2; print('OpenCV '+ cv2.__version__)"
-
python -c "import tensorflow as tf; print('TensorFlow '+ tf.__version__)"
# Check that these private names can be imported because tensorpack is using them
-
python -c "from tensorflow.python.client.session import _FetchHandler"
-
python -c "from tensorflow.python.training.monitored_session import _HookedSession"
-
python -c "import tensorflow as tf; tf.Operation._add_control_input"
-
mkdir -p $HOME/tensorpack_data
-
export TENSORPACK_DATASET=$HOME/tensorpack_data
script
:
-
flake8 .
-
if [[ $TRAVIS_PYTHON_VERSION == '3.6' ]]; then cd examples && flake8 .; fi
# some examples are py3 only
-
mkdir -p $HOME/tensorpack_data
-
export TENSORPACK_DATASET=$HOME/tensorpack_data
-
$TRAVIS_BUILD_DIR/tests/run-tests.sh
-
cd $TRAVIS_BUILD_DIR
# go back to root so that deploy may work
...
...
docs/tutorial/extend/augmentor.md
View file @
e79d74f8
### Design of Tensorpack's imgaug Module
###
#
Design of Tensorpack's imgaug Module
The
[
imgaug module
](
../../modules/dataflow.imgaug.html
)
is designed to allow the following usage:
...
...
@@ -22,7 +22,7 @@ The [imgaug module](../../modules/dataflow.imgaug.html) is designed to allow the
4.
Reset random seed. Random seed can be reset by
[
reset_state
](
../../modules/dataflow.imgaug.html#tensorpack.dataflow.imgaug.Augmentor.reset_state
)
.
This is important for multi-process data loading, and
it
is called automatically if you use tensorpack's
the reset method
is called automatically if you use tensorpack's
[
image augmentation dataflow
](
../../modules/dataflow.html#tensorpack.dataflow.AugmentImageComponent
)
.
### Write an Image Augmentor
...
...
tensorpack/callbacks/__init__.py
View file @
e79d74f8
...
...
@@ -47,5 +47,7 @@ for _, module_name, _ in iter_modules(
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
if
not
os
.
path
.
isfile
(
srcpath
):
continue
if
module_name
.
endswith
(
'_test'
):
continue
if
not
module_name
.
startswith
(
'_'
):
_global_import
(
module_name
)
tensorpack/callbacks/param.py
View file @
e79d74f8
...
...
@@ -103,7 +103,7 @@ class ObjAttrParam(HyperParam):
def
set_value
(
self
,
v
):
setattr
(
self
.
obj
,
self
.
attrname
,
v
)
def
get_value
(
self
,
v
):
def
get_value
(
self
):
return
getattr
(
self
.
obj
,
self
.
attrname
)
...
...
@@ -151,8 +151,7 @@ class HyperParamSetter(Callback):
"""
ret
=
self
.
_get_value_to_set
()
if
ret
is
not
None
and
ret
!=
self
.
_last_value
:
if
self
.
epoch_num
!=
self
.
_last_epoch_set
:
# Print this message at most once every epoch
if
self
.
epoch_num
!=
self
.
_last_epoch_set
:
# Print this message at most once every epoch
if
self
.
_last_value
is
None
:
logger
.
info
(
"[HyperParamSetter] At global_step={}, {} is set to {:.6f}"
.
format
(
self
.
global_step
,
self
.
param
.
readable_name
,
ret
))
...
...
@@ -261,13 +260,33 @@ class ScheduledHyperParamSetter(HyperParamSetter):
self
.
_step
=
step_based
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
param
)
def
_get_value_to_set
(
self
):
refnum
=
self
.
global_step
if
self
.
_step
else
self
.
epoch_num
def
_get_value_to_set
(
self
):
# override parent
return
self
.
_get_value_to_set_at_point
(
self
.
_current_point
())
def
_current_point
(
self
):
return
self
.
global_step
if
self
.
_step
else
self
.
epoch_num
def
_check_value_at_beginning
(
self
):
v
=
None
# we are at `before_train`, therefore the epoch/step associated with `current_point` has finished.
for
p
in
range
(
0
,
self
.
_current_point
()
+
1
):
v
=
self
.
_get_value_to_set_at_point
(
p
)
or
v
actual_value
=
self
.
param
.
get_value
()
if
v
is
not
None
and
v
!=
actual_value
:
logger
.
warn
(
"According to the schedule, parameter '{}' should become {} at the current point. "
"However its current value is {}. "
"You may want to check whether your initialization of the parameter is as expected"
.
format
(
self
.
param
.
readable_name
,
v
,
actual_value
))
def
_get_value_to_set_at_point
(
self
,
point
):
"""
Using schedule, compute the value to be set at a given point.
"""
laste
,
lastv
=
None
,
None
for
e
,
v
in
self
.
schedule
:
if
e
==
refnum
:
if
e
==
point
:
return
v
# meet the exact boundary, return directly
if
e
>
refnum
:
if
e
>
point
:
break
laste
,
lastv
=
e
,
v
if
laste
is
None
or
laste
==
e
:
...
...
@@ -276,9 +295,13 @@ class ScheduledHyperParamSetter(HyperParamSetter):
if
self
.
interp
is
None
:
# If no interpolation, nothing to do.
return
None
v
=
(
refnum
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
v
=
(
point
-
laste
)
*
1.
/
(
e
-
laste
)
*
(
v
-
lastv
)
+
lastv
return
v
def
_before_train
(
self
):
super
(
ScheduledHyperParamSetter
,
self
)
.
_before_train
()
self
.
_check_value_at_beginning
()
def
_trigger_epoch
(
self
):
if
not
self
.
_step
:
self
.
trigger
()
...
...
tensorpack/callbacks/param_test.py
0 → 100644
View file @
e79d74f8
# -*- coding: utf-8 -*-
import
unittest
import
tensorflow
as
tf
import
six
from
..utils
import
logger
from
..train.trainers
import
NoOpTrainer
from
.param
import
ScheduledHyperParamSetter
,
ObjAttrParam
class
ParamObject
(
object
):
"""
An object that holds the param to be set, for testing purposes.
"""
PARAM_NAME
=
'param'
def
__init__
(
self
):
self
.
param_history
=
{}
self
.
__dict__
[
self
.
PARAM_NAME
]
=
1.0
def
__setattr__
(
self
,
name
,
value
):
if
name
==
self
.
PARAM_NAME
:
self
.
_set_param
(
value
)
super
(
ParamObject
,
self
)
.
__setattr__
(
name
,
value
)
def
_set_param
(
self
,
value
):
self
.
param_history
[
self
.
trainer
.
global_step
]
=
value
class
ScheduledHyperParamSetterTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
_param_obj
=
ParamObject
()
def
tearDown
(
self
):
tf
.
reset_default_graph
()
def
_create_trainer_with_scheduler
(
self
,
scheduler
,
steps_per_epoch
,
max_epoch
,
starting_epoch
=
1
):
trainer
=
NoOpTrainer
()
tf
.
get_variable
(
name
=
'test_var'
,
shape
=
[])
self
.
_param_obj
.
trainer
=
trainer
trainer
.
train_with_defaults
(
callbacks
=
[
scheduler
],
extra_callbacks
=
[],
monitors
=
[],
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
max_epoch
,
starting_epoch
=
starting_epoch
)
return
self
.
_param_obj
.
param_history
def
testInterpolation
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
30
,
0.3
),
(
40
,
0.4
),
(
50
,
0.5
)],
interp
=
'linear'
,
step_based
=
True
)
history
=
self
.
_create_trainer_with_scheduler
(
scheduler
,
10
,
50
,
starting_epoch
=
20
)
self
.
assertEqual
(
min
(
history
.
keys
()),
30
)
self
.
assertEqual
(
history
[
30
],
0.3
)
self
.
assertEqual
(
history
[
40
],
0.4
)
self
.
assertEqual
(
history
[
45
],
0.45
)
def
testSchedule
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
0.4
),
(
30
,
0.5
)])
history
=
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
50
)
self
.
assertEqual
(
min
(
history
.
keys
()),
10
)
self
.
assertEqual
(
len
(
history
),
3
)
def
testStartAfterSchedule
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
0.4
),
(
30
,
0.5
)])
history
=
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
92
,
starting_epoch
=
90
)
self
.
assertEqual
(
len
(
history
),
0
)
@
unittest
.
skipIf
(
six
.
PY2
,
"assertLogs not supported in Python 2"
)
def
testWarningStartInTheMiddle
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
0.4
),
(
30
,
0.5
)])
with
self
.
assertLogs
(
logger
=
logger
.
_logger
,
level
=
'WARNING'
):
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
21
,
starting_epoch
=
20
)
@
unittest
.
skipIf
(
six
.
PY2
,
"unittest.mock not available in Python 2"
)
def
testNoWarningStartInTheMiddle
(
self
):
scheduler
=
ScheduledHyperParamSetter
(
ObjAttrParam
(
self
.
_param_obj
,
ParamObject
.
PARAM_NAME
),
[(
10
,
0.3
),
(
20
,
1.0
),
(
30
,
1.5
)])
with
unittest
.
mock
.
patch
(
'tensorpack.utils.logger.warning'
)
as
warning
:
self
.
_create_trainer_with_scheduler
(
scheduler
,
1
,
22
,
starting_epoch
=
21
)
self
.
assertFalse
(
warning
.
called
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tensorpack/callbacks/steps.py
View file @
e79d74f8
...
...
@@ -101,7 +101,7 @@ class ProgressBar(Callback):
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
It maintains the global step in the graph, making sure it's increased by one
at every `hooked_sess.run`
.
This callback is used internally by the trainer, you don't need to worry about it.
"""
...
...
tensorpack/train/trainers.py
View file @
e79d74f8
...
...
@@ -61,7 +61,7 @@ class NoOpTrainer(SimpleTrainer):
Note that `steps_per_epoch` and `max_epochs` are still valid options.
"""
def
run_step
(
self
):
pass
self
.
hooked_sess
.
run
([])
# Only exists for type check & back-compatibility
...
...
tensorpack/utils/logger.py
View file @
e79d74f8
...
...
@@ -56,11 +56,14 @@ def _getlogger():
_logger
=
_getlogger
()
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'
warn'
,
'
exception'
,
'debug'
,
'setLevel'
]
_LOGGING_METHOD
=
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'exception'
,
'debug'
,
'setLevel'
]
# export logger functions
for
func
in
_LOGGING_METHOD
:
locals
()[
func
]
=
getattr
(
_logger
,
func
)
__all__
.
append
(
func
)
# 'warn' is deprecated in logging module
warn
=
_logger
.
warning
__all__
.
append
(
'warn'
)
def
_get_time_str
():
...
...
tests/run-tests.sh
View file @
e79d74f8
...
...
@@ -5,14 +5,17 @@ DIR=$(dirname $0)
cd
$DIR
export
TF_CPP_MIN_LOG_LEVEL
=
2
export
TF_CPP_MIN_VLOG_LEVEL
=
2
# test import (#471)
python
-c
'from tensorpack.dataflow.imgaug import transform'
# Check that these private names can be imported because tensorpack is using them
python
-c
"from tensorflow.python.client.session import _FetchHandler"
python
-c
"from tensorflow.python.training.monitored_session import _HookedSession"
python
-c
"import tensorflow as tf; tf.Operation._add_control_input"
python
-m
unittest discover
-v
# python -m tensorpack.models._test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py
# run tests
python
-m
tensorpack.callbacks.param_test
TENSORPACK_SERIALIZE
=
pyarrow python test_serializer.py
TENSORPACK_SERIALIZE
=
msgpack python test_serializer.py
python
-m
unittest discover
-v
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment