Commit e4d6992d authored by Yuxin Wu's avatar Yuxin Wu

use metaclass with six

parent fbc13fb4
......@@ -55,3 +55,4 @@ pip install --user -r opt-requirements.txt (some optional dependencies, you can
```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
```
+ Use tcmalloc if running with large data
......@@ -14,7 +14,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po
To use the script. You'll need:
+ TensorFlow >= 0.10
+ TensorFlow >= 0.11
+ OpenCV bindings for Python
......
......@@ -20,6 +20,7 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses
To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain images of shpae 2s x s, formed by A and B
# you can download some data from the original pix2pix repo: https://github.com/phillipi/pix2pix#datasets
# training visualization will appear be in tensorboard
To visualize on test set:
......@@ -125,7 +126,7 @@ class Model(ModelDesc):
def split_input(img):
"""
img: an 512x256x3 image
img: an image with shape (s, 2s, 3)
:return: [input, output]
"""
s = img.shape[0]
......@@ -187,7 +188,7 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='A directory of 512x256 images')
parser.add_argument('--data', help='Image directory')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
global args
args = parser.parse_args()
......
......@@ -6,15 +6,15 @@
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import six
import random
from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
@six.add_metaclass(ABCMeta)
class RLEnvironment(object):
__meta__ = ABCMeta
def __init__(self):
self.reset_stat()
......
......@@ -11,6 +11,7 @@ import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import numpy as np
import six
from six.moves import queue
......@@ -43,8 +44,8 @@ class TransitionExperience(object):
for k, v in six.iteritems(kwargs):
setattr(self, k, v)
@six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta
def __init__(self, idx):
super(SimulatorProcessBase, self).__init__()
......@@ -62,8 +63,6 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to
send states and receive the next action
"""
__metaclass__ = ABCMeta
def __init__(self, idx, pipe_c2s, pipe_s2c):
"""
:param idx: idx of this process
......@@ -103,8 +102,6 @@ class SimulatorMaster(threading.Thread):
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
__metaclass__ = ABCMeta
class ClientState(object):
def __init__(self):
self.memory = [] # list of Experience
......
......@@ -7,12 +7,13 @@ import sys
import os
import time
from abc import abstractmethod, ABCMeta
import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
@six.add_metaclass(ABCMeta)
class Callback(object):
""" Base class for all callbacks """
__metaclass__ = ABCMeta
def before_train(self):
"""
......
......@@ -20,8 +20,8 @@ from .dispatcher import OutputTensorDispatcer
__all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta)
class Inferencer(object):
__metaclass__ = ABCMeta
def before_inference(self):
"""
......
......@@ -17,10 +17,9 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam']
@six.add_metaclass(ABCMeta)
class HyperParam(object):
""" Base class for a hyper param"""
__metaclass__ = ABCMeta
def setup_graph(self):
""" setup the graph in `setup_graph` callback stage, if necessary"""
......@@ -88,7 +87,6 @@ class HyperParamSetter(Callback):
"""
Base class to set hyperparameters after every epoch.
"""
__metaclass__ = ABCMeta
def __init__(self, param):
"""
......
......@@ -5,14 +5,14 @@
from abc import abstractmethod, ABCMeta
import six
from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
@six.add_metaclass(ABCMeta)
class DataFlow(object):
""" Base class for all DataFlow """
__metaclass__ = ABCMeta
class Infinity:
pass
......
......@@ -44,6 +44,10 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
The output database can be read directly by `LMDBDataPoint`
"""
assert isinstance(ds, DataFlow), type(ds)
isdir = os.path.isdir(lmdb_path)
if isdir:
......
......@@ -4,13 +4,14 @@
from abc import abstractmethod, ABCMeta
from ...utils import get_rng
import six
from six.moves import zip
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
@six.add_metaclass(ABCMeta)
class Augmentor(object):
""" Base class for an augmentor"""
__metaclass__ = ABCMeta
def __init__(self):
self.reset_state()
......
......@@ -9,6 +9,7 @@ import tensorflow as tf
from collections import namedtuple
import inspect
import pickle
import six
from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_tensors_by_names
......@@ -30,9 +31,9 @@ class InputVar(object):
def loads(buf):
return pickle.loads(buf)
@six.add_metaclass(ABCMeta)
class ModelDesc(object):
""" Base class for a model description """
__metaclass__ = ABCMeta
def get_input_vars(self):
"""
......
......@@ -15,9 +15,8 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor']
@six.add_metaclass(ABCMeta)
class PredictorBase(object):
__metaclass__ = ABCMeta
"""
Available attributes:
session
......
......@@ -7,6 +7,7 @@ from six.moves import range, zip
from abc import ABCMeta, abstractmethod
import multiprocessing
import os
import six
from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue
......@@ -21,9 +22,8 @@ from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor']
@six.add_metaclass(ABCMeta)
class DatasetPredictorBase(object):
__metaclass__ = ABCMeta
def __init__(self, config, dataset):
"""
:param config: a `PredictConfig` instance.
......
......@@ -6,6 +6,7 @@
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import re
import six
import inspect
from ..utils import logger
from .symbolic_functions import rms
......@@ -31,8 +32,8 @@ def apply_grad_processors(grads, gradprocs):
g = proc.process(g)
return g
@six.add_metaclass(ABCMeta)
class GradientProcessor(object):
__metaclass__ = ABCMeta
def process(self, grads):
"""
......
......@@ -20,9 +20,9 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default.
@six.add_metaclass(ABCMeta)
class SessionInit(object):
""" Base class for utilities to initialize a session"""
__metaclass__ = ABCMeta
def init(self, sess):
""" Initialize a session
......
......@@ -6,6 +6,7 @@ from abc import ABCMeta, abstractmethod
import signal
import re
import weakref
import six
from six.moves import range
import tqdm
......@@ -22,10 +23,9 @@ __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
pass
@six.add_metaclass(ABCMeta)
class Trainer(object):
""" Base class for a trainer."""
__metaclass__ = ABCMeta
"""a `StatHolder` instance"""
stat_holder = None
......
......@@ -6,6 +6,7 @@
import tensorflow as tf
import threading
from abc import ABCMeta, abstractmethod
import six
from ..dataflow.common import RepeatedData
from ..tfutils.summary import add_moving_summary
......@@ -14,8 +15,8 @@ from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput']
@six.add_metaclass(ABCMeta)
class InputData(object):
__metaclass__ = ABCMeta
pass
class FeedInput(InputData):
......
......@@ -7,6 +7,7 @@ from . import logger
from .argtools import memoized
from abc import abstractmethod, ABCMeta
import numpy as np
import six
from six.moves import range
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
......@@ -16,8 +17,8 @@ def log_once(s):
logger.warn(s)
# just a placeholder
@six.add_metaclass(ABCMeta)
class Discretizer(object):
__metaclass__ = ABCMeta
@abstractmethod
def get_nr_bin(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment