Commit ee7dcd0d authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

More Tests (#200)

* some tweaks for travis

First commit for rewriting the unittests when using travis-ci. It
uses unitests from python instead of custom framework and adds
caching to the tensorpack_datadir which path is explicitly set.

* add infogan example to tests

* add flag to create data dir only if not exists

* add several tests including ResNet-18 on fakedata

* fix matplotlib import

* fix python3 bytes-like str

* override image paths in resnet-test

* try to fix ilsvrc paths

* python3 fix for decode str

* fix meta-data path

* next travis fix, arg in python call

* use fakedata in resnet

* fix lint, fix arg

* fix linting in test_resnet.py

* cosmetic changes

* add data_format for resnet-eval

* delete resnet directory

* fix linting :(
parent 7332ae78
......@@ -5,11 +5,15 @@ language: python
cache:
pip: true
apt: true
directories:
- $HOME/tensorpack_data
addons:
apt:
packages:
- pandoc
- libprotobuf-dev
- protobuf-compiler
matrix:
fast_finish: true
......@@ -45,7 +49,9 @@ before_script:
script:
- flake8 .
- cd examples && flake8 .
- cd $TRAVIS_BUILD_DIR && python tests/test_examples.py
- mkdir -p $HOME/tensorpack_data
- export TENSORPACK_DATASET=$HOME/tensorpack_data
- cd $TRAVIS_BUILD_DIR/tests && python -m unittest discover -v
notifications:
- email:
......
......@@ -23,6 +23,10 @@ DEPTH = None
class Model(ModelDesc):
def __init__(self, data_format='NCHW'):
self.data_format = data_format
def _get_inputs(self):
return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
......@@ -36,7 +40,8 @@ class Model(ModelDesc):
image_mean = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
image_std = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)
image = (image - image_mean) / image_std
image = tf.transpose(image, [0, 3, 1, 2])
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
def shortcut(l, n_in, n_out, stride):
if n_in != n_out:
......@@ -93,7 +98,7 @@ class Model(ModelDesc):
with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')), \
argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME')
......@@ -123,8 +128,9 @@ class Model(ModelDesc):
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
def get_data(train_or_test):
# return FakeData([[64, 224,224,3],[64]], 1000, random=False, dtype='uint8')
def get_data(train_or_test, fake=False):
if fake:
return FakeData([[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
isTrain = train_or_test == 'train'
datadir = args.data
......@@ -187,9 +193,9 @@ def get_data(train_or_test):
return ds
def get_config():
dataset_train = get_data('train')
dataset_val = get_data('val')
def get_config(fake=False, data_format='NCHW'):
dataset_train = get_data('train', fake=fake)
dataset_val = get_data('val', fake=fake)
return TrainConfig(
dataflow=dataset_train,
......@@ -202,7 +208,7 @@ def get_config():
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
HumanHyperParamSetter('learning_rate'),
],
model=Model(),
model=Model(data_format=data_format),
steps_per_epoch=5000,
max_epoch=110,
)
......@@ -231,6 +237,9 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('--fake', help='use fakedata to test or benchmark this model', action='store_true')
parser.add_argument('--data_format', help='specify NCHW or NHWC',
type=str, default='NCHW')
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101])
parser.add_argument('--eval', action='store_true')
......@@ -249,7 +258,7 @@ if __name__ == '__main__':
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.auto_set_dir()
config = get_config()
config = get_config(fake=args.fake, data_format=args.data_format)
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = NR_GPU
......
......@@ -6,10 +6,6 @@
import numpy as np
import os
import matplotlib
from matplotlib import offsetbox
import matplotlib.pyplot as plt
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
......@@ -20,6 +16,15 @@ import tensorflow.contrib.slim as slim
from embedding_data import get_test_data, MnistPairs, MnistTriplets
MATPLOTLIB_AVAIBLABLE = False
try:
import matplotlib
from matplotlib import offsetbox
import matplotlib.pyplot as plt
MATPLOTLIB_AVAIBLABLE = True
except ImportError:
MATPLOTLIB_AVAIBLABLE = False
FLAGS = flags.FLAGS
tf.app.flags.DEFINE_string('load', "", 'load model')
......@@ -162,6 +167,9 @@ def get_config(model, algorithm_name):
def visualize(model_path, model):
if not MATPLOTLIB_AVAIBLABLE:
logger.error("visualize requires matplotlib package ...")
return
pred = OfflinePredictor(PredictConfig(
session_init=get_model_loader(model_path),
model=model(),
......
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: test_examples.py
# Author: Patrick Wieschollek <mail@patwie.com>
import sys
from abc import abstractproperty
import unittest
import subprocess
import shlex
import sys
import threading
from termcolor import colored, cprint
COMMANDS_TO_TEST = ["python examples/mnist-convnet.py"]
class SurviveException(Exception):
"""Exception when process is already terminated
"""
pass
import os
import shutil
class PythonScript(threading.Thread):
......@@ -63,35 +52,31 @@ class PythonScript(threading.Thread):
else:
# something unexpected happend here, this script was supposed to survive at leat the timeout
if len(self.err) is not 0:
stderr = "\n".join([" " * 10 + v for v in self.err.split("\n")])
raise SurviveException(stderr)
examples_total = len(COMMANDS_TO_TEST)
examples_passed = 0
examples_failed = []
max_name_length = max([len(v) for v in COMMANDS_TO_TEST]) + 10
cprint("Test %i python scripts with timeout" % (examples_total), 'yellow', attrs=['bold'])
for example_name in COMMANDS_TO_TEST:
string = "test: %s %s" % (example_name, " " * (max_name_length - len(example_name)))
sys.stdout.write(colored(string, 'yellow', attrs=['bold']))
try:
PythonScript(example_name).execute()
cprint("... works", 'green', attrs=['bold'])
examples_passed += 1
except Exception as stderr_message:
cprint("... examples_failed", 'red', attrs=['bold'])
print(stderr_message)
examples_failed.append(example_name)
print("\n\n")
cprint("Summary: TEST examples_passed %i / %i" % (examples_passed, examples_total), 'yellow', attrs=['bold'])
if examples_total != examples_passed:
print("The following script examples_failed:")
for failed_script in examples_failed:
print(" - %s" % failed_script)
sys.exit(1)
else:
sys.exit(0)
stderr = "\n".join([" " * 10 + v for v in self.err.decode("utf-8").split("\n")])
raise AssertionError(stderr)
class TestPythonScript(unittest.TestCase):
@abstractproperty
def script(self):
pass
@staticmethod
def clear_trainlog(script):
script = os.path.basename(script)
script = script[:-3]
if os.path.isdir(os.path.join("train_log", script)):
shutil.rmtree(os.path.join("train_log", script))
def assertSurvive(self, script, args=None, timeout=10): # noqa
cmd = "python{} {}".format(sys.version_info.major, script)
if args:
cmd += " " + " ".join(args)
PythonScript(cmd, timeout=timeout).execute()
def setUp(self):
TestPythonScript.clear_trainlog(self.script)
def tearDown(self):
TestPythonScript.clear_trainlog(self.script)
from case_script import TestPythonScript
import os
def random_content():
return ('Lorem ipsum dolor sit amet\n'
'consetetur sadipscing elitr\n'
'sed diam nonumy eirmod tempor invidunt ut labore\n')
class CharRNNTest(TestPythonScript):
@property
def script(self):
return '../examples/Char-RNN/char-rnn.py'
def setUp(self):
super(CharRNNTest, self).setUp()
with open('input.txt', 'w') as f:
f.write(random_content())
def test(self):
self.assertSurvive(self.script, args=['--gpu 0', 'train'], timeout=10)
def tearDown(self):
super(CharRNNTest, self).tearDown()
os.remove('input.txt')
from case_script import TestPythonScript
class InfoGANTest(TestPythonScript):
@property
def script(self):
return '../examples/GAN/InfoGAN-mnist.py'
def test(self):
self.assertSurvive(self.script, args=None, timeout=10)
from case_script import TestPythonScript
class MnistTest(TestPythonScript):
@property
def script(self):
return '../examples/mnist-convnet.py'
def test(self):
self.assertSurvive(self.script, args=None, timeout=10)
from case_script import TestPythonScript
class SimilarityLearningTest(TestPythonScript):
@property
def script(self):
return '../examples/SimilarityLearning/mnist-embeddings.py'
def test(self):
self.assertSurvive(self.script, args=['--algorithm triplet'], timeout=10)
from case_script import TestPythonScript
import os
import shutil
class ResnetTest(TestPythonScript):
@property
def script(self):
return '../examples/ResNet/imagenet-resnet.py'
def test(self):
self.assertSurvive(self.script, args=['--data .',
'--gpu 0', '--fake', '--data_format NHWC'], timeout=10)
def tearDown(self):
super(ResnetTest, self).tearDown()
if os.path.isdir('ilsvrc'):
shutil.rmtree('ilsvrc')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment