Commit 09cc8662 authored by Yuxin Wu's avatar Yuxin Wu

stop using scaffold, so that custom session creation is possible (#191)

parent 13e0ec39
......@@ -39,7 +39,7 @@ Describe your training task with three components:
+ Allows you to process data in Python without blocking the training, by multiprocess prefetch & TF Queue prefetch.
+ All data producer has a unified interface, you can compose and reuse them to perform complex preprocessing.
2. __Callbacks__, customizable, like `tf.train.SessionRunHook` but more than that. Includes everything you want to do apart from the training iterations, such as:
2. __Callbacks__, like `tf.train.SessionRunHook`, plugins, or extensions. Write a callback to implement everything you want to do apart from the training iterations, such as:
+ Change hyperparameters during training
+ Print some tensors of interest
+ Run inference on a test dataset
......@@ -51,7 +51,6 @@ Describe your training task with three components:
`LinearWrap` and `argscope` simplify large models (e.g. [vgg example](https://github.com/ppwwyyxx/tensorpack/blob/master/examples/load-vgg16.py)).
With the above components defined, tensorpack trainer runs the training iterations for you.
Trainer was written with performance in mind:
Even on a small CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code.
Multi-GPU training is off-the-shelf by simply switching the trainer.
......
# Dataflow
Dataflow is an interface to produce data.
Dataflow is a library to help you build Python iterators to load data.
A Dataflow has a `get_data()` generator method,
which yields a `datapoint` when called.
A datapoint must be a **list** of Python objects which I called the `components` of this datapoint.
which yields `datapoints`.
A datapoint must be a **list** of Python objects which I called the `components` of a datapoint.
For example, to train on MNIST dataset, you can build a Dataflow
For example, to train on MNIST dataset, you can build a Dataflow with a `get_data()` method
that yields datapoints of two elements (components):
a numpy array of shape (64, 28, 28), and an array of shape (64,).
......@@ -17,7 +17,7 @@ the greatest code reusablility.
There are a lot of existing modules in tensorpack which you can use to compose
complex Dataflow instances with a long pre-processing pipeline. A whole pipeline usually
would __read from disk (or other sources), apply augmentations, group into batches,
prefetch data__, etc. An example is as the following:
prefetch data__, etc. A simple example is as the following:
````python
# define a Dataflow which produces image-label pairs from a caffe lmdb database
......@@ -26,7 +26,7 @@ df = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False)
df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128
df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel, and transfer data with ZeroMQ
# start 3 processes to run the dataflow in parallel, and communicate with ZeroMQ
df = PrefetchDataZMQ(df, 3)
````
A more complicated example is the [ResNet training script](../examples/ResNet/imagenet-resnet.py)
......
......@@ -20,6 +20,7 @@ from ..utils.develop import deprecated, log_deprecated
from ..callbacks import Callback, Callbacks, MaintainStepCounter
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.sesscreate import NewSessionCreator
__all__ = ['Trainer', 'StopTraining', 'MultiPredictorTowerTrainer']
......@@ -115,22 +116,20 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self.config.session_init._setup_graph()
def after_init(scaffold, sess):
logger.info("Graph variables initialized.")
self.config.session_init._run_init(sess)
scaffold = tf.train.Scaffold(
init_op=tf.global_variables_initializer(),
init_fn=after_init)
# create session
sess_creator = NewSessionCreator(config=self.config.session_config)
logger.info("Finalize the graph, create the session ...")
self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config),
hooks=None)
session_creator=sess_creator, hooks=None)
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
# init session
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
logger.info("Graph variables initialized.")
self.config.session_init.init(self.sess)
self.sess.graph.finalize()
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment