Commit bbaf8d12 authored by Yuxin Wu's avatar Yuxin Wu

use Variable.load to avoid assign ops

parent d97c0081
# Efficient DataFlow
This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet
dataset as an example.
Our goal in the end is to have
a generator which yields ImageNet datapoints (after proper preprocessing) as fast as possible.
Since it is simply a generator interface, you can use the DataFlow in other frameworks (e.g. Keras)
or your own code as well.
We use ILSVRC12 training set, which contains 1.28 million images.
Following the [ResNet example](../examples/ResNet), our pre-processing need images in their original resolution,
......@@ -120,7 +121,7 @@ It will generate a database file of 140G. We build a DataFlow to read the LMDB f
```
from tensorpack import *
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = BatchData(ds, 256, allow_list=True)
ds = BatchData(ds, 256, use_list=True)
TestDataSpeed(ds).start_test()
```
Depending on whether the OS has cached the file for you (and how large the RAM is), the above script
......@@ -134,7 +135,7 @@ As a reference, on Samsung SSD 850, the uncached speed is about 16it/s.
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = LocallyShuffleData(ds, 50000)
ds = BatchData(ds, 256, allow_list=True)
ds = BatchData(ds, 256, use_list=True)
```
Instead of shuffling all the training data in every epoch (which would require random read),
the added line above maintains a buffer of datapoints and shuffle them once a while.
......@@ -153,7 +154,7 @@ Then we add necessary transformations:
ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = BatchData(ds, 256)
```
1. `LMDBData` deserialized the datapoints (from string to [jpeg_string, label])
1. `LMDBData` deserialize the datapoints (from string to [jpeg_string, label])
2. Use opencv to decode the first component into ndarray
3. Apply augmentations to the ndarray
......@@ -172,7 +173,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = BatchData(ds, 256)
```
Since we are reading the database sequentially, have multiple identical instances of the
Since we are reading the database sequentially, having multiple identical instances of the
underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to
launch the underlying DataFlow in one independent process, and only parallelize the transformations.
(`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This is [issue#138])
......
......@@ -20,9 +20,9 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
"""
DCGAN on CelebA dataset.
The original code (dcgan.torch) uses kernel_shape=4, but I found the difference not significant.
1. Download the 'aligned&cropped' version of CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset
(or just use any directory of jpg files).
2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model:
......
......@@ -132,7 +132,6 @@ def get_data(train_or_test):
crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google.
"""
def _augment(self, img, _):
h, w = img.shape[:2]
area = h * w
......
......@@ -74,13 +74,9 @@ class GraphVarParam(HyperParam):
else:
raise ValueError("{} is not a VARIABLE in the graph!".format(self.var_name))
self.val_holder = tf.placeholder(tf.float32, shape=self.shape,
name=self._readable_name + '_feed')
self.assign_op = self.var.assign(self.val_holder)
def set_value(self, v):
""" Assign the variable a new value. """
self.assign_op.eval(feed_dict={self.val_holder: v})
self.var.load(v)
def get_value(self):
""" Evaluate the variable. """
......
......@@ -146,7 +146,8 @@ class ILSVRC12(RNGDataFlow):
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
"""
assert name in ['train', 'test', 'val']
assert name in ['train', 'test', 'val'], name
assert os.path.isdir(dir), dir
self.full_dir = os.path.join(dir, name)
self.name = name
assert os.path.isdir(self.full_dir), self.full_dir
......
......@@ -54,14 +54,10 @@ class SessionUpdate(object):
vars_to_update: a collection of variables to update
"""
self.sess = sess
self.assign_ops = defaultdict(list)
self.name_map = defaultdict(list)
for v in vars_to_update:
# p = tf.placeholder(v.dtype, shape=v.get_shape())
with tf.device('/cpu:0'):
p = tf.placeholder(v.dtype)
savename = get_savename_from_varname(v.name)
# multiple vars might share one savename
self.assign_ops[savename].append((p, v, v.assign(p)))
savename = get_savename_from_varname(v.name)
self.name_map[savename].append(v)
def update(self, prms):
"""
......@@ -70,8 +66,8 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update.
"""
for name, value in six.iteritems(prms):
assert name in self.assign_ops
for p, v, op in self.assign_ops[name]:
assert name in self.name_map
for v in self.name_map[name]:
varshape = tuple(v.get_shape().as_list())
if varshape != value.shape:
# TODO only allow reshape when shape different by empty axis
......@@ -79,7 +75,7 @@ class SessionUpdate(object):
"{}: {}!={}".format(name, varshape, value.shape)
logger.warn("Param {} is reshaped during assigning".format(name))
value = value.reshape(varshape)
self.sess.run(op, feed_dict={p: value})
v.load(value, session=self.sess)
def dump_session_params(path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment