Commit 09cc8662 authored by Yuxin Wu's avatar Yuxin Wu

stop using scaffold, so that custom session creation is possible (#191)

parent 13e0ec39
...@@ -39,7 +39,7 @@ Describe your training task with three components: ...@@ -39,7 +39,7 @@ Describe your training task with three components:
+ Allows you to process data in Python without blocking the training, by multiprocess prefetch & TF Queue prefetch. + Allows you to process data in Python without blocking the training, by multiprocess prefetch & TF Queue prefetch.
+ All data producer has a unified interface, you can compose and reuse them to perform complex preprocessing. + All data producer has a unified interface, you can compose and reuse them to perform complex preprocessing.
2. __Callbacks__, customizable, like `tf.train.SessionRunHook` but more than that. Includes everything you want to do apart from the training iterations, such as: 2. __Callbacks__, like `tf.train.SessionRunHook`, plugins, or extensions. Write a callback to implement everything you want to do apart from the training iterations, such as:
+ Change hyperparameters during training + Change hyperparameters during training
+ Print some tensors of interest + Print some tensors of interest
+ Run inference on a test dataset + Run inference on a test dataset
...@@ -51,7 +51,6 @@ Describe your training task with three components: ...@@ -51,7 +51,6 @@ Describe your training task with three components:
`LinearWrap` and `argscope` simplify large models (e.g. [vgg example](https://github.com/ppwwyyxx/tensorpack/blob/master/examples/load-vgg16.py)). `LinearWrap` and `argscope` simplify large models (e.g. [vgg example](https://github.com/ppwwyyxx/tensorpack/blob/master/examples/load-vgg16.py)).
With the above components defined, tensorpack trainer runs the training iterations for you. With the above components defined, tensorpack trainer runs the training iterations for you.
Trainer was written with performance in mind:
Even on a small CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code. Even on a small CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code.
Multi-GPU training is off-the-shelf by simply switching the trainer. Multi-GPU training is off-the-shelf by simply switching the trainer.
......
# Dataflow # Dataflow
Dataflow is an interface to produce data. Dataflow is a library to help you build Python iterators to load data.
A Dataflow has a `get_data()` generator method, A Dataflow has a `get_data()` generator method,
which yields a `datapoint` when called. which yields `datapoints`.
A datapoint must be a **list** of Python objects which I called the `components` of this datapoint. A datapoint must be a **list** of Python objects which I called the `components` of a datapoint.
For example, to train on MNIST dataset, you can build a Dataflow For example, to train on MNIST dataset, you can build a Dataflow with a `get_data()` method
that yields datapoints of two elements (components): that yields datapoints of two elements (components):
a numpy array of shape (64, 28, 28), and an array of shape (64,). a numpy array of shape (64, 28, 28), and an array of shape (64,).
...@@ -17,7 +17,7 @@ the greatest code reusablility. ...@@ -17,7 +17,7 @@ the greatest code reusablility.
There are a lot of existing modules in tensorpack which you can use to compose There are a lot of existing modules in tensorpack which you can use to compose
complex Dataflow instances with a long pre-processing pipeline. A whole pipeline usually complex Dataflow instances with a long pre-processing pipeline. A whole pipeline usually
would __read from disk (or other sources), apply augmentations, group into batches, would __read from disk (or other sources), apply augmentations, group into batches,
prefetch data__, etc. An example is as the following: prefetch data__, etc. A simple example is as the following:
````python ````python
# define a Dataflow which produces image-label pairs from a caffe lmdb database # define a Dataflow which produces image-label pairs from a caffe lmdb database
...@@ -26,7 +26,7 @@ df = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False) ...@@ -26,7 +26,7 @@ df = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False)
df = AugmentImageComponent(df, [imgaug.Resize((225, 225))]) df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128 # group data into batches of size 128
df = BatchData(df, 128) df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel, and transfer data with ZeroMQ # start 3 processes to run the dataflow in parallel, and communicate with ZeroMQ
df = PrefetchDataZMQ(df, 3) df = PrefetchDataZMQ(df, 3)
```` ````
A more complicated example is the [ResNet training script](../examples/ResNet/imagenet-resnet.py) A more complicated example is the [ResNet training script](../examples/ResNet/imagenet-resnet.py)
......
...@@ -20,6 +20,7 @@ from ..utils.develop import deprecated, log_deprecated ...@@ -20,6 +20,7 @@ from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer'] __all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
...@@ -115,22 +116,20 @@ class Trainer(object): ...@@ -115,22 +116,20 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph() # create session
sess_creator = NewSessionCreator(config=self.config.session_config)
def after_init(scaffold, sess):
logger.info("Graph variables initialized.")
self.config.session_init._run_init(sess)
scaffold = tf.train.Scaffold(
init_op=tf.global_variables_initializer(),
init_fn=after_init)
logger.info("Finalize the graph, create the session ...") logger.info("Finalize the graph, create the session ...")
self.monitored_sess = tf.train.MonitoredSession( self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator( session_creator=sess_creator, hooks=None)
scaffold=scaffold, config=self.config.session_config),
hooks=None)
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
# init session
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
logger.info("Graph variables initialized.")
self.config.session_init.init(self.sess)
self.sess.graph.finalize()
hooks = self._callbacks.get_hooks() hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks) self.hooked_sess = HookedSession(self.sess, hooks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment