Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ccae46f4
Commit
ccae46f4
authored
Mar 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add session_creator option in TrainConfig (#191)
parent
09cc8662
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
30 additions
and
22 deletions
+30
-22
CHANGES.md
CHANGES.md
+2
-0
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-1
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+0
-1
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+0
-1
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+0
-1
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+0
-1
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+0
-1
examples/Inception/inceptionv3.py
examples/Inception/inceptionv3.py
+0
-1
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+2
-1
examples/cifar-convnet.py
examples/cifar-convnet.py
+0
-3
tensorpack/train/base.py
tensorpack/train/base.py
+5
-6
tensorpack/train/config.py
tensorpack/train/config.py
+19
-5
No files found.
CHANGES.md
View file @
ccae46f4
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
TensorFlow itself also changes API and those are not listed here.
+
2017/03/16.
`session_config`
option in
`TrainConfig`
and
`PredictConfig`
is deprecated.
Use
`session_creator`
to define how to create session instead.
+
2017/02/20. The interface of step callbacks are changed to be the same as
`tf.train.SessionRunHook`
.
+
2017/02/20. The interface of step callbacks are changed to be the same as
`tf.train.SessionRunHook`
.
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
to the
[
existing callbacks
](
https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py
)
.
to the
[
existing callbacks
](
https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py
)
.
...
...
examples/A3C-Gym/train-atari.py
View file @
ccae46f4
...
@@ -222,7 +222,8 @@ def get_config():
...
@@ -222,7 +222,8 @@ def get_config():
StartProcOrThread
(
master
),
StartProcOrThread
(
master
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
],
],
session_config
=
get_default_sess_config
(
0.5
),
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
model
=
M
,
model
=
M
,
steps_per_epoch
=
STEPS_PER_EPOCH
,
steps_per_epoch
=
STEPS_PER_EPOCH
,
max_epoch
=
1000
,
max_epoch
=
1000
,
...
...
examples/GAN/ConditionalGAN-mnist.py
View file @
ccae46f4
...
@@ -108,7 +108,6 @@ def get_config():
...
@@ -108,7 +108,6 @@ def get_config():
return
TrainConfig
(
return
TrainConfig
(
dataflow
=
dataset
,
dataflow
=
dataset
,
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
model
=
Model
(),
steps_per_epoch
=
500
,
steps_per_epoch
=
500
,
max_epoch
=
100
,
max_epoch
=
100
,
...
...
examples/GAN/DCGAN-CelebA.py
View file @
ccae46f4
...
@@ -107,7 +107,6 @@ def get_config():
...
@@ -107,7 +107,6 @@ def get_config():
model
=
Model
(),
model
=
Model
(),
dataflow
=
get_data
(
args
.
data
),
dataflow
=
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
steps_per_epoch
=
300
,
steps_per_epoch
=
300
,
max_epoch
=
200
,
max_epoch
=
200
,
)
)
...
...
examples/GAN/InfoGAN-mnist.py
View file @
ccae46f4
...
@@ -167,7 +167,6 @@ def get_config():
...
@@ -167,7 +167,6 @@ def get_config():
return
TrainConfig
(
return
TrainConfig
(
dataflow
=
get_data
(),
dataflow
=
get_data
(),
callbacks
=
[
ModelSaver
(
keep_freq
=
0.1
)],
callbacks
=
[
ModelSaver
(
keep_freq
=
0.1
)],
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
model
=
Model
(),
steps_per_epoch
=
500
,
steps_per_epoch
=
500
,
max_epoch
=
100
,
max_epoch
=
100
,
...
...
examples/GAN/WGAN-CelebA.py
View file @
ccae46f4
...
@@ -61,7 +61,6 @@ def get_config():
...
@@ -61,7 +61,6 @@ def get_config():
# use the same data in the DCGAN example
# use the same data in the DCGAN example
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
steps_per_epoch
=
300
,
steps_per_epoch
=
300
,
max_epoch
=
200
,
max_epoch
=
200
,
)
)
...
...
examples/Inception/inception-bn.py
View file @
ccae46f4
...
@@ -171,7 +171,6 @@ def get_config():
...
@@ -171,7 +171,6 @@ def get_config():
(
19
,
3e-3
),
(
24
,
1e-3
),
(
26
,
2e-4
),
(
19
,
3e-3
),
(
24
,
1e-3
),
(
26
,
2e-4
),
(
30
,
5e-5
)])
(
30
,
5e-5
)])
],
],
session_config
=
get_default_sess_config
(
0.99
),
model
=
Model
(),
model
=
Model
(),
steps_per_epoch
=
5000
,
steps_per_epoch
=
5000
,
max_epoch
=
80
,
max_epoch
=
80
,
...
...
examples/Inception/inceptionv3.py
View file @
ccae46f4
...
@@ -277,7 +277,6 @@ def get_config():
...
@@ -277,7 +277,6 @@ def get_config():
(
41
,
8e-5
),
(
48
,
1e-5
),
(
53
,
2e-6
)]),
(
41
,
8e-5
),
(
48
,
1e-5
),
(
53
,
2e-6
)]),
HumanHyperParamSetter
(
'learning_rate'
)
HumanHyperParamSetter
(
'learning_rate'
)
],
],
session_config
=
get_default_sess_config
(
0.9
),
model
=
Model
(),
model
=
Model
(),
steps_per_epoch
=
5000
,
steps_per_epoch
=
5000
,
max_epoch
=
100
,
max_epoch
=
100
,
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
ccae46f4
...
@@ -162,7 +162,8 @@ def get_config():
...
@@ -162,7 +162,8 @@ def get_config():
[
ScalarStats
(
'cost'
),
ClassificationError
()]),
[
ScalarStats
(
'cost'
),
ClassificationError
()]),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
],
],
session_config
=
get_default_sess_config
(
0.5
),
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
steps_per_epoch
=
steps_per_epoch
,
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
500
,
max_epoch
=
500
,
)
)
...
...
examples/cifar-convnet.py
View file @
ccae46f4
...
@@ -114,8 +114,6 @@ def get_config(cifar_classnum):
...
@@ -114,8 +114,6 @@ def get_config(cifar_classnum):
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
dataset_test
=
get_data
(
'test'
,
cifar_classnum
)
dataset_test
=
get_data
(
'test'
,
cifar_classnum
)
sess_config
=
get_default_sess_config
(
0.5
)
def
lr_func
(
lr
):
def
lr_func
(
lr
):
if
lr
<
3e-5
:
if
lr
<
3e-5
:
raise
StopTraining
()
raise
StopTraining
()
...
@@ -129,7 +127,6 @@ def get_config(cifar_classnum):
...
@@ -129,7 +127,6 @@ def get_config(cifar_classnum):
StatMonitorParamSetter
(
'learning_rate'
,
'val_error'
,
lr_func
,
StatMonitorParamSetter
(
'learning_rate'
,
'val_error'
,
lr_func
,
threshold
=
0.001
,
last_k
=
10
),
threshold
=
0.001
,
last_k
=
10
),
],
],
session_config
=
sess_config
,
max_epoch
=
150
,
max_epoch
=
150
,
)
)
...
...
tensorpack/train/base.py
View file @
ccae46f4
...
@@ -20,7 +20,6 @@ from ..utils.develop import deprecated, log_deprecated
...
@@ -20,7 +20,6 @@ from ..utils.develop import deprecated, log_deprecated
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
..tfutils.sesscreate
import
NewSessionCreator
__all__
=
[
'Trainer'
,
'StopTraining'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'Trainer'
,
'StopTraining'
,
'MultiPredictorTowerTrainer'
]
...
@@ -117,11 +116,11 @@ class Trainer(object):
...
@@ -117,11 +116,11 @@ class Trainer(object):
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
# create session
# create session
sess_creator
=
NewSessionCreator
(
config
=
self
.
config
.
session_config
)
sess_creator
=
self
.
config
.
session_creator
logger
.
info
(
"Finalize the graph, create the session ..."
)
logger
.
info
(
"Finalize the graph, create the session ..."
)
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
self
.
_
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
sess_creator
,
hooks
=
None
)
session_creator
=
sess_creator
,
hooks
=
None
)
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
sess
=
self
.
_
monitored_sess
.
_tf_sess
()
# expose the underlying session also
# init session
# init session
init_op
=
tf
.
global_variables_initializer
()
init_op
=
tf
.
global_variables_initializer
()
...
@@ -159,7 +158,7 @@ class Trainer(object):
...
@@ -159,7 +158,7 @@ class Trainer(object):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
start_time
=
time
.
time
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
monitored_sess
.
should_stop
():
if
self
.
_
monitored_sess
.
should_stop
():
return
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
trigger_step
()
...
@@ -179,7 +178,7 @@ class Trainer(object):
...
@@ -179,7 +178,7 @@ class Trainer(object):
finally
:
finally
:
self
.
_callbacks
.
after_train
()
self
.
_callbacks
.
after_train
()
self
.
monitors
.
close
()
self
.
monitors
.
close
()
self
.
monitored_sess
.
close
()
self
.
_
monitored_sess
.
close
()
# Predictor related methods: TODO
# Predictor related methods: TODO
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
...
...
tensorpack/train/config.py
View file @
ccae46f4
...
@@ -13,6 +13,7 @@ from ..utils import logger
...
@@ -13,6 +13,7 @@ from ..utils import logger
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.optimizer
import
apply_grad_processors
from
..tfutils.optimizer
import
apply_grad_processors
from
.input_data
import
InputData
from
.input_data
import
InputData
from
.monitor
import
TFSummaryWriter
,
JSONWriter
,
ScalarPrinter
from
.monitor
import
TFSummaryWriter
,
JSONWriter
,
ScalarPrinter
...
@@ -30,7 +31,7 @@ class TrainConfig(object):
...
@@ -30,7 +31,7 @@ class TrainConfig(object):
model
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
monitors
=
None
,
session_c
onfig
=
get_default_sess_config
()
,
session_init
=
None
,
session_c
reator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
**
kwargs
):
**
kwargs
):
...
@@ -47,8 +48,12 @@ class TrainConfig(object):
...
@@ -47,8 +48,12 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_creator (tf.train.SessionCreator): how to create the
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
session. Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by
:func:`tfutils.get_default_sess_config()`.
session_init (SessionInit): how to initialize variables of a
session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
Defaults to the input data size.
...
@@ -99,13 +104,22 @@ class TrainConfig(object):
...
@@ -99,13 +104,22 @@ class TrainConfig(object):
self
.
model
=
model
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
session_config
=
session_config
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
if
session_init
is
None
:
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
if
session_creator
is
None
:
if
session_config
is
not
None
:
log_deprecated
(
"TrainConfig(session_config=)"
,
"Use session_creator=NewSessionCreator(config=) instead!"
,
"2017-05-20"
)
self
.
session_creator
=
NewSessionCreator
(
config
=
session_config
)
else
:
self
.
session_creator
=
NewSessionCreator
(
config
=
get_default_sess_config
())
else
:
self
.
session_creator
=
session_creator
if
steps_per_epoch
is
None
:
if
steps_per_epoch
is
None
:
steps_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
steps_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
if
steps_per_epoch
is
not
None
:
if
steps_per_epoch
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment