Commit 2b095948 authored by Yuxin Wu's avatar Yuxin Wu

fix travis build (fix #146)

parent 7a008efe
...@@ -27,7 +27,7 @@ matrix: ...@@ -27,7 +27,7 @@ matrix:
python: 3.5 python: 3.5
env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly
allow_failures: allow_failures:
- env: TF_TYPE=nightly - env: TF_VERSION=1.0.0rc2 TF_TYPE=nightly
install: install:
- pip install -U pip # the pip version on travis is too old - pip install -U pip # the pip version on travis is too old
......
...@@ -3,31 +3,26 @@ ...@@ -3,31 +3,26 @@
# File: DCGAN-CelebA.py # File: DCGAN-CelebA.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import tensorflow as tf import tensorflow as tf
import glob import glob
import pickle import os, sys
import os
import sys
import argparse import argparse
import cv2
from tensorpack import * from tensorpack import *
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
import tensorpack.tfutils.symbolic_functions as symbf
from GAN import GANTrainer, RandomZData, GANModelDesc from GAN import GANTrainer, RandomZData, GANModelDesc
""" """
DCGAN on CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset 1. Download the 'aligned&cropped' version of CelebA dataset
from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
(or just use any directory of jpg files).
2. Start training: 2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/ ./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model: 3. Visualize samples of a trained model:
./DCGAN-CelebA.py --load path/to/model --sample ./DCGAN-CelebA.py --load path/to/model --sample
You can also train on other images (just use any directory of jpg files in
`--data`). But you may need to change the preprocessing steps in `get_data()`.
""" """
SHAPE = 64 SHAPE = 64
...@@ -94,9 +89,7 @@ class Model(GANModelDesc): ...@@ -94,9 +89,7 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3) return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def get_data(): def get_data(datadir):
global args
datadir = args.data
imgs = glob.glob(datadir + '/*.jpg') imgs = glob.glob(datadir + '/*.jpg')
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
augs = [imgaug.CenterCrop(140), imgaug.Resize(64)] augs = [imgaug.CenterCrop(140), imgaug.Resize(64)]
...@@ -108,10 +101,10 @@ def get_data(): ...@@ -108,10 +101,10 @@ def get_data():
def get_config(): def get_config():
return TrainConfig( return TrainConfig(
dataflow=get_data(), model=Model(),
dataflow=get_data(args.data),
callbacks=[ModelSaver()], callbacks=[ModelSaver()],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(),
steps_per_epoch=300, steps_per_epoch=300,
max_epoch=200, max_epoch=200,
) )
...@@ -135,8 +128,8 @@ if __name__ == '__main__': ...@@ -135,8 +128,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='view generated examples')
parser.add_argument('--data', help='`image_align_celeba` directory of the celebA dataset') parser.add_argument('--data', help='a jpeg directory')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = F403,F401,F405,F841 ignore = F403,F401,F405,F841,E401
exclude = private exclude = private
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from tensorpack.libinfo import __version__
from tensorpack.train import * from tensorpack.train import *
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
...@@ -10,4 +11,3 @@ from tensorpack.tfutils import * ...@@ -10,4 +11,3 @@ from tensorpack.tfutils import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.predict import * from tensorpack.predict import *
from tensorpack.libinfo import __version__
# issue#523 may happen on old systems
import cv2 # noqa
# issue#7378 may happen with custom opencv. It doesn't hurt to disable opencl
import os
os.environ['OPENCV_OPENCL_RUNTIME'] = ''
__version__ = '0.1.5' __version__ = '0.1.5'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment