Commit f5d5d4c2 authored by Yuxin Wu's avatar Yuxin Wu

fix SimpleTrainer and use longer test survival limit

parent 3b7e7c55
......@@ -64,7 +64,7 @@ class Trainer(object):
self.monitors = []
self._epoch_num = None
self._setup() # subclass will setup the graph
self._setup() # subclass will setup the graph and InputSource
@property
def epoch_num(self):
......
......@@ -63,7 +63,6 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
super(DistributedReplicatedTrainer, self).__init__(config)
worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter(
......@@ -79,6 +78,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0
super(DistributedReplicatedTrainer, self).__init__(config)
@staticmethod
def _average_grads(tower_grads, devices):
"""
......
......@@ -23,11 +23,9 @@ class SimpleTrainer(Trainer):
Args:
config (TrainConfig): the training config.
"""
super(SimpleTrainer, self).__init__(config)
assert len(self.config.tower) == 1, \
assert len(config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(self.config.tower))
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
if config.dataflow is None:
self._input_source = config.data
......@@ -35,6 +33,7 @@ class SimpleTrainer(Trainer):
self._input_source = FeedInput(config.dataflow)
logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config)
def run_step(self):
self.hooked_sess.run(self.train_op)
......
......@@ -21,12 +21,12 @@ class PythonScript(threading.Thread):
p: process handle
timeout (int): timeout in seconds
"""
def __init__(self, cmd, timeout=10):
def __init__(self, cmd, timeout):
"""Prepare a python script
Args:
cmd (TYPE): command to execute the example with all flags (including python)
timeout (int, optional): time in seconds the script has to survive
cmd (str): command to execute the example with all flags (including python)
timeout (int): time in seconds the script has to survive
"""
threading.Thread.__init__(self)
self.cmd = cmd
......@@ -51,7 +51,7 @@ class PythonScript(threading.Thread):
self.join()
else:
# something unexpected happend here, this script was supposed to survive at least the timeout
if len(self.err) is not 0:
if len(self.err) > 0:
output = u"STDOUT: \n\n\n" + self.out.decode('utf-8')
output += u"\n\n\n STDERR: \n\n\n" + self.err.decode('utf-8')
raise AssertionError(output)
......@@ -70,7 +70,7 @@ class TestPythonScript(unittest.TestCase):
if os.path.isdir(os.path.join("train_log", script)):
shutil.rmtree(os.path.join("train_log", script))
def assertSurvive(self, script, args=None, timeout=10): # noqa
def assertSurvive(self, script, args=None, timeout=20): # noqa
cmd = "python{} {}".format(sys.version_info.major, script)
if args:
cmd += " " + " ".join(args)
......
......@@ -20,7 +20,7 @@ class CharRNNTest(TestPythonScript):
f.write(random_content())
def test(self):
self.assertSurvive(self.script, args=['--gpu 0', 'train'], timeout=10)
self.assertSurvive(self.script, args=['train'])
def tearDown(self):
super(CharRNNTest, self).tearDown()
......
......@@ -8,4 +8,4 @@ class InfoGANTest(TestPythonScript):
return '../examples/GAN/InfoGAN-mnist.py'
def test(self):
self.assertSurvive(self.script, args=None, timeout=10)
self.assertSurvive(self.script, args=None)
......@@ -8,4 +8,4 @@ class MnistTest(TestPythonScript):
return '../examples/mnist-convnet.py'
def test(self):
self.assertSurvive(self.script, args=None, timeout=10)
self.assertSurvive(self.script, args=None)
from case_script import TestPythonScript
import os
import shutil
class ResnetTest(TestPythonScript):
@property
def script(self):
return '../examples/ResNet/imagenet-resnet.py'
def test(self):
self.assertSurvive(self.script, args=['--data .',
'--gpu 0', '--fake', '--data_format NHWC'], timeout=10)
def tearDown(self):
super(ResnetTest, self).tearDown()
if os.path.isdir('ilsvrc'):
shutil.rmtree('ilsvrc')
self.assertSurvive(
self.script,
args=['--fake', '--data_format NHWC'], timeout=20)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment