Commit ee7dcd0d authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

More Tests (#200)

* some tweaks for travis

First commit for rewriting the unittests when using travis-ci. It
uses unitests from python instead of custom framework and adds
caching to the tensorpack_datadir which path is explicitly set.

* add infogan example to tests

* add flag to create data dir only if not exists

* add several tests including ResNet-18 on fakedata

* fix matplotlib import

* fix python3 bytes-like str

* override image paths in resnet-test

* try to fix ilsvrc paths

* python3 fix for decode str

* fix meta-data path

* next travis fix, arg in python call

* use fakedata in resnet

* fix lint, fix arg

* fix linting in test_resnet.py

* cosmetic changes

* add data_format for resnet-eval

* delete resnet directory

* fix linting :(
parent 7332ae78
...@@ -5,11 +5,15 @@ language: python ...@@ -5,11 +5,15 @@ language: python
cache: cache:
pip: true pip: true
apt: true apt: true
directories:
- $HOME/tensorpack_data
addons: addons:
apt: apt:
packages: packages:
- pandoc - pandoc
- libprotobuf-dev
- protobuf-compiler
matrix: matrix:
fast_finish: true fast_finish: true
...@@ -45,7 +49,9 @@ before_script: ...@@ -45,7 +49,9 @@ before_script:
script: script:
- flake8 . - flake8 .
- cd examples && flake8 . - cd examples && flake8 .
- cd $TRAVIS_BUILD_DIR && python tests/test_examples.py - mkdir -p $HOME/tensorpack_data
- export TENSORPACK_DATASET=$HOME/tensorpack_data
- cd $TRAVIS_BUILD_DIR/tests && python -m unittest discover -v
notifications: notifications:
- email: - email:
......
...@@ -23,6 +23,10 @@ DEPTH = None ...@@ -23,6 +23,10 @@ DEPTH = None
class Model(ModelDesc): class Model(ModelDesc):
def __init__(self, data_format='NCHW'):
self.data_format = data_format
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'), return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
...@@ -36,6 +40,7 @@ class Model(ModelDesc): ...@@ -36,6 +40,7 @@ class Model(ModelDesc):
image_mean = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32) image_mean = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
image_std = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32) image_std = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)
image = (image - image_mean) / image_std image = (image - image_mean) / image_std
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2]) image = tf.transpose(image, [0, 3, 1, 2])
def shortcut(l, n_in, n_out, stride): def shortcut(l, n_in, n_out, stride):
...@@ -93,7 +98,7 @@ class Model(ModelDesc): ...@@ -93,7 +98,7 @@ class Model(ModelDesc):
with argscope(Conv2D, nl=tf.identity, use_bias=False, with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')), \ W_init=variance_scaling_initializer(mode='FAN_OUT')), \
argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'): argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
...@@ -123,8 +128,9 @@ class Model(ModelDesc): ...@@ -123,8 +128,9 @@ class Model(ModelDesc):
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True) return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test): def get_data(train_or_test, fake=False):
# return FakeData([[64, 224,224,3],[64]], 1000, random=False, dtype='uint8') if fake:
return FakeData([[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
datadir = args.data datadir = args.data
...@@ -187,9 +193,9 @@ def get_data(train_or_test): ...@@ -187,9 +193,9 @@ def get_data(train_or_test):
return ds return ds
def get_config(): def get_config(fake=False, data_format='NCHW'):
dataset_train = get_data('train') dataset_train = get_data('train', fake=fake)
dataset_val = get_data('val') dataset_val = get_data('val', fake=fake)
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
...@@ -202,7 +208,7 @@ def get_config(): ...@@ -202,7 +208,7 @@ def get_config():
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]), [(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
], ],
model=Model(), model=Model(data_format=data_format),
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
) )
...@@ -231,6 +237,9 @@ if __name__ == '__main__': ...@@ -231,6 +237,9 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
parser.add_argument('--data_format', help='specify NCHW or NHWC',
type=str, default='NCHW')
parser.add_argument('-d', '--depth', help='resnet depth', parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101]) type=int, default=18, choices=[18, 34, 50, 101])
parser.add_argument('--eval', action='store_true') parser.add_argument('--eval', action='store_true')
...@@ -249,7 +258,7 @@ if __name__ == '__main__': ...@@ -249,7 +258,7 @@ if __name__ == '__main__':
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.auto_set_dir() logger.auto_set_dir()
config = get_config() config = get_config(fake=args.fake, data_format=args.data_format)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU config.nr_tower = NR_GPU
......
...@@ -6,10 +6,6 @@ ...@@ -6,10 +6,6 @@
import numpy as np import numpy as np
import os import os
import matplotlib
from matplotlib import offsetbox
import matplotlib.pyplot as plt
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -20,6 +16,15 @@ import tensorflow.contrib.slim as slim ...@@ -20,6 +16,15 @@ import tensorflow.contrib.slim as slim
from embedding_data import get_test_data, MnistPairs, MnistTriplets from embedding_data import get_test_data, MnistPairs, MnistTriplets
MATPLOTLIB_AVAIBLABLE = False
try:
import matplotlib
from matplotlib import offsetbox
import matplotlib.pyplot as plt
MATPLOTLIB_AVAIBLABLE = True
except ImportError:
MATPLOTLIB_AVAIBLABLE = False
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
tf.app.flags.DEFINE_string('load', "", 'load model') tf.app.flags.DEFINE_string('load', "", 'load model')
...@@ -162,6 +167,9 @@ def get_config(model, algorithm_name): ...@@ -162,6 +167,9 @@ def get_config(model, algorithm_name):
def visualize(model_path, model): def visualize(model_path, model):
if not MATPLOTLIB_AVAIBLABLE:
logger.error("visualize requires matplotlib package ...")
return
pred = OfflinePredictor(PredictConfig( pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
model=model(), model=model(),
......
#!/usr/bin/env python from abc import abstractproperty
# -*- coding: UTF-8 -*- import unittest
# File: test_examples.py
# Author: Patrick Wieschollek <mail@patwie.com>
import sys
import subprocess import subprocess
import shlex import shlex
import sys
import threading import threading
from termcolor import colored, cprint import os
import shutil
COMMANDS_TO_TEST = ["python examples/mnist-convnet.py"]
class SurviveException(Exception):
"""Exception when process is already terminated
"""
pass
class PythonScript(threading.Thread): class PythonScript(threading.Thread):
...@@ -63,35 +52,31 @@ class PythonScript(threading.Thread): ...@@ -63,35 +52,31 @@ class PythonScript(threading.Thread):
else: else:
# something unexpected happend here, this script was supposed to survive at leat the timeout # something unexpected happend here, this script was supposed to survive at leat the timeout
if len(self.err) is not 0: if len(self.err) is not 0:
stderr = "\n".join([" " * 10 + v for v in self.err.split("\n")]) stderr = "\n".join([" " * 10 + v for v in self.err.decode("utf-8").split("\n")])
raise SurviveException(stderr) raise AssertionError(stderr)
examples_total = len(COMMANDS_TO_TEST) class TestPythonScript(unittest.TestCase):
examples_passed = 0
examples_failed = [] @abstractproperty
max_name_length = max([len(v) for v in COMMANDS_TO_TEST]) + 10 def script(self):
pass
cprint("Test %i python scripts with timeout" % (examples_total), 'yellow', attrs=['bold'])
@staticmethod
for example_name in COMMANDS_TO_TEST: def clear_trainlog(script):
string = "test: %s %s" % (example_name, " " * (max_name_length - len(example_name))) script = os.path.basename(script)
sys.stdout.write(colored(string, 'yellow', attrs=['bold'])) script = script[:-3]
try: if os.path.isdir(os.path.join("train_log", script)):
PythonScript(example_name).execute() shutil.rmtree(os.path.join("train_log", script))
cprint("... works", 'green', attrs=['bold'])
examples_passed += 1 def assertSurvive(self, script, args=None, timeout=10): # noqa
except Exception as stderr_message: cmd = "python{} {}".format(sys.version_info.major, script)
cprint("... examples_failed", 'red', attrs=['bold']) if args:
print(stderr_message) cmd += " " + " ".join(args)
examples_failed.append(example_name) PythonScript(cmd, timeout=timeout).execute()
print("\n\n") def setUp(self):
cprint("Summary: TEST examples_passed %i / %i" % (examples_passed, examples_total), 'yellow', attrs=['bold']) TestPythonScript.clear_trainlog(self.script)
if examples_total != examples_passed:
print("The following script examples_failed:") def tearDown(self):
for failed_script in examples_failed: TestPythonScript.clear_trainlog(self.script)
print(" - %s" % failed_script)
sys.exit(1)
else:
sys.exit(0)
from case_script import TestPythonScript
import os
def random_content():
return ('Lorem ipsum dolor sit amet\n'
'consetetur sadipscing elitr\n'
'sed diam nonumy eirmod tempor invidunt ut labore\n')
class CharRNNTest(TestPythonScript):
@property
def script(self):
return '../examples/Char-RNN/char-rnn.py'
def setUp(self):
super(CharRNNTest, self).setUp()
with open('input.txt', 'w') as f:
f.write(random_content())
def test(self):
self.assertSurvive(self.script, args=['--gpu 0', 'train'], timeout=10)
def tearDown(self):
super(CharRNNTest, self).tearDown()
os.remove('input.txt')
from case_script import TestPythonScript
class InfoGANTest(TestPythonScript):
@property
def script(self):
return '../examples/GAN/InfoGAN-mnist.py'
def test(self):
self.assertSurvive(self.script, args=None, timeout=10)
from case_script import TestPythonScript
class MnistTest(TestPythonScript):
@property
def script(self):
return '../examples/mnist-convnet.py'
def test(self):
self.assertSurvive(self.script, args=None, timeout=10)
from case_script import TestPythonScript
class SimilarityLearningTest(TestPythonScript):
@property
def script(self):
return '../examples/SimilarityLearning/mnist-embeddings.py'
def test(self):
self.assertSurvive(self.script, args=['--algorithm triplet'], timeout=10)
from case_script import TestPythonScript
import os
import shutil
class ResnetTest(TestPythonScript):
@property
def script(self):
return '../examples/ResNet/imagenet-resnet.py'
def test(self):
self.assertSurvive(self.script, args=['--data .',
'--gpu 0', '--fake', '--data_format NHWC'], timeout=10)
def tearDown(self):
super(ResnetTest, self).tearDown()
if os.path.isdir('ilsvrc'):
shutil.rmtree('ilsvrc')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment