Commit 68edaa0c authored by Yuxin Wu's avatar Yuxin Wu

fix #653, fix #655

parent 3398df09
...@@ -61,6 +61,7 @@ def get_data(name, batch): ...@@ -61,6 +61,7 @@ def get_data(name, batch):
def get_config(model, fake=False): def get_config(model, fake=False):
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
assert args.batch % nr_tower == 0
batch = args.batch // nr_tower batch = args.batch // nr_tower
if fake: if fake:
...@@ -73,14 +74,14 @@ def get_config(model, fake=False): ...@@ -73,14 +74,14 @@ def get_config(model, fake=False):
dataset_train = get_data('train', batch) dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch) dataset_val = get_data('val', batch)
BASE_LR = 0.1 * (args.batch // 256) BASE_LR = 0.1 * (args.batch / 256.0)
callbacks = [ callbacks = [
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2), 'learning_rate', [(30, BASE_LR * 1e-1), (60, BASE_LR * 1e-2),
(85, BASE_LR * 1e-3), (95, BASE_LR * 1e-4), (105, BASE_LR * 1e-5)]), (85, BASE_LR * 1e-3), (95, BASE_LR * 1e-4), (105, BASE_LR * 1e-5)]),
] ]
if BASE_LR != 0.1: if BASE_LR > 0.1:
callbacks.append( callbacks.append(
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', [(0, 0.1), (3, BASE_LR)], interp='linear')) 'learning_rate', [(0, 0.1), (3, BASE_LR)], interp='linear'))
...@@ -115,7 +116,7 @@ if __name__ == '__main__': ...@@ -115,7 +116,7 @@ if __name__ == '__main__':
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152]) type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
parser.add_argument('--batch', help='total batch size. need to be multiple of 256 to get similar accuracy.', parser.add_argument('--batch', help='total batch size. 256 gives best accuracy.',
default=256, type=int) default=256, type=int)
parser.add_argument('--mode', choices=['resnet', 'preact', 'se'], parser.add_argument('--mode', choices=['resnet', 'preact', 'se'],
help='variants of resnet to use', default='resnet') help='variants of resnet to use', default='resnet')
......
...@@ -493,9 +493,10 @@ class StagingInput(FeedfreeInput): ...@@ -493,9 +493,10 @@ class StagingInput(FeedfreeInput):
fetches=[self.stage_op, unstage_op]) fetches=[self.stage_op, unstage_op])
def _prefill(self): def _prefill(self):
logger.info("Pre-filling staging area ...") logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage): for k in range(self.nr_stage):
self.stage_op.run() self.stage_op.run()
logger.info("Put {} element(s) to StagingArea.")
def _before_run(self, ctx): def _before_run(self, ctx):
# This has to happen once, right before the first iteration. # This has to happen once, right before the first iteration.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import os import os
import tensorflow as tf import tensorflow as tf
import multiprocessing as mp
from ..callbacks import RunOp from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
...@@ -339,8 +340,10 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -339,8 +340,10 @@ class HorovodTrainer(SingleCostTrainer):
# NOTE It will fail if GPU was already detected before initializing the session # NOTE It will fail if GPU was already detected before initializing the session
# https://github.com/tensorflow/tensorflow/issues/8136 # https://github.com/tensorflow/tensorflow/issues/8136
session_creator.config.gpu_options.visible_device_list = str(self._local_rank) session_creator.config.gpu_options.visible_device_list = str(self._local_rank)
# TODO split #CPUs try:
# session_creator.config.inter_op_parallelism_threads = session_creator.config.inter_op_parallelism_threads = mp.cpu_count() // hvd.local_size()
except AttributeError:
pass
super(HorovodTrainer, self).initialize( super(HorovodTrainer, self).initialize(
session_creator, session_init) session_creator, session_init)
......
...@@ -8,10 +8,14 @@ import msgpack_numpy ...@@ -8,10 +8,14 @@ import msgpack_numpy
msgpack_numpy.patch() msgpack_numpy.patch()
try: try:
# https://github.com/apache/arrow/pull/1223#issuecomment-359895666
import sys import sys
old_mod = sys.modules.get('torch', None)
sys.modules['torch'] = None sys.modules['torch'] = None
# https://github.com/apache/arrow/pull/1223#issuecomment-359895666
import pyarrow as pa import pyarrow as pa
if old_mod is not None:
sys.modules['torch'] = old_mod
else:
del sys.modules['torch'] del sys.modules['torch']
except ImportError: except ImportError:
pa = None pa = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment