Commit 2b095948 authored by Yuxin Wu's avatar Yuxin Wu

fix travis build (fix #146)

parent 7a008efe
......@@ -27,7 +27,7 @@ matrix:
python: 3.5
env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly
allow_failures:
- env: TF_TYPE=nightly
- env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly
install:
- pip install -U pip # the pip version on travis is too old
......
......@@ -3,31 +3,26 @@
# File: DCGAN-CelebA.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf
import glob
import pickle
import os
import sys
import os, sys
import argparse
import cv2
from tensorpack import *
from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, GANModelDesc
"""
DCGAN on CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset
from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
(or just use any directory of jpg files).
2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model:
./DCGAN-CelebA.py --load path/to/model --sample
You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing steps in `get_data()`.
"""
SHAPE = 64
......@@ -94,9 +89,7 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def get_data():
global args
datadir = args.data
def get_data(datadir):
imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [imgaug.CenterCrop(140), imgaug.Resize(64)]
......@@ -108,10 +101,10 @@ def get_data():
def get_config():
return TrainConfig(
dataflow=get_data(),
model=Model(),
dataflow=get_data(args.data),
callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=300,
max_epoch=200,
)
......@@ -135,8 +128,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset')
parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='a jpeg directory')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......
[flake8]
max-line-length = 120
ignore = F403,F401,F405,F841
ignore = F403,F401,F405,F841,E401
exclude = private
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from tensorpack.libinfo import __version__
from tensorpack.train import *
from tensorpack.models import *
from tensorpack.utils import *
......@@ -10,4 +11,3 @@ from tensorpack.tfutils import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.predict import *
from tensorpack.libinfo import __version__
# issue#523 may happen on old systems
import cv2 # noqa
# issue#7378 may happen with custom opencv. It doesn't hurt to disable opencl
import os
os.environ['OPENCV_OPENCL_RUNTIME'] = ''
__version__ = '0.1.5'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment