Commit ccae46f4 authored by Yuxin Wu's avatar Yuxin Wu

add session_creator option in TrainConfig (#191)

parent 09cc8662
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+ 2017/03/16. `session_config` option in `TrainConfig` and `PredictConfig` is deprecated.
Use `session_creator` to define how to create session instead.
+ 2017/02/20. The interface of step callbacks are changed to be the same as `tf.train.SessionRunHook`.
If you haven't written any custom step callbacks, there is nothing to do. Otherwise please refer
to the [existing callbacks](https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/callbacks/steps.py).
......
......@@ -222,7 +222,8 @@ def get_config():
StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
],
session_config=get_default_sess_config(0.5),
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
model=M,
steps_per_epoch=STEPS_PER_EPOCH,
max_epoch=1000,
......
......@@ -108,7 +108,6 @@ def get_config():
return TrainConfig(
dataflow=dataset,
callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=500,
max_epoch=100,
......
......@@ -107,7 +107,6 @@ def get_config():
model=Model(),
dataflow=get_data(args.data),
callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
steps_per_epoch=300,
max_epoch=200,
)
......
......@@ -167,7 +167,6 @@ def get_config():
return TrainConfig(
dataflow=get_data(),
callbacks=[ModelSaver(keep_freq=0.1)],
session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=500,
max_epoch=100,
......
......@@ -61,7 +61,6 @@ def get_config():
# use the same data in the DCGAN example
dataflow=DCGAN.get_data(args.data),
callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
steps_per_epoch=300,
max_epoch=200,
)
......
......@@ -171,7 +171,6 @@ def get_config():
(19, 3e-3), (24, 1e-3), (26, 2e-4),
(30, 5e-5)])
],
session_config=get_default_sess_config(0.99),
model=Model(),
steps_per_epoch=5000,
max_epoch=80,
......
......@@ -277,7 +277,6 @@ def get_config():
(41, 8e-5), (48, 1e-5), (53, 2e-6)]),
HumanHyperParamSetter('learning_rate')
],
session_config=get_default_sess_config(0.9),
model=Model(),
steps_per_epoch=5000,
max_epoch=100,
......
......@@ -162,7 +162,8 @@ def get_config():
[ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
],
session_config=get_default_sess_config(0.5),
session_creator=sesscreate.NewSessionCreator(
config=get_default_sess_config(0.5)),
steps_per_epoch=steps_per_epoch,
max_epoch=500,
)
......
......@@ -114,8 +114,6 @@ def get_config(cifar_classnum):
dataset_train = get_data('train', cifar_classnum)
dataset_test = get_data('test', cifar_classnum)
sess_config = get_default_sess_config(0.5)
def lr_func(lr):
if lr < 3e-5:
raise StopTraining()
......@@ -129,7 +127,6 @@ def get_config(cifar_classnum):
StatMonitorParamSetter('learning_rate', 'val_error', lr_func,
threshold=0.001, last_k=10),
],
session_config=sess_config,
max_epoch=150,
)
......
......@@ -20,7 +20,6 @@ from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
......@@ -117,11 +116,11 @@ class Trainer(object):
self._callbacks.setup_graph(weakref.proxy(self))
# create session
sess_creator = NewSessionCreator(config=self.config.session_config)
sess_creator = self.config.session_creator
logger.info("Finalize the graph, create the session ...")
self.monitored_sess = tf.train.MonitoredSession(
self._monitored_sess = tf.train.MonitoredSession(
session_creator=sess_creator, hooks=None)
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
self.sess = self._monitored_sess._tf_sess() # expose the underlying session also
# init session
init_op = tf.global_variables_initializer()
......@@ -159,7 +158,7 @@ class Trainer(object):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch):
if self.monitored_sess.should_stop():
if self._monitored_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
......@@ -179,7 +178,7 @@ class Trainer(object):
finally:
self._callbacks.after_train()
self.monitors.close()
self.monitored_sess.close()
self._monitored_sess.close()
# Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0):
......
......@@ -13,6 +13,7 @@ from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData
from .monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
......@@ -30,7 +31,7 @@ class TrainConfig(object):
model=None,
callbacks=None, extra_callbacks=None,
monitors=None,
session_config=get_default_sess_config(), session_init=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=[0],
**kwargs):
......@@ -47,8 +48,12 @@ class TrainConfig(object):
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
session_creator (tf.train.SessionCreator): how to create the
session. Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by
:func:`tfutils.get_default_sess_config()`.
session_init (SessionInit): how to initialize variables of a
session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
......@@ -99,13 +104,22 @@ class TrainConfig(object):
self.model = model
assert_type(self.model, ModelDesc)
self.session_config = session_config
assert_type(self.session_config, tf.ConfigProto)
if session_init is None:
session_init = JustCurrentSession()
self.session_init = session_init
assert_type(self.session_init, SessionInit)
if session_creator is None:
if session_config is not None:
log_deprecated(
"TrainConfig(session_config=)",
"Use session_creator=NewSessionCreator(config=) instead!", "2017-05-20")
self.session_creator = NewSessionCreator(config=session_config)
else:
self.session_creator = NewSessionCreator(config=get_default_sess_config())
else:
self.session_creator = session_creator
if steps_per_epoch is None:
steps_per_epoch = kwargs.pop('step_per_epoch', None)
if steps_per_epoch is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment