Commit 4cde005e authored by Yuxin Wu's avatar Yuxin Wu

update docs & small changes

parent b5ac2443
......@@ -87,6 +87,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
# waiting is necessary, otherwise the estimated mean score is biased
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
......
......@@ -26,7 +26,7 @@ def maybe_freeze_affine(getter, *args, **kwargs):
def resnet_argscope():
with argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'), \
argscope(Conv2D, use_bias=False), \
argscope(BatchNorm, training=False, epsilon=0), \
argscope(BatchNorm, training=False), \
custom_getter_scope(maybe_freeze_affine):
yield
......
......@@ -150,8 +150,10 @@ class ImageNetModel(ModelDesc):
"""
weight_decay_on_bn = False
def __init__(self, data_format='NCHW'):
self.data_format = data_format
"""
Either 'NCHW' or 'NHWC'
"""
data_format = 'NCHW'
def inputs(self):
return [tf.placeholder(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'),
......
......@@ -25,9 +25,7 @@ from resnet_model import (
class Model(ImageNetModel):
def __init__(self, depth, data_format='NCHW', mode='resnet'):
super(Model, self).__init__(data_format)
def __init__(self, depth, mode='resnet'):
if mode == 'se':
assert depth >= 50
......@@ -64,17 +62,17 @@ def get_config(model, fake=False):
assert args.batch % nr_tower == 0
batch = args.batch // nr_tower
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
if fake:
logger.info("For benchmark, batch size is fixed to 64 per tower.")
dataset_train = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
[[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8')
callbacks = []
else:
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
BASE_LR = 0.1 * (args.batch / 256.0)
START_LR = 0.1
BASE_LR = START_LR * (args.batch / 256.0)
callbacks = [
ModelSaver(),
EstimatedTimeLeft(),
......@@ -82,10 +80,10 @@ def get_config(model, fake=False):
'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
(90, BASE_LR * 1e-3), (100, BASE_LR * 1e-4)]),
]
if BASE_LR > 0.1:
if BASE_LR > START_LR:
callbacks.append(
ScheduledHyperParamSetter(
'learning_rate', [(0, 0.1), (3, BASE_LR)], interp='linear'))
'learning_rate', [(0, START_LR), (5, BASE_LR)], interp='linear'))
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
......@@ -126,7 +124,8 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
model = Model(args.depth, args.data_format, args.mode)
model = Model(args.depth, args.mode)
model.data_format = args.data_format
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
......
......@@ -39,7 +39,7 @@ Usage:
./CAM-resnet.py --data /path/to/imagenet [--load ImageNet-ResNet18-Preact.npz] [--gpu 0,1,2,3]
```
Pretrained and fine-tuned ResNet can be downloaded
[here](http://models.tensorpack.com/ResNet/) and [here](http://models.tensorpack.com/Visualization/).
in the [model zoo](http://models.tensorpack.com/).
2. Generate CAM on ImageNet validation set:
```bash
......
......@@ -301,7 +301,7 @@ class MapDataComponent(MapData):
r = func(dp[index])
if r is None:
return None
dp = copy(dp) # shallow copy to avoid modifying the list
dp = list(dp) # shallow copy to avoid modifying the list
dp[index] = r
return dp
super(MapDataComponent, self).__init__(ds, f)
......@@ -606,6 +606,9 @@ class CacheData(ProxyDataFlow):
"""
Cache the first pass of a DataFlow completely in memory,
and produce from the cache thereafter.
NOTE: The user should not stop the iterator before it has reached the end.
Otherwise the cache may be incomplete.
"""
def __init__(self, ds, shuffle=False):
"""
......
......@@ -268,7 +268,8 @@ class Lighting(ImageAugmentor):
def _get_augment_params(self, img):
assert img.shape[2] == 3
return self.rng.randn(3) * self.std
ret = self.rng.randn(3) * self.std
return ret.astype('float32')
def _augment(self, img, v):
old_dtype = img.dtype
......
......@@ -403,12 +403,13 @@ class PlasmaPutData(ProxyDataFlow):
Experimental.
"""
def __init__(self, ds):
def __init__(self, ds, socket="/tmp/plasma"):
self._socket = socket
super(PlasmaPutData, self).__init__(ds)
def reset_state(self):
super(PlasmaPutData, self).reset_state()
self.client = plasma.connect("/tmp/plasma", "", 0)
self.client = plasma.connect(self._socket, "", 0)
def get_data(self):
for dp in self.ds.get_data():
......@@ -421,12 +422,13 @@ class PlasmaGetData(ProxyDataFlow):
Take plasma object id from a DataFlow, and retrieve it from plasma shared
memory object store.
"""
def __init__(self, ds):
def __init__(self, ds, socket="/tmp/plasma"):
self._socket = socket
super(PlasmaGetData, self).__init__(ds)
def reset_state(self):
super(PlasmaGetData, self).reset_state()
self.client = plasma.connect("/tmp/plasma", "", 0)
self.client = plasma.connect(self._socket, "", 0)
def get_data(self):
for dp in self.ds.get_data():
......
......@@ -38,8 +38,10 @@ def get_default_sess_config(mem_fraction=0.99):
# Didn't see much difference.
conf.gpu_options.per_process_gpu_memory_fraction = 0.99
if get_tf_version_number() >= 1.2:
conf.gpu_options.force_gpu_compatible = True
# This hurt performance of large data pipeline:
# https://github.com/tensorflow/benchmarks/commit/1528c46499cdcff669b5d7c006b7b971884ad0e6
# conf.gpu_options.force_gpu_compatible = True
conf.gpu_options.allow_growth = True
......@@ -47,7 +49,7 @@ def get_default_sess_config(mem_fraction=0.99):
# conf.graph_options.rewrite_options.memory_optimization = \
# rwc.RewriterConfig.HEURISTICS
# May hurt performance
# May hurt performance?
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
# conf.graph_options.place_pruned_graph = True
return conf
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment