Commit 353bb6c5 authored by Yuxin Wu's avatar Yuxin Wu

minor changes in examples and TrainConfig

parent 82afb459
......@@ -140,12 +140,26 @@ def get_data(train_or_test):
return ds
def get_config():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-n', '--num_units',
help='number of units in each stage',
type=int, default=18)
parser.add_argument('--load', help='load model')
args = parser.parse_args()
NUM_UNITS = args.num_units
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir()
dataset_train = get_data('train')
dataset_test = get_data('test')
return TrainConfig(
config = TrainConfig(
model=Model(n=NUM_UNITS),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
......@@ -154,26 +168,8 @@ def get_config():
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
],
model=Model(n=NUM_UNITS),
max_epoch=400,
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-n', '--num_units',
help='number of units in each stage',
type=int, default=18)
parser.add_argument('--load', help='load model')
args = parser.parse_args()
NUM_UNITS = args.num_units
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
SyncMultiGPUTrainerParameterServer(config).train()
......@@ -58,15 +58,21 @@ def get_data(train_or_test):
return ds
def get_config():
logger.auto_set_dir()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
# prepare dataset
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir()
dataset_train = get_data('train')
steps_per_epoch = dataset_train.size()
dataset_test = get_data('test')
return TrainConfig(
config = TrainConfig(
model=Model(n=18),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
......@@ -75,23 +81,8 @@ def get_config():
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
],
model=Model(n=18),
steps_per_epoch=steps_per_epoch,
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None,
max_epoch=500,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
SyncMultiGPUTrainerParameterServer(config).train()
......@@ -40,7 +40,8 @@ class SVHNDigit(RNGDataFlow):
filename = os.path.join(data_dir, name + '_32x32.mat')
if not os.path.isfile(filename):
url = SVHN_URL + os.path.basename(filename)
logger.info("File {} not found! Downloading from {}.".format(filename, url))
logger.info("File {} not found!".format(filename))
logger.info("Downloading from {}.".format(url))
download(url, os.path.dirname(filename))
logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename)
......
......@@ -118,11 +118,13 @@ class TrainConfig(object):
if steps_per_epoch is None:
try:
if dataflow is not None:
steps_per_epoch = self.dataflow.size()
steps_per_epoch = dataflow.size()
elif data is not None:
steps_per_epoch = data.size()
else:
steps_per_epoch = self.data.size()
raise NotImplementedError()
except NotImplementedError:
logger.exception("You must set `steps_per_epoch` if dataset.size() is not implemented.")
logger.exception("You must set `TrainConfig(steps_per_epoch)` if data.size() is not available.")
else:
steps_per_epoch = int(steps_per_epoch)
self.steps_per_epoch = steps_per_epoch
......@@ -131,6 +133,7 @@ class TrainConfig(object):
self.max_epoch = int(max_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0
nr_tower = max(nr_tower, 1)
self.nr_tower = nr_tower
if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment