Commit ccae46f4 authored by Yuxin Wu's avatar Yuxin Wu

add session_creator option in TrainConfig (#191)

parent 09cc8662
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
+ 2017/03/16. `session_config` option in `TrainConfig` and `PredictConfig` is deprecated.
Use `session_creator` to define how to create session instead.
+ 2017/02/20. The interface of step callbacks are changed to be the same as `tf.train.SessionRunHook`. + 2017/02/20. The interface of step callbacks are changed to be the same as `tf.train.SessionRunHook`.
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
to the [existing callbacks](https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py). to the [existing callbacks](https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py).
......
...@@ -222,7 +222,8 @@ def get_config(): ...@@ -222,7 +222,8 @@ def get_config():
StartProcOrThread(master), StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
], ],
session_config=get_default_sess_config(0.5), session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
model=M, model=M,
steps_per_epoch=STEPS_PER_EPOCH, steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000, max_epoch=1000,
......
...@@ -108,7 +108,6 @@ def get_config(): ...@@ -108,7 +108,6 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=dataset,
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=100, max_epoch=100,
......
...@@ -107,7 +107,6 @@ def get_config(): ...@@ -107,7 +107,6 @@ def get_config():
model=Model(), model=Model(),
dataflow=get_data(args.data), dataflow=get_data(args.data),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
) )
......
...@@ -167,7 +167,6 @@ def get_config(): ...@@ -167,7 +167,6 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=get_data(), dataflow=get_data(),
callbacks=[ModelSaver(keep_freq=0.1)], callbacks=[ModelSaver(keep_freq=0.1)],
session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
steps_per_epoch=500, steps_per_epoch=500,
max_epoch=100, max_epoch=100,
......
...@@ -61,7 +61,6 @@ def get_config(): ...@@ -61,7 +61,6 @@ def get_config():
# use the same data in the DCGAN example # use the same data in the DCGAN example
dataflow=DCGAN.get_data(args.data), dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
) )
......
...@@ -171,7 +171,6 @@ def get_config(): ...@@ -171,7 +171,6 @@ def get_config():
(19, 3e-3), (24, 1e-3), (26, 2e-4), (19, 3e-3), (24, 1e-3), (26, 2e-4),
(30, 5e-5)]) (30, 5e-5)])
], ],
session_config=get_default_sess_config(0.99),
model=Model(), model=Model(),
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=80, max_epoch=80,
......
...@@ -277,7 +277,6 @@ def get_config(): ...@@ -277,7 +277,6 @@ def get_config():
(41, 8e-5), (48, 1e-5), (53, 2e-6)]), (41, 8e-5), (48, 1e-5), (53, 2e-6)]),
HumanHyperParamSetter('learning_rate') HumanHyperParamSetter('learning_rate')
], ],
session_config=get_default_sess_config(0.9),
model=Model(), model=Model(),
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=100, max_epoch=100,
......
...@@ -162,7 +162,8 @@ def get_config(): ...@@ -162,7 +162,8 @@ def get_config():
[ScalarStats('cost'), ClassificationError()]), [ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
], ],
session_config=get_default_sess_config(0.5), session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
max_epoch=500, max_epoch=500,
) )
......
...@@ -114,8 +114,6 @@ def get_config(cifar_classnum): ...@@ -114,8 +114,6 @@ def get_config(cifar_classnum):
dataset_train = get_data('train', cifar_classnum) dataset_train = get_data('train', cifar_classnum)
dataset_test = get_data('test', cifar_classnum) dataset_test = get_data('test', cifar_classnum)
sess_config = get_default_sess_config(0.5)
def lr_func(lr): def lr_func(lr):
if lr < 3e-5: if lr < 3e-5:
raise StopTraining() raise StopTraining()
...@@ -129,7 +127,6 @@ def get_config(cifar_classnum): ...@@ -129,7 +127,6 @@ def get_config(cifar_classnum):
StatMonitorParamSetter('learning_rate', 'val_error', lr_func, StatMonitorParamSetter('learning_rate', 'val_error', lr_func,
threshold=0.001, last_k=10), threshold=0.001, last_k=10),
], ],
session_config=sess_config,
max_epoch=150, max_epoch=150,
) )
......
...@@ -20,7 +20,6 @@ from ..utils.develop import deprecated, log_deprecated ...@@ -20,7 +20,6 @@ from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer'] __all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
...@@ -117,11 +116,11 @@ class Trainer(object): ...@@ -117,11 +116,11 @@ class Trainer(object):
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
# create session # create session
sess_creator = NewSessionCreator(config=self.config.session_config) sess_creator = self.config.session_creator
logger.info("Finalize the graph, create the session ...") logger.info("Finalize the graph, create the session ...")
self.monitored_sess = tf.train.MonitoredSession( self._monitored_sess = tf.train.MonitoredSession(
session_creator=sess_creator, hooks=None) session_creator=sess_creator, hooks=None)
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also self.sess = self._monitored_sess._tf_sess() # expose the underlying session also
# init session # init session
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
...@@ -159,7 +158,7 @@ class Trainer(object): ...@@ -159,7 +158,7 @@ class Trainer(object):
logger.info("Start Epoch {} ...".format(self.epoch_num)) logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time() start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self.monitored_sess.should_stop(): if self._monitored_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
self._callbacks.trigger_step() self._callbacks.trigger_step()
...@@ -179,7 +178,7 @@ class Trainer(object): ...@@ -179,7 +178,7 @@ class Trainer(object):
finally: finally:
self._callbacks.after_train() self._callbacks.after_train()
self.monitors.close() self.monitors.close()
self.monitored_sess.close() self._monitored_sess.close()
# Predictor related methods: TODO # Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
......
...@@ -13,6 +13,7 @@ from ..utils import logger ...@@ -13,6 +13,7 @@ from ..utils import logger
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData from .input_data import InputData
from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
...@@ -30,7 +31,7 @@ class TrainConfig(object): ...@@ -30,7 +31,7 @@ class TrainConfig(object):
model=None, model=None,
callbacks=None, extra_callbacks=None, callbacks=None, extra_callbacks=None,
monitors=None, monitors=None,
session_config=get_default_sess_config(), session_init=None, session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999, starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0], nr_tower=1, tower=None, predict_tower=[0],
**kwargs): **kwargs):
...@@ -47,8 +48,12 @@ class TrainConfig(object): ...@@ -47,8 +48,12 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`. monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``. Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
session_config (tf.ConfigProto): the config used to instantiate the session. session_creator (tf.train.SessionCreator): how to create the
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session. session. Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by
:func:`tfutils.get_default_sess_config()`.
session_init (SessionInit): how to initialize variables of a
session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch. steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size. Defaults to the input data size.
...@@ -99,13 +104,22 @@ class TrainConfig(object): ...@@ -99,13 +104,22 @@ class TrainConfig(object):
self.model = model self.model = model
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
self.session_config = session_config
assert_type(self.session_config, tf.ConfigProto)
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
self.session_init = session_init self.session_init = session_init
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
if session_creator is None:
if session_config is not None:
log_deprecated(
"TrainConfig(session_config=)",
"Use session_creator=NewSessionCreator(config=) instead!", "2017-05-20")
self.session_creator = NewSessionCreator(config=session_config)
else:
self.session_creator = NewSessionCreator(config=get_default_sess_config())
else:
self.session_creator = session_creator
if steps_per_epoch is None: if steps_per_epoch is None:
steps_per_epoch = kwargs.pop('step_per_epoch', None) steps_per_epoch = kwargs.pop('step_per_epoch', None)
if steps_per_epoch is not None: if steps_per_epoch is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment