Commit e4d6992d authored by Yuxin Wu's avatar Yuxin Wu

use metaclass with six

parent fbc13fb4
...@@ -55,3 +55,4 @@ pip install --user -r opt-requirements.txt (some optional dependencies, you can ...@@ -55,3 +55,4 @@ pip install --user -r opt-requirements.txt (some optional dependencies, you can
``` ```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack` export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
``` ```
+ Use tcmalloc if running with large data
...@@ -14,7 +14,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po ...@@ -14,7 +14,7 @@ It's provided in the format of numpy dictionary, so it should be very easy to po
To use the script. You'll need: To use the script. You'll need:
+ TensorFlow >= 0.10 + TensorFlow >= 0.11
+ OpenCV bindings for Python + OpenCV bindings for Python
......
...@@ -20,6 +20,7 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses ...@@ -20,6 +20,7 @@ from GAN import GANTrainer, RandomZData, build_GAN_losses
To train: To train:
./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA} ./Image2Image.py --data /path/to/datadir --mode {AtoB,BtoA}
# datadir should contain images of shpae 2s x s, formed by A and B # datadir should contain images of shpae 2s x s, formed by A and B
# you can download some data from the original pix2pix repo: https://github.com/phillipi/pix2pix#datasets
# training visualization will appear be in tensorboard # training visualization will appear be in tensorboard
To visualize on test set: To visualize on test set:
...@@ -125,7 +126,7 @@ class Model(ModelDesc): ...@@ -125,7 +126,7 @@ class Model(ModelDesc):
def split_input(img): def split_input(img):
""" """
img: an 512x256x3 image img: an image with shape (s, 2s, 3)
:return: [input, output] :return: [input, output]
""" """
s = img.shape[0] s = img.shape[0]
...@@ -187,7 +188,7 @@ if __name__ == '__main__': ...@@ -187,7 +188,7 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--sample', action='store_true', help='run sampling') parser.add_argument('--sample', action='store_true', help='run sampling')
parser.add_argument('--data', help='A directory of 512x256 images') parser.add_argument('--data', help='Image directory')
parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB') parser.add_argument('--mode', choices=['AtoB', 'BtoA'], default='AtoB')
global args global args
args = parser.parse_args() args = parser.parse_args()
......
...@@ -6,15 +6,15 @@ ...@@ -6,15 +6,15 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict from collections import defaultdict
import six
import random import random
from ..utils import get_rng from ..utils import get_rng
__all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer', __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace'] 'DiscreteActionSpace']
@six.add_metaclass(ABCMeta)
class RLEnvironment(object): class RLEnvironment(object):
__meta__ = ABCMeta
def __init__(self): def __init__(self):
self.reset_stat() self.reset_stat()
......
...@@ -11,6 +11,7 @@ import weakref ...@@ -11,6 +11,7 @@ import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
import numpy as np import numpy as np
import six import six
from six.moves import queue from six.moves import queue
...@@ -43,8 +44,8 @@ class TransitionExperience(object): ...@@ -43,8 +44,8 @@ class TransitionExperience(object):
for k, v in six.iteritems(kwargs): for k, v in six.iteritems(kwargs):
setattr(self, k, v) setattr(self, k, v)
@six.add_metaclass(ABCMeta)
class SimulatorProcessBase(mp.Process): class SimulatorProcessBase(mp.Process):
__metaclass__ = ABCMeta
def __init__(self, idx): def __init__(self, idx):
super(SimulatorProcessBase, self).__init__() super(SimulatorProcessBase, self).__init__()
...@@ -62,8 +63,6 @@ class SimulatorProcessStateExchange(SimulatorProcessBase): ...@@ -62,8 +63,6 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
A process that simulates a player and communicates to master to A process that simulates a player and communicates to master to
send states and receive the next action send states and receive the next action
""" """
__metaclass__ = ABCMeta
def __init__(self, idx, pipe_c2s, pipe_s2c): def __init__(self, idx, pipe_c2s, pipe_s2c):
""" """
:param idx: idx of this process :param idx: idx of this process
...@@ -103,8 +102,6 @@ class SimulatorMaster(threading.Thread): ...@@ -103,8 +102,6 @@ class SimulatorMaster(threading.Thread):
It should produce action for each simulator, as well as It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished. defining callbacks when a transition or an episode is finished.
""" """
__metaclass__ = ABCMeta
class ClientState(object): class ClientState(object):
def __init__(self): def __init__(self):
self.memory = [] # list of Experience self.memory = [] # list of Experience
......
...@@ -7,12 +7,13 @@ import sys ...@@ -7,12 +7,13 @@ import sys
import os import os
import time import time
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import six
__all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback'] __all__ = ['Callback', 'PeriodicCallback', 'ProxyCallback']
@six.add_metaclass(ABCMeta)
class Callback(object): class Callback(object):
""" Base class for all callbacks """ """ Base class for all callbacks """
__metaclass__ = ABCMeta
def before_train(self): def before_train(self):
""" """
......
...@@ -20,8 +20,8 @@ from .dispatcher import OutputTensorDispatcer ...@@ -20,8 +20,8 @@ from .dispatcher import OutputTensorDispatcer
__all__ = ['InferenceRunner', 'ClassificationError', __all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer', 'BinaryClassificationStats'] 'ScalarStats', 'Inferencer', 'BinaryClassificationStats']
@six.add_metaclass(ABCMeta)
class Inferencer(object): class Inferencer(object):
__metaclass__ = ABCMeta
def before_inference(self): def before_inference(self):
""" """
......
...@@ -17,10 +17,9 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter', ...@@ -17,10 +17,9 @@ __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter', 'ScheduledHyperParamSetter',
'StatMonitorParamSetter', 'HyperParamSetterWithFunc', 'StatMonitorParamSetter', 'HyperParamSetterWithFunc',
'HyperParam', 'GraphVarParam', 'ObjAttrParam'] 'HyperParam', 'GraphVarParam', 'ObjAttrParam']
@six.add_metaclass(ABCMeta)
class HyperParam(object): class HyperParam(object):
""" Base class for a hyper param""" """ Base class for a hyper param"""
__metaclass__ = ABCMeta
def setup_graph(self): def setup_graph(self):
""" setup the graph in `setup_graph` callback stage, if necessary""" """ setup the graph in `setup_graph` callback stage, if necessary"""
...@@ -88,7 +87,6 @@ class HyperParamSetter(Callback): ...@@ -88,7 +87,6 @@ class HyperParamSetter(Callback):
""" """
Base class to set hyperparameters after every epoch. Base class to set hyperparameters after every epoch.
""" """
__metaclass__ = ABCMeta
def __init__(self, param): def __init__(self, param):
""" """
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import six
from ..utils import get_rng from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
@six.add_metaclass(ABCMeta)
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
__metaclass__ = ABCMeta
class Infinity: class Infinity:
pass pass
......
...@@ -44,6 +44,10 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -44,6 +44,10 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img) cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dump_dataflow_to_lmdb(ds, lmdb_path): def dump_dataflow_to_lmdb(ds, lmdb_path):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
The output database can be read directly by `LMDBDataPoint`
"""
assert isinstance(ds, DataFlow), type(ds) assert isinstance(ds, DataFlow), type(ds)
isdir = os.path.isdir(lmdb_path) isdir = os.path.isdir(lmdb_path)
if isdir: if isdir:
......
...@@ -4,13 +4,14 @@ ...@@ -4,13 +4,14 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ...utils import get_rng from ...utils import get_rng
import six
from six.moves import zip from six.moves import zip
__all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList'] __all__ = ['Augmentor', 'ImageAugmentor', 'AugmentorList']
@six.add_metaclass(ABCMeta)
class Augmentor(object): class Augmentor(object):
""" Base class for an augmentor""" """ Base class for an augmentor"""
__metaclass__ = ABCMeta
def __init__(self): def __init__(self):
self.reset_state() self.reset_state()
......
...@@ -9,6 +9,7 @@ import tensorflow as tf ...@@ -9,6 +9,7 @@ import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import inspect import inspect
import pickle import pickle
import six
from ..utils import logger, INPUT_VARS_KEY from ..utils import logger, INPUT_VARS_KEY
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
...@@ -30,9 +31,9 @@ class InputVar(object): ...@@ -30,9 +31,9 @@ class InputVar(object):
def loads(buf): def loads(buf):
return pickle.loads(buf) return pickle.loads(buf)
@six.add_metaclass(ABCMeta)
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
__metaclass__ = ABCMeta
def get_input_vars(self): def get_input_vars(self):
""" """
......
...@@ -15,9 +15,8 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor', ...@@ -15,9 +15,8 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph', 'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor'] 'DataParallelOfflinePredictor']
@six.add_metaclass(ABCMeta)
class PredictorBase(object): class PredictorBase(object):
__metaclass__ = ABCMeta
""" """
Available attributes: Available attributes:
session session
......
...@@ -7,6 +7,7 @@ from six.moves import range, zip ...@@ -7,6 +7,7 @@ from six.moves import range, zip
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import multiprocessing import multiprocessing
import os import os
import six
from ..dataflow import DataFlow, BatchData from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue from ..dataflow.dftools import dataflow_to_process_queue
...@@ -21,9 +22,8 @@ from .base import OfflinePredictor ...@@ -21,9 +22,8 @@ from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor', __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor'] 'MultiProcessDatasetPredictor']
@six.add_metaclass(ABCMeta)
class DatasetPredictorBase(object): class DatasetPredictorBase(object):
__metaclass__ = ABCMeta
def __init__(self, config, dataset): def __init__(self, config, dataset):
""" """
:param config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re import re
import six
import inspect import inspect
from ..utils import logger from ..utils import logger
from .symbolic_functions import rms from .symbolic_functions import rms
...@@ -31,8 +32,8 @@ def apply_grad_processors(grads, gradprocs): ...@@ -31,8 +32,8 @@ def apply_grad_processors(grads, gradprocs):
g = proc.process(g) g = proc.process(g)
return g return g
@six.add_metaclass(ABCMeta)
class GradientProcessor(object): class GradientProcessor(object):
__metaclass__ = ABCMeta
def process(self, grads): def process(self, grads):
""" """
......
...@@ -20,9 +20,9 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore', ...@@ -20,9 +20,9 @@ __all__ = ['SessionInit', 'NewSession', 'SaverRestore',
# TODO they initialize_all at the beginning by default. # TODO they initialize_all at the beginning by default.
@six.add_metaclass(ABCMeta)
class SessionInit(object): class SessionInit(object):
""" Base class for utilities to initialize a session""" """ Base class for utilities to initialize a session"""
__metaclass__ = ABCMeta
def init(self, sess): def init(self, sess):
""" Initialize a session """ Initialize a session
......
...@@ -6,6 +6,7 @@ from abc import ABCMeta, abstractmethod ...@@ -6,6 +6,7 @@ from abc import ABCMeta, abstractmethod
import signal import signal
import re import re
import weakref import weakref
import six
from six.moves import range from six.moves import range
import tqdm import tqdm
...@@ -22,10 +23,9 @@ __all__ = ['Trainer', 'StopTraining'] ...@@ -22,10 +23,9 @@ __all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException): class StopTraining(BaseException):
pass pass
@six.add_metaclass(ABCMeta)
class Trainer(object): class Trainer(object):
""" Base class for a trainer.""" """ Base class for a trainer."""
__metaclass__ = ABCMeta
"""a `StatHolder` instance""" """a `StatHolder` instance"""
stat_holder = None stat_holder = None
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
import threading import threading
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
...@@ -14,8 +15,8 @@ from ..callbacks.concurrency import StartProcOrThread ...@@ -14,8 +15,8 @@ from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput'] __all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput']
@six.add_metaclass(ABCMeta)
class InputData(object): class InputData(object):
__metaclass__ = ABCMeta
pass pass
class FeedInput(InputData): class FeedInput(InputData):
......
...@@ -7,6 +7,7 @@ from . import logger ...@@ -7,6 +7,7 @@ from . import logger
from .argtools import memoized from .argtools import memoized
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import six
from six.moves import range from six.moves import range
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND'] __all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
...@@ -16,8 +17,8 @@ def log_once(s): ...@@ -16,8 +17,8 @@ def log_once(s):
logger.warn(s) logger.warn(s)
# just a placeholder # just a placeholder
@six.add_metaclass(ABCMeta)
class Discretizer(object): class Discretizer(object):
__metaclass__ = ABCMeta
@abstractmethod @abstractmethod
def get_nr_bin(self): def get_nr_bin(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment