Commit 353bb6c5 authored by Yuxin Wu's avatar Yuxin Wu

minor changes in examples and TrainConfig

parent 82afb459
...@@ -140,12 +140,26 @@ def get_data(train_or_test): ...@@ -140,12 +140,26 @@ def get_data(train_or_test):
return ds return ds
def get_config(): if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-n', '--num_units',
help='number of units in each stage',
type=int, default=18)
parser.add_argument('--load', help='load model')
args = parser.parse_args()
NUM_UNITS = args.num_units
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir() logger.auto_set_dir()
dataset_train = get_data('train') dataset_train = get_data('train')
dataset_test = get_data('test') dataset_test = get_data('test')
return TrainConfig( config = TrainConfig(
model=Model(n=NUM_UNITS),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -154,26 +168,8 @@ def get_config(): ...@@ -154,26 +168,8 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
], ],
model=Model(n=NUM_UNITS),
max_epoch=400, max_epoch=400,
nr_tower=max(get_nr_gpu(), 1),
session_init=SaverRestore(args.load) if args.load else None
) )
SyncMultiGPUTrainerParameterServer(config).train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('-n', '--num_units',
help='number of units in each stage',
type=int, default=18)
parser.add_argument('--load', help='load model')
args = parser.parse_args()
NUM_UNITS = args.num_units
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
...@@ -58,15 +58,21 @@ def get_data(train_or_test): ...@@ -58,15 +58,21 @@ def get_data(train_or_test):
return ds return ds
def get_config(): if __name__ == '__main__':
logger.auto_set_dir() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
# prepare dataset if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
logger.auto_set_dir()
dataset_train = get_data('train') dataset_train = get_data('train')
steps_per_epoch = dataset_train.size()
dataset_test = get_data('test') dataset_test = get_data('test')
return TrainConfig( config = TrainConfig(
model=Model(n=18),
dataflow=dataset_train, dataflow=dataset_train,
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
...@@ -75,23 +81,8 @@ def get_config(): ...@@ -75,23 +81,8 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)]) [(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
], ],
model=Model(n=18), nr_tower=max(get_nr_gpu(), 1),
steps_per_epoch=steps_per_epoch, session_init=SaverRestore(args.load) if args.load else None,
max_epoch=500, max_epoch=500,
) )
SyncMultiGPUTrainerParameterServer(config).train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(get_nr_gpu(), 1)
SyncMultiGPUTrainer(config).train()
...@@ -40,7 +40,8 @@ class SVHNDigit(RNGDataFlow): ...@@ -40,7 +40,8 @@ class SVHNDigit(RNGDataFlow):
filename = os.path.join(data_dir, name + '_32x32.mat') filename = os.path.join(data_dir, name + '_32x32.mat')
if not os.path.isfile(filename): if not os.path.isfile(filename):
url = SVHN_URL + os.path.basename(filename) url = SVHN_URL + os.path.basename(filename)
logger.info("File {} not found! Downloading from {}.".format(filename, url)) logger.info("File {} not found!".format(filename))
logger.info("Downloading from {}.".format(url))
download(url, os.path.dirname(filename)) download(url, os.path.dirname(filename))
logger.info("Loading {} ...".format(filename)) logger.info("Loading {} ...".format(filename))
data = scipy.io.loadmat(filename) data = scipy.io.loadmat(filename)
......
...@@ -118,11 +118,13 @@ class TrainConfig(object): ...@@ -118,11 +118,13 @@ class TrainConfig(object):
if steps_per_epoch is None: if steps_per_epoch is None:
try: try:
if dataflow is not None: if dataflow is not None:
steps_per_epoch = self.dataflow.size() steps_per_epoch = dataflow.size()
elif data is not None:
steps_per_epoch = data.size()
else: else:
steps_per_epoch = self.data.size() raise NotImplementedError()
except NotImplementedError: except NotImplementedError:
logger.exception("You must set `steps_per_epoch` if dataset.size() is not implemented.") logger.exception("You must set `TrainConfig(steps_per_epoch)` if data.size() is not available.")
else: else:
steps_per_epoch = int(steps_per_epoch) steps_per_epoch = int(steps_per_epoch)
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
...@@ -131,6 +133,7 @@ class TrainConfig(object): ...@@ -131,6 +133,7 @@ class TrainConfig(object):
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0 assert self.steps_per_epoch > 0 and self.max_epoch > 0
nr_tower = max(nr_tower, 1)
self.nr_tower = nr_tower self.nr_tower = nr_tower
if tower is not None: if tower is not None:
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!" assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment