Commit a313662f authored by Yuxin Wu's avatar Yuxin Wu

add post process optimizer

parent 384d90a9
...@@ -8,6 +8,8 @@ so you won't need to look at here very often. ...@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changes API and those are not listed here.
+ 2017/02/12. `TrainConfig(optimizer=)` was deprecated. Now optimizer is set in `ModelDesc`. And
gradient processors become part of an optimizer. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/d1041a77a9c59d8c9abf64f389f3b605d65b483e).
* 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was * 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was
renamed to `InputDesc`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/5b29bda9f17d7b587259e13963c4c8093e8387f8). renamed to `InputDesc`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/5b29bda9f17d7b587259e13963c4c8093e8387f8).
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/a9dd0b8ec34209ab86a92875589dbbc4716e73ef). * 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/a9dd0b8ec34209ab86a92875589dbbc4716e73ef).
......
...@@ -21,10 +21,9 @@ To use trainers, pass a `TrainConfig` to configure them: ...@@ -21,10 +21,9 @@ To use trainers, pass a `TrainConfig` to configure them:
```python ```python
config = TrainConfig( config = TrainConfig(
model=MyModel()
dataflow=my_dataflow, dataflow=my_dataflow,
optimizer=tf.train.AdamOptimizer(0.01),
callbacks=[...] callbacks=[...]
model=MyModel()
) )
# start training: # start training:
......
...@@ -76,18 +76,11 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -76,18 +76,11 @@ class GANTrainer(FeedfreeTrainerBase):
with TowerContext(''): with TowerContext(''):
actual_inputs = self._get_input_tensors() actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
opt = self.model.get_optimizer()
# optimize G self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
grads = self.config.optimizer.compute_gradients(
self.model.g_loss, var_list=self.model.g_vars)
self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op')
# optimize D
with tf.control_dependencies([self.g_min]): with tf.control_dependencies([self.g_min]):
grads = self.config.optimizer.compute_gradients( self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.model.d_loss, var_list=self.model.d_vars)
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
self.train_op = self.d_min self.train_op = self.d_min
def run_step(self): def run_step(self):
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
import tensorflow as tf import tensorflow as tf
from .gradproc import apply_grad_processors as apply_gradproc from .gradproc import apply_grad_processors as apply_gradproc
__all__ = ['apply_grad_processors', 'ProxyOptimizer'] __all__ = ['apply_grad_processors', 'ProxyOptimizer',
'PostProcessVariablesOptimizer']
class ProxyOptimizer(tf.train.Optimizer): class ProxyOptimizer(tf.train.Optimizer):
...@@ -49,3 +50,43 @@ def apply_grad_processors(opt, gradprocs): ...@@ -49,3 +50,43 @@ def apply_grad_processors(opt, gradprocs):
g = apply_gradproc(grads_and_vars, self._gradprocs) g = apply_gradproc(grads_and_vars, self._gradprocs)
return self._opt.apply_gradients(g, global_step, name) return self._opt.apply_gradients(g, global_step, name)
return _ApplyGradientProcessor(opt, gradprocs) return _ApplyGradientProcessor(opt, gradprocs)
class PostProcessVariablesOptimizer(ProxyOptimizer):
"""
An optimizer which applies an operation to variables (e.g. clipping,
quantization) after updating the gradient.
"""
def __init__(self, opt, func, colocate=True):
"""
Args:
opt (tf.train.Optimizer):
func (tf.Variable -> tf.Operation or None): the operation needed
to perform for this variable after the gradient update.
colocate (boolean): colocate the function with the variable.
"""
super(PostProcessVariablesOptimizer, self).__init__(opt)
self._func = func
self._colocate = colocate
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = super(PostProcessVariablesOptimizer, self).apply_gradients(
grads_and_vars, global_step)
ops = []
with tf.control_dependencies([update_op]):
for _, var in grads_and_vars:
with self._maybe_colocate(var):
op = self._func(var)
assert isinstance(op, tf.Operation), op
if op is not None:
ops.append(op)
update_op = tf.group(update_op, *ops, name=name)
return update_op
def _maybe_colocate(self, var):
G = tf.get_default_graph()
if self._colocate:
with G.colocate_with(var):
yield
else:
yield
...@@ -109,8 +109,8 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None): ...@@ -109,8 +109,8 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``. It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args: Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist. config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default. :class:`QueueInput` default.
""" """
config.data = QueueInput(config.dataflow, input_queue) config.data = QueueInput(config.dataflow, input_queue)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment