Commit f5d5d4c2 authored by Yuxin Wu's avatar Yuxin Wu

fix SimpleTrainer and use longer test survival limit

parent 3b7e7c55
...@@ -64,7 +64,7 @@ class Trainer(object): ...@@ -64,7 +64,7 @@ class Trainer(object):
self.monitors = [] self.monitors = []
self._epoch_num = None self._epoch_num = None
self._setup() # subclass will setup the graph self._setup() # subclass will setup the graph and InputSource
@property @property
def epoch_num(self): def epoch_num(self):
......
...@@ -63,7 +63,6 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase): ...@@ -63,7 +63,6 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self._input_source = config.data self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker') self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config)
worker_prefix = '/job:worker/task:%s' % self.task_index worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter( self.param_server_device = tf.train.replica_device_setter(
...@@ -79,6 +78,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase): ...@@ -79,6 +78,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)] self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0 self.sync_queue_counter = 0
super(DistributedReplicatedTrainer, self).__init__(config)
@staticmethod @staticmethod
def _average_grads(tower_grads, devices): def _average_grads(tower_grads, devices):
""" """
......
...@@ -23,11 +23,9 @@ class SimpleTrainer(Trainer): ...@@ -23,11 +23,9 @@ class SimpleTrainer(Trainer):
Args: Args:
config (TrainConfig): the training config. config (TrainConfig): the training config.
""" """
super(SimpleTrainer, self).__init__(config) assert len(config.tower) == 1, \
assert len(self.config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \ "Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(self.config.tower)) " Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
if config.dataflow is None: if config.dataflow is None:
self._input_source = config.data self._input_source = config.data
...@@ -35,6 +33,7 @@ class SimpleTrainer(Trainer): ...@@ -35,6 +33,7 @@ class SimpleTrainer(Trainer):
self._input_source = FeedInput(config.dataflow) self._input_source = FeedInput(config.dataflow)
logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). " logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead.") "Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config)
def run_step(self): def run_step(self):
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
......
...@@ -21,12 +21,12 @@ class PythonScript(threading.Thread): ...@@ -21,12 +21,12 @@ class PythonScript(threading.Thread):
p: process handle p: process handle
timeout (int): timeout in seconds timeout (int): timeout in seconds
""" """
def __init__(self, cmd, timeout=10): def __init__(self, cmd, timeout):
"""Prepare a python script """Prepare a python script
Args: Args:
cmd (TYPE): command to execute the example with all flags (including python) cmd (str): command to execute the example with all flags (including python)
timeout (int, optional): time in seconds the script has to survive timeout (int): time in seconds the script has to survive
""" """
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.cmd = cmd self.cmd = cmd
...@@ -51,7 +51,7 @@ class PythonScript(threading.Thread): ...@@ -51,7 +51,7 @@ class PythonScript(threading.Thread):
self.join() self.join()
else: else:
# something unexpected happend here, this script was supposed to survive at least the timeout # something unexpected happend here, this script was supposed to survive at least the timeout
if len(self.err) is not 0: if len(self.err) > 0:
output = u"STDOUT: \n\n\n" + self.out.decode('utf-8') output = u"STDOUT: \n\n\n" + self.out.decode('utf-8')
output += u"\n\n\n STDERR: \n\n\n" + self.err.decode('utf-8') output += u"\n\n\n STDERR: \n\n\n" + self.err.decode('utf-8')
raise AssertionError(output) raise AssertionError(output)
...@@ -70,7 +70,7 @@ class TestPythonScript(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestPythonScript(unittest.TestCase):
if os.path.isdir(os.path.join("train_log", script)): if os.path.isdir(os.path.join("train_log", script)):
shutil.rmtree(os.path.join("train_log", script)) shutil.rmtree(os.path.join("train_log", script))
def assertSurvive(self, script, args=None, timeout=10): # noqa def assertSurvive(self, script, args=None, timeout=20): # noqa
cmd = "python{} {}".format(sys.version_info.major, script) cmd = "python{} {}".format(sys.version_info.major, script)
if args: if args:
cmd += " " + " ".join(args) cmd += " " + " ".join(args)
......
...@@ -20,7 +20,7 @@ class CharRNNTest(TestPythonScript): ...@@ -20,7 +20,7 @@ class CharRNNTest(TestPythonScript):
f.write(random_content()) f.write(random_content())
def test(self): def test(self):
self.assertSurvive(self.script, args=['--gpu 0', 'train'], timeout=10) self.assertSurvive(self.script, args=['train'])
def tearDown(self): def tearDown(self):
super(CharRNNTest, self).tearDown() super(CharRNNTest, self).tearDown()
......
...@@ -8,4 +8,4 @@ class InfoGANTest(TestPythonScript): ...@@ -8,4 +8,4 @@ class InfoGANTest(TestPythonScript):
return '../examples/GAN/InfoGAN-mnist.py' return '../examples/GAN/InfoGAN-mnist.py'
def test(self): def test(self):
self.assertSurvive(self.script, args=None, timeout=10) self.assertSurvive(self.script, args=None)
...@@ -8,4 +8,4 @@ class MnistTest(TestPythonScript): ...@@ -8,4 +8,4 @@ class MnistTest(TestPythonScript):
return '../examples/mnist-convnet.py' return '../examples/mnist-convnet.py'
def test(self): def test(self):
self.assertSurvive(self.script, args=None, timeout=10) self.assertSurvive(self.script, args=None)
from case_script import TestPythonScript from case_script import TestPythonScript
import os
import shutil
class ResnetTest(TestPythonScript): class ResnetTest(TestPythonScript):
@property @property
def script(self): def script(self):
return '../examples/ResNet/imagenet-resnet.py' return '../examples/ResNet/imagenet-resnet.py'
def test(self): def test(self):
self.assertSurvive(self.script, args=['--data .', self.assertSurvive(
'--gpu 0', '--fake', '--data_format NHWC'], timeout=10) self.script,
args=['--fake', '--data_format NHWC'], timeout=20)
def tearDown(self):
super(ResnetTest, self).tearDown()
if os.path.isdir('ilsvrc'):
shutil.rmtree('ilsvrc')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment