Commit 26e609f8 authored by Yuxin Wu's avatar Yuxin Wu

minor changes

parent c81ea087
......@@ -94,7 +94,7 @@ def get_config(model, fake=False):
model=model,
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=5000,
steps_per_epoch=100 if args.fake else 5000, # 5000 ~= 1.28M / TOTAL_BATCH_SIZE
max_epoch=110,
nr_tower=nr_tower
)
......@@ -123,6 +123,9 @@ if __name__ == '__main__':
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
else:
if args.fake:
logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
else:
logger.set_logger_dir(
os.path.join('train_log', 'imagenet-resnet-d' + str(args.depth)))
......
......@@ -92,7 +92,7 @@ def get_imagenet_dataflow(
assert datadir is not None
assert isinstance(augmentors, list)
isTrain = name == 'train'
cpu = min(30, multiprocessing.cpu_count())
cpu = min(40, multiprocessing.cpu_count())
if isTrain:
ds = dataset.ILSVRC12(datadir, name, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
......@@ -213,3 +213,21 @@ class ImageNetModel(ModelDesc):
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
return loss
if __name__ == '__main__':
import argparse
from tensorpack.dataflow import TestDataSpeed
parser = argparse.ArgumentParser()
parser.add_argument('--data', required=True)
parser.add_argument('--batch', type=int, default=32)
args = parser.parse_args()
augs = fbresnet_augmentor(False)
augs = [imgaug.ResizeShortestEdge(256),
imgaug.CenterCrop(224)
]
df = get_imagenet_dataflow(
args.data, 'train', args.batch, augs)
TestDataSpeed(df).start()
......@@ -17,7 +17,6 @@ from ..callbacks import (
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..tfutils.scope_utils import cached_name_scope
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
......@@ -108,8 +107,6 @@ def setup_keras_trainer(
nr_inputs = len(inputs_desc)
def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(targets_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(targets_desc))
ctx = get_current_tower_context()
input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:])
......
......@@ -263,7 +263,7 @@ class MapData(ProxyDataFlow):
def get_data(self):
for dp in self.ds.get_data():
ret = self.func(dp)
ret = self.func(copy(dp)) # shallow copy the list
if ret is not None:
yield ret
......@@ -292,7 +292,7 @@ class MapDataComponent(MapData):
r = func(dp[index])
if r is None:
return None
dp = copy(dp) # avoid modifying the list
dp = copy(dp) # shallow copy to avoid modifying the list
dp[index] = r
return dp
super(MapDataComponent, self).__init__(ds, f)
......
......@@ -96,6 +96,7 @@ def cached_name_scope(name, top_level=True):
"""
if not top_level:
current_ns = tf.get_default_graph().get_name_scope()
if current_ns:
name = current_ns + '/' + name
ns = _get_cached_ns(name)
with tf.name_scope(ns):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment