Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ccae46f4
Commit
ccae46f4
authored
Mar 16, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add session_creator option in TrainConfig (#191)
parent
09cc8662
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
30 additions
and
22 deletions
+30
-22
CHANGES.md
CHANGES.md
+2
-0
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-1
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+0
-1
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+0
-1
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+0
-1
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+0
-1
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+0
-1
examples/Inception/inceptionv3.py
examples/Inception/inceptionv3.py
+0
-1
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+2
-1
examples/cifar-convnet.py
examples/cifar-convnet.py
+0
-3
tensorpack/train/base.py
tensorpack/train/base.py
+5
-6
tensorpack/train/config.py
tensorpack/train/config.py
+19
-5
No files found.
CHANGES.md
View file @
ccae46f4
...
...
@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+
2017/03/16.
`session_config`
option in
`TrainConfig`
and
`PredictConfig`
is deprecated.
Use
`session_creator`
to define how to create session instead.
+
2017/02/20. The interface of step callbacks are changed to be the same as
`tf.train.SessionRunHook`
.
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
to the
[
existing callbacks
](
https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py
)
.
...
...
examples/A3C-Gym/train-atari.py
View file @
ccae46f4
...
...
@@ -222,7 +222,8 @@ def get_config():
StartProcOrThread
(
master
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
],
session_config
=
get_default_sess_config
(
0.5
),
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
model
=
M
,
steps_per_epoch
=
STEPS_PER_EPOCH
,
max_epoch
=
1000
,
...
...
examples/GAN/ConditionalGAN-mnist.py
View file @
ccae46f4
...
...
@@ -108,7 +108,6 @@ def get_config():
return
TrainConfig
(
dataflow
=
dataset
,
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
steps_per_epoch
=
500
,
max_epoch
=
100
,
...
...
examples/GAN/DCGAN-CelebA.py
View file @
ccae46f4
...
...
@@ -107,7 +107,6 @@ def get_config():
model
=
Model
(),
dataflow
=
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
steps_per_epoch
=
300
,
max_epoch
=
200
,
)
...
...
examples/GAN/InfoGAN-mnist.py
View file @
ccae46f4
...
...
@@ -167,7 +167,6 @@ def get_config():
return
TrainConfig
(
dataflow
=
get_data
(),
callbacks
=
[
ModelSaver
(
keep_freq
=
0.1
)],
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
steps_per_epoch
=
500
,
max_epoch
=
100
,
...
...
examples/GAN/WGAN-CelebA.py
View file @
ccae46f4
...
...
@@ -61,7 +61,6 @@ def get_config():
# use the same data in the DCGAN example
dataflow
=
DCGAN
.
get_data
(
args
.
data
),
callbacks
=
[
ModelSaver
()],
session_config
=
get_default_sess_config
(
0.5
),
steps_per_epoch
=
300
,
max_epoch
=
200
,
)
...
...
examples/Inception/inception-bn.py
View file @
ccae46f4
...
...
@@ -171,7 +171,6 @@ def get_config():
(
19
,
3e-3
),
(
24
,
1e-3
),
(
26
,
2e-4
),
(
30
,
5e-5
)])
],
session_config
=
get_default_sess_config
(
0.99
),
model
=
Model
(),
steps_per_epoch
=
5000
,
max_epoch
=
80
,
...
...
examples/Inception/inceptionv3.py
View file @
ccae46f4
...
...
@@ -277,7 +277,6 @@ def get_config():
(
41
,
8e-5
),
(
48
,
1e-5
),
(
53
,
2e-6
)]),
HumanHyperParamSetter
(
'learning_rate'
)
],
session_config
=
get_default_sess_config
(
0.9
),
model
=
Model
(),
steps_per_epoch
=
5000
,
max_epoch
=
100
,
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
ccae46f4
...
...
@@ -162,7 +162,8 @@ def get_config():
[
ScalarStats
(
'cost'
),
ClassificationError
()]),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
],
session_config
=
get_default_sess_config
(
0.5
),
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
500
,
)
...
...
examples/cifar-convnet.py
View file @
ccae46f4
...
...
@@ -114,8 +114,6 @@ def get_config(cifar_classnum):
dataset_train
=
get_data
(
'train'
,
cifar_classnum
)
dataset_test
=
get_data
(
'test'
,
cifar_classnum
)
sess_config
=
get_default_sess_config
(
0.5
)
def
lr_func
(
lr
):
if
lr
<
3e-5
:
raise
StopTraining
()
...
...
@@ -129,7 +127,6 @@ def get_config(cifar_classnum):
StatMonitorParamSetter
(
'learning_rate'
,
'val_error'
,
lr_func
,
threshold
=
0.001
,
last_k
=
10
),
],
session_config
=
sess_config
,
max_epoch
=
150
,
)
...
...
tensorpack/train/base.py
View file @
ccae46f4
...
...
@@ -20,7 +20,6 @@ from ..utils.develop import deprecated, log_deprecated
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.sesscreate
import
NewSessionCreator
__all__
=
[
'Trainer'
,
'StopTraining'
,
'MultiPredictorTowerTrainer'
]
...
...
@@ -117,11 +116,11 @@ class Trainer(object):
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
# create session
sess_creator
=
NewSessionCreator
(
config
=
self
.
config
.
session_config
)
sess_creator
=
self
.
config
.
session_creator
logger
.
info
(
"Finalize the graph, create the session ..."
)
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
self
.
_
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
sess_creator
,
hooks
=
None
)
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
sess
=
self
.
_
monitored_sess
.
_tf_sess
()
# expose the underlying session also
# init session
init_op
=
tf
.
global_variables_initializer
()
...
...
@@ -159,7 +158,7 @@ class Trainer(object):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
monitored_sess
.
should_stop
():
if
self
.
_
monitored_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
...
...
@@ -179,7 +178,7 @@ class Trainer(object):
finally
:
self
.
_callbacks
.
after_train
()
self
.
monitors
.
close
()
self
.
monitored_sess
.
close
()
self
.
_
monitored_sess
.
close
()
# Predictor related methods: TODO
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
...
...
tensorpack/train/config.py
View file @
ccae46f4
...
...
@@ -13,6 +13,7 @@ from ..utils import logger
from
..utils.develop
import
log_deprecated
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.optimizer
import
apply_grad_processors
from
.input_data
import
InputData
from
.monitor
import
TFSummaryWriter
,
JSONWriter
,
ScalarPrinter
...
...
@@ -30,7 +31,7 @@ class TrainConfig(object):
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_c
onfig
=
get_default_sess_config
()
,
session_init
=
None
,
session_c
reator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
**
kwargs
):
...
...
@@ -47,8 +48,12 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by
:func:`tfutils.get_default_sess_config()`.
session_init (SessionInit): how to initialize variables of a
session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
...
...
@@ -99,13 +104,22 @@ class TrainConfig(object):
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
self
.
session_config
=
session_config
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
if
session_creator
is
None
:
if
session_config
is
not
None
:
log_deprecated
(
"TrainConfig(session_config=)"
,
"Use session_creator=NewSessionCreator(config=) instead!"
,
"2017-05-20"
)
self
.
session_creator
=
NewSessionCreator
(
config
=
session_config
)
else
:
self
.
session_creator
=
NewSessionCreator
(
config
=
get_default_sess_config
())
else
:
self
.
session_creator
=
session_creator
if
steps_per_epoch
is
None
:
steps_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
if
steps_per_epoch
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment