Commit 53549d52 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] start moving optimizer to model

parent 5b29bda9
...@@ -49,6 +49,7 @@ InputVar = InputDesc ...@@ -49,6 +49,7 @@ InputVar = InputDesc
class ModelDesc(object): class ModelDesc(object):
""" Base class for a model description """ """ Base class for a model description """
# inputs:
def get_reused_placehdrs(self): def get_reused_placehdrs(self):
""" """
Create or return (if already created) raw input TF placeholders in the graph. Create or return (if already created) raw input TF placeholders in the graph.
...@@ -97,13 +98,14 @@ class ModelDesc(object): ...@@ -97,13 +98,14 @@ class ModelDesc(object):
""" """
:returns: a list of InputDesc :returns: a list of InputDesc
""" """
# TODO deprecate @ Mar 11 # TODO deprecate @ Apr 11
logger.warn("[Deprecated] _get_input_vars() is renamed to _get_inputs()") logger.warn("[Deprecated] _get_input_vars() is renamed to _get_inputs()")
return self._get_input_vars() return self._get_input_vars()
def _get_input_vars(self): # keep backward compatibility def _get_input_vars(self): # keep backward compatibility
raise NotImplementedError() raise NotImplementedError()
# graph, cost, optimizer:
def build_graph(self, model_inputs): def build_graph(self, model_inputs):
""" """
Build the whole symbolic graph. Build the whole symbolic graph.
...@@ -153,6 +155,16 @@ class ModelDesc(object): ...@@ -153,6 +155,16 @@ class ModelDesc(object):
def _get_cost(self, *args): def _get_cost(self, *args):
return self.cost return self.cost
def get_optimizer(self):
"""
Returns:
a :class:`tf.train.Optimizer` instance.
"""
return self._get_optimizer()
def _get_optimizer(self):
raise NotImplementedError()
def get_gradient_processor(self): def get_gradient_processor(self):
""" Return a list of :class:`tensorpack.tfutils.GradientProcessor`. """ Return a list of :class:`tensorpack.tfutils.GradientProcessor`.
They will be executed by the trainer in the given order. They will be executed by the trainer in the given order.
......
...@@ -24,7 +24,7 @@ class TrainConfig(object): ...@@ -24,7 +24,7 @@ class TrainConfig(object):
""" """
def __init__(self, dataflow=None, data=None, def __init__(self, dataflow=None, data=None,
model=None, optimizer=None, model=None,
callbacks=None, extra_callbacks=None, callbacks=None, extra_callbacks=None,
session_config=get_default_sess_config(), session_config=get_default_sess_config(),
session_init=None, session_init=None,
...@@ -37,7 +37,6 @@ class TrainConfig(object): ...@@ -37,7 +37,6 @@ class TrainConfig(object):
data (InputData): an `InputData` instance. Only one of ``dataflow`` data (InputData): an `InputData` instance. Only one of ``dataflow``
or ``data`` has to be present. or ``data`` has to be present.
model (ModelDesc): the model to train. model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
...@@ -74,9 +73,6 @@ class TrainConfig(object): ...@@ -74,9 +73,6 @@ class TrainConfig(object):
assert_type(self.data, InputData) assert_type(self.data, InputData)
self.dataflow = None self.dataflow = None
self.optimizer = optimizer
assert_type(self.optimizer, tf.train.Optimizer)
if isinstance(callbacks, Callbacks): if isinstance(callbacks, Callbacks):
# keep quiet now because I haven't determined the final API yet. # keep quiet now because I haven't determined the final API yet.
logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!") logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!")
...@@ -133,10 +129,17 @@ class TrainConfig(object): ...@@ -133,10 +129,17 @@ class TrainConfig(object):
if isinstance(self.predict_tower, int): if isinstance(self.predict_tower, int):
self.predict_tower = [self.predict_tower] self.predict_tower = [self.predict_tower]
if 'optimizer' in kwargs:
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
else:
self.optimizer = None
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None): def set_tower(self, nr_tower=None, tower=None):
# this is a deprecated function # this is a deprecated function
# TODO Deprecate @ Mar 15
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly") logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!" assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower: if nr_tower:
......
...@@ -21,7 +21,6 @@ class FeedfreeTrainerBase(Trainer): ...@@ -21,7 +21,6 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster) """ A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``self.data`` to be a :class:`FeedfreeInput`.
""" """
def _trigger_epoch(self): def _trigger_epoch(self):
# run summary_op every epoch # run summary_op every epoch
# TODO FIXME summary_op will take a data! This is not good for TensorInput. # TODO FIXME summary_op will take a data! This is not good for TensorInput.
...@@ -45,7 +44,11 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -45,7 +44,11 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
self.model.build_graph(actual_inputs) self.model.build_graph(actual_inputs)
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
# GATE_NONE faster? # GATE_NONE faster?
grads = self.config.optimizer.compute_gradients( opt = self.config.optimizer
if opt is None:
opt = self.model.get_optimizer() # XXX TODO not gonna work if optimizer modifies grad
self.config.optimizer = opt
grads = opt.compute_gradients(
cost_var, cost_var,
gate_gradients=tf.train.Optimizer.GATE_NONE, gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=False) colocate_gradients_with_ops=False)
......
...@@ -84,11 +84,14 @@ class SimpleTrainer(Trainer): ...@@ -84,11 +84,14 @@ class SimpleTrainer(Trainer):
model.build_graph(self.input_vars) model.build_graph(self.input_vars)
cost_var = model.get_cost() cost_var = model.get_cost()
grads = self.config.optimizer.compute_gradients(cost_var) opt = self.config.optimizer
if not opt:
opt = model.get_optimizer()
grads = opt.compute_gradients(cost_var)
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads,
self.model.get_gradient_processor()) self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
def _trigger_epoch(self): def _trigger_epoch(self):
if self.summary_op is not None: if self.summary_op is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment