Commit a313662f authored by Yuxin Wu's avatar Yuxin Wu

add post process optimizer

parent 384d90a9
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+ 2017/02/12. `TrainConfig(optimizer=)` was deprecated. Now optimizer is set in `ModelDesc`. And
gradient processors become part of an optimizer. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/d1041a77a9c59d8c9abf64f389f3b605d65b483e).
* 2017/02/11. `_get_input_vars()` in `ModelDesc` was renamed to `_get_inputs`. `InputVar` was
renamed to `InputDesc`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/5b29bda9f17d7b587259e13963c4c8093e8387f8).
* 2017/01/27. `TrainConfig(step_per_epoch)` was renamed to `steps_per_epoch`. See [commit](https://github.com/ppwwyyxx/tensorpack/commit/a9dd0b8ec34209ab86a92875589dbbc4716e73ef).
......
......@@ -21,10 +21,9 @@ To use trainers, pass a `TrainConfig` to configure them:
```python
config = TrainConfig(
model=MyModel()
dataflow=my_dataflow,
optimizer=tf.train.AdamOptimizer(0.01),
callbacks=[...]
model=MyModel()
)
# start training:
......
......@@ -76,18 +76,11 @@ class GANTrainer(FeedfreeTrainerBase):
with TowerContext(''):
actual_inputs = self._get_input_tensors()
self.model.build_graph(actual_inputs)
opt = self.model.get_optimizer()
# optimize G
grads = self.config.optimizer.compute_gradients(
self.model.g_loss, var_list=self.model.g_vars)
self.g_min = self.config.optimizer.apply_gradients(grads, name='g_op')
# optimize D
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
with tf.control_dependencies([self.g_min]):
grads = self.config.optimizer.compute_gradients(
self.model.d_loss, var_list=self.model.d_vars)
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = self.d_min
def run_step(self):
......
......@@ -6,7 +6,8 @@
import tensorflow as tf
from .gradproc import apply_grad_processors as apply_gradproc
__all__ = ['apply_grad_processors', 'ProxyOptimizer']
__all__ = ['apply_grad_processors', 'ProxyOptimizer',
'PostProcessVariablesOptimizer']
class ProxyOptimizer(tf.train.Optimizer):
......@@ -49,3 +50,43 @@ def apply_grad_processors(opt, gradprocs):
g = apply_gradproc(grads_and_vars, self._gradprocs)
return self._opt.apply_gradients(g, global_step, name)
return _ApplyGradientProcessor(opt, gradprocs)
class PostProcessVariablesOptimizer(ProxyOptimizer):
"""
An optimizer which applies an operation to variables (e.g. clipping,
quantization) after updating the gradient.
"""
def __init__(self, opt, func, colocate=True):
"""
Args:
opt (tf.train.Optimizer):
func (tf.Variable -> tf.Operation or None): the operation needed
to perform for this variable after the gradient update.
colocate (boolean): colocate the function with the variable.
"""
super(PostProcessVariablesOptimizer, self).__init__(opt)
self._func = func
self._colocate = colocate
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
update_op = super(PostProcessVariablesOptimizer, self).apply_gradients(
grads_and_vars, global_step)
ops = []
with tf.control_dependencies([update_op]):
for _, var in grads_and_vars:
with self._maybe_colocate(var):
op = self._func(var)
assert isinstance(op, tf.Operation), op
if op is not None:
ops.append(op)
update_op = tf.group(update_op, *ops, name=name)
return update_op
def _maybe_colocate(self, var):
G = tf.get_default_graph()
if self._colocate:
with G.colocate_with(var):
yield
else:
yield
......@@ -109,8 +109,8 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
"""
config.data = QueueInput(config.dataflow, input_queue)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment