Commit 118c2a26 authored by Yuxin Wu's avatar Yuxin Wu

clean-up some deprecations

parent 01486c39
...@@ -18,7 +18,7 @@ Reproduce the following GAN-related methods: ...@@ -18,7 +18,7 @@ Reproduce the following GAN-related methods:
+ BEGAN ([BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717)) + BEGAN ([BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717))
Please see the __docstring__ in each script for detailed usage and pretrained models. Please see the __docstring__ in each script for detailed usage and pretrained models. MultiGPU training is supported.
## DCGAN.py ## DCGAN.py
......
...@@ -10,14 +10,11 @@ import six ...@@ -10,14 +10,11 @@ import six
from ..utils import logger from ..utils import logger
from ..utils.naming import INPUTS_KEY from ..utils.naming import INPUTS_KEY
from ..utils.develop import deprecated, log_deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.model_utils import apply_slim_collections from ..tfutils.model_utils import apply_slim_collections
__all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph'] __all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph']
# TODO "variable" is not the right name to use for input here.
class InputDesc(object): class InputDesc(object):
""" Store metadata about input placeholders. """ """ Store metadata about input placeholders. """
...@@ -50,7 +47,8 @@ class InputVar(InputDesc): ...@@ -50,7 +47,8 @@ class InputVar(InputDesc):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description.
"""
# inputs: # inputs:
@memoized @memoized
...@@ -63,11 +61,6 @@ class ModelDesc(object): ...@@ -63,11 +61,6 @@ class ModelDesc(object):
""" """
return self.build_placeholders() return self.build_placeholders()
@deprecated("Use get_reused_placehdrs() instead.", "2017-04-11")
def get_input_vars(self):
# this wasn't a public API anyway
return self.get_reused_placehdrs()
def build_placeholders(self, prefix=''): def build_placeholders(self, prefix=''):
""" """
For each InputDesc, create new placeholders with optional prefix and For each InputDesc, create new placeholders with optional prefix and
...@@ -76,12 +69,12 @@ class ModelDesc(object): ...@@ -76,12 +69,12 @@ class ModelDesc(object):
Returns: Returns:
list[tf.Tensor]: the list of built placeholders. list[tf.Tensor]: the list of built placeholders.
""" """
input_vars = self._get_inputs() inputs = self._get_inputs()
for v in input_vars: for v in inputs:
tf.add_to_collection(INPUTS_KEY, v.dumps()) tf.add_to_collection(INPUTS_KEY, v.dumps())
ret = [] ret = []
with tf.name_scope(None): # clear any name scope it might get called in with tf.name_scope(None): # clear any name scope it might get called in
for v in input_vars: for v in inputs:
placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder placehdr_f = tf.placeholder if not v.sparse else tf.sparse_placeholder
ret.append(placehdr_f( ret.append(placehdr_f(
v.type, shape=v.shape, v.type, shape=v.shape,
...@@ -95,15 +88,11 @@ class ModelDesc(object): ...@@ -95,15 +88,11 @@ class ModelDesc(object):
""" """
return self._get_inputs() return self._get_inputs()
def _get_inputs(self): # this is a better name than _get_input_vars @abstractmethod
def _get_inputs(self):
""" """
:returns: a list of InputDesc :returns: a list of InputDesc
""" """
log_deprecated("", "_get_input_vars() was renamed to _get_inputs().", "2017-04-11")
return self._get_input_vars()
def _get_input_vars(self): # keep backward compatibility
raise NotImplementedError()
def build_graph(self, model_inputs): def build_graph(self, model_inputs):
""" """
...@@ -142,8 +131,8 @@ class ModelDesc(object): ...@@ -142,8 +131,8 @@ class ModelDesc(object):
def get_optimizer(self): def get_optimizer(self):
""" """
Return the optimizer used in the task. Return the optimizer used in the task.
Used by some of the tensorpack :class:`Trainer` which only uses a single optimizer. Used by some of the tensorpack :class:`Trainer` which assume single optimizer.
You can ignore this method if you use your own trainer with more than one optimizers. You can (and should) ignore this method if you use a custom trainer with more than one optimizers.
Users of :class:`ModelDesc` will need to implement `_get_optimizer()`, Users of :class:`ModelDesc` will need to implement `_get_optimizer()`,
which will only be called once per each model. which will only be called once per each model.
...@@ -157,6 +146,9 @@ class ModelDesc(object): ...@@ -157,6 +146,9 @@ class ModelDesc(object):
raise NotImplementedError() raise NotImplementedError()
def get_gradient_processor(self): def get_gradient_processor(self):
return self._get_gradient_processor()
def _get_gradient_processor(self):
return [] return []
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# File: config.py # File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
ProgressBar, MergeAllSummaries, ProgressBar, MergeAllSummaries,
...@@ -15,7 +13,6 @@ from ..utils.develop import log_deprecated ...@@ -15,7 +13,6 @@ from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors
from .input_source import InputSource from .input_source import InputSource
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -154,15 +151,9 @@ class TrainConfig(object): ...@@ -154,15 +151,9 @@ class TrainConfig(object):
assert len(set(self.predict_tower)) == len(self.predict_tower), \ assert len(set(self.predict_tower)) == len(self.predict_tower), \
"Cannot have duplicated predict_tower!" "Cannot have duplicated predict_tower!"
if 'optimizer' in kwargs: assert 'optimizer' not in kwargs, \
log_deprecated("TrainConfig(optimizer=...)", "TrainConfig(optimizer=...) was already deprecated! " \
"Use ModelDesc._get_optimizer() instead.", "Use ModelDesc._get_optimizer() instead."
"2017-04-12")
self._optimizer = kwargs.pop('optimizer')
assert_type(self._optimizer, tf.train.Optimizer)
else:
self._optimizer = None
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
@property @property
...@@ -176,19 +167,3 @@ class TrainConfig(object): ...@@ -176,19 +167,3 @@ class TrainConfig(object):
@property @property
def callbacks(self): # disable setter def callbacks(self): # disable setter
return self._callbacks return self._callbacks
@property
def optimizer(self):
""" for back-compatibilty only. will remove in the future"""
if self._optimizer:
opt = self._optimizer
else:
opt = self.model.get_optimizer()
gradproc = self.model.get_gradient_processor()
if gradproc:
log_deprecated("ModelDesc.get_gradient_processor()",
"Use gradient processor to build an optimizer instead.", "2017-04-12")
opt = apply_grad_processors(opt, gradproc)
if not self._optimizer:
self._optimizer = opt
return opt
...@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -64,7 +64,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" get the cost and gradient""" """ get the cost and gradient"""
self.build_train_tower() self.build_train_tower()
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
opt = self.config.optimizer # TODO XXX opt = self.model.get_optimizer()
# GATE_NONE faster? # GATE_NONE faster?
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, cost,
...@@ -96,7 +96,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer): ...@@ -96,7 +96,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
super(SimpleFeedfreeTrainer, self)._setup() super(SimpleFeedfreeTrainer, self)._setup()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op') opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
# skip training # skip training
# self.train_op = tf.group(*self._input_tensors) # self.train_op = tf.group(*self._input_tensors)
......
...@@ -364,8 +364,8 @@ class StagingInputWrapper(FeedfreeInput): ...@@ -364,8 +364,8 @@ class StagingInputWrapper(FeedfreeInput):
devices: list of devices to be used for each training tower devices: list of devices to be used for each training tower
nr_stage: number of elements to prefetch nr_stage: number of elements to prefetch
""" """
assert isinstance(input, FeedfreeInput), input
self._input = input self._input = input
assert isinstance(input, FeedfreeInput)
self._devices = devices self._devices = devices
self._nr_stage = nr_stage self._nr_stage = nr_stage
self._areas = [] self._areas = []
......
...@@ -167,7 +167,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree ...@@ -167,7 +167,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase, SingleCostFeedfree
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list) grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
# grads = grad_list[0] # grads = grad_list[0]
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op') self.train_op = self.model.get_optimizer().apply_gradients(grads, name='min_op')
def SyncMultiGPUTrainer(config): def SyncMultiGPUTrainer(config):
...@@ -217,7 +217,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, ...@@ -217,7 +217,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
grad_list = [gradproc.process(gv) for gv in grad_list] grad_list = [gradproc.process(gv) for gv in grad_list]
# use grad from the first tower for iteration in main thread # use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op') self._opt = self.model.get_optimizer()
self.train_op = self._opt.apply_gradients(grad_list[0], name='min_op')
self._start_async_threads(grad_list) self._start_async_threads(grad_list)
...@@ -227,7 +228,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase, ...@@ -227,7 +228,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
self.async_step_counter = itertools.count() self.async_step_counter = itertools.count()
self.training_threads = [] self.training_threads = []
for k in range(1, self.config.nr_tower): for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self._opt.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding def f(op=train_op): # avoid late-binding
self.sess.run([op]) # TODO this won't work with StageInput self.sess.run([op]) # TODO this won't work with StageInput
......
...@@ -42,5 +42,5 @@ class SimpleTrainer(Trainer): ...@@ -42,5 +42,5 @@ class SimpleTrainer(Trainer):
model.build_graph(self.inputs) model.build_graph(self.inputs)
cost_var = model.get_cost() cost_var = model.get_cost()
opt = self.config.optimizer opt = self.model.get_optimizer()
self.train_op = opt.minimize(cost_var, name='min_op') self.train_op = opt.minimize(cost_var, name='min_op')
...@@ -4,7 +4,9 @@ flake8 . ...@@ -4,7 +4,9 @@ flake8 .
cd examples cd examples
GIT_ARG="--git-dir ../.git --work-tree .." GIT_ARG="--git-dir ../.git --work-tree .."
# find out modified python files # find out modified python files, so that we ignored unstaged files
MOD=$(git $GIT_ARG status -s | grep -E '\.py$' | grep -E '^ *M|^ *A ' | cut -c 4-) MOD=$(git $GIT_ARG status -s | grep -E '\.py$' | grep -E '^ *M|^ *A ' | cut -c 4-)
# git $GIT_ARG status -s | grep -E '\.py$' # git $GIT_ARG status -s | grep -E '\.py$'
flake8 $MOD if [[ -n $MOD ]]; then
flake8 $MOD
fi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment