Commit 26e609f8 authored by Yuxin Wu's avatar Yuxin Wu

minor changes

parent c81ea087
...@@ -94,7 +94,7 @@ def get_config(model, fake=False): ...@@ -94,7 +94,7 @@ def get_config(model, fake=False):
model=model, model=model,
dataflow=dataset_train, dataflow=dataset_train,
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=5000, steps_per_epoch=100 if args.fake else 5000, # 5000 ~= 1.28M / TOTAL_BATCH_SIZE
max_epoch=110, max_epoch=110,
nr_tower=nr_tower nr_tower=nr_tower
) )
...@@ -124,8 +124,11 @@ if __name__ == '__main__': ...@@ -124,8 +124,11 @@ if __name__ == '__main__':
ds = get_data('val', batch) ds = get_data('val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds) eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
else: else:
logger.set_logger_dir( if args.fake:
os.path.join('train_log', 'imagenet-resnet-d' + str(args.depth))) logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
else:
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(args.depth)))
config = get_config(model, fake=args.fake) config = get_config(model, fake=args.fake)
if args.load: if args.load:
......
...@@ -92,7 +92,7 @@ def get_imagenet_dataflow( ...@@ -92,7 +92,7 @@ def get_imagenet_dataflow(
assert datadir is not None assert datadir is not None
assert isinstance(augmentors, list) assert isinstance(augmentors, list)
isTrain = name == 'train' isTrain = name == 'train'
cpu = min(30, multiprocessing.cpu_count()) cpu = min(40, multiprocessing.cpu_count())
if isTrain: if isTrain:
ds = dataset.ILSVRC12(datadir, name, shuffle=True) ds = dataset.ILSVRC12(datadir, name, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
...@@ -213,3 +213,21 @@ class ImageNetModel(ModelDesc): ...@@ -213,3 +213,21 @@ class ImageNetModel(ModelDesc):
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5') wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5')) add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
return loss return loss
if __name__ == '__main__':
import argparse
from tensorpack.dataflow import TestDataSpeed
parser = argparse.ArgumentParser()
parser.add_argument('--data', required=True)
parser.add_argument('--batch', type=int, default=32)
args = parser.parse_args()
augs = fbresnet_augmentor(False)
augs = [imgaug.ResizeShortestEdge(256),
imgaug.CenterCrop(224)
]
df = get_imagenet_dataflow(
args.data, 'train', args.batch, augs)
TestDataSpeed(df).start()
...@@ -17,7 +17,6 @@ from ..callbacks import ( ...@@ -17,7 +17,6 @@ from ..callbacks import (
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.scope_utils import cached_name_scope from ..tfutils.scope_utils import cached_name_scope
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu from ..utils.gpu import get_nr_gpu
...@@ -108,8 +107,6 @@ def setup_keras_trainer( ...@@ -108,8 +107,6 @@ def setup_keras_trainer(
nr_inputs = len(inputs_desc) nr_inputs = len(inputs_desc)
def get_cost(*inputs): def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(targets_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(targets_desc))
ctx = get_current_tower_context() ctx = get_current_tower_context()
input_tensors = list(inputs[:nr_inputs]) input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:]) target_tensors = list(inputs[nr_inputs:])
......
...@@ -263,7 +263,7 @@ class MapData(ProxyDataFlow): ...@@ -263,7 +263,7 @@ class MapData(ProxyDataFlow):
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
ret = self.func(dp) ret = self.func(copy(dp)) # shallow copy the list
if ret is not None: if ret is not None:
yield ret yield ret
...@@ -292,7 +292,7 @@ class MapDataComponent(MapData): ...@@ -292,7 +292,7 @@ class MapDataComponent(MapData):
r = func(dp[index]) r = func(dp[index])
if r is None: if r is None:
return None return None
dp = copy(dp) # avoid modifying the list dp = copy(dp) # shallow copy to avoid modifying the list
dp[index] = r dp[index] = r
return dp return dp
super(MapDataComponent, self).__init__(ds, f) super(MapDataComponent, self).__init__(ds, f)
......
...@@ -96,7 +96,8 @@ def cached_name_scope(name, top_level=True): ...@@ -96,7 +96,8 @@ def cached_name_scope(name, top_level=True):
""" """
if not top_level: if not top_level:
current_ns = tf.get_default_graph().get_name_scope() current_ns = tf.get_default_graph().get_name_scope()
name = current_ns + '/' + name if current_ns:
name = current_ns + '/' + name
ns = _get_cached_ns(name) ns = _get_cached_ns(name)
with tf.name_scope(ns): with tf.name_scope(ns):
yield ns yield ns
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment