Commit 3498c0b6 authored by Yuxin Wu's avatar Yuxin Wu

clean-up ModelDesc

parent 490142d7
......@@ -9,8 +9,6 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
from ..tfutils.tower import get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection
......@@ -115,9 +113,6 @@ class ModelDesc(ModelDescBase):
def get_cost(self):
"""
Return the cost tensor in the graph.
It will be called by :func:`get_cost_and_grad` by default.
You can ignore this method (or just use :class:`ModelDescBase`)
if you use your own trainer with more than one cost.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
......@@ -138,9 +133,7 @@ class ModelDesc(ModelDescBase):
@memoized
def get_optimizer(self):
"""
Return the optimizer used in the task.
Used by some of the tensorpack :class:`Trainer` which assume single optimizer.
You should use :class:`ModelDescBase` if you use a custom trainer with more than one optimizers.
Return the memoized optimizer returned by `_get_optimizer`.
Users of :class:`ModelDesc` will need to implement `_get_optimizer()`,
which will only be called once per each model.
......@@ -153,28 +146,11 @@ class ModelDesc(ModelDescBase):
def _get_optimizer(self):
raise NotImplementedError()
def get_cost_and_grad(self):
def build_graph_get_cost(self, *inputs):
"""
Compute gradients with ``self.get_optimizer()`` on ``self.get_cost()``.
This method will be used by all the existing tensorpack trainers.
Returns:
cost (tf.Tensor): the cost tensor returned by ``self.get_cost()``.
grads (list[tuple]): list of (grad, variable) tuple.
Build the graph from inputs and return the cost tensor.
This is useful for most of the :class:`GraphBuilder` which expects
such a function.
"""
return self._get_cost_and_grad()
def _get_cost_and_grad(self):
ctx = get_current_tower_context()
assert ctx is not None and ctx.is_training, ctx
cost = self.get_cost() # assume single cost
# produce gradients
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = self.get_optimizer()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return cost, grads
self.build_graph(inputs)
return self.get_cost()
......@@ -277,6 +277,12 @@ class Trainer(object):
self.model, self.vs_name_for_predictor)
return self._predictor_factory
@property
def vs_name_for_predictor(self):
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
return ""
def launch_train(
run_step, model=None, callbacks=None, extra_callbacks=None, monitors=None,
......
......@@ -94,12 +94,8 @@ class DistributedTrainerReplicated(Trainer):
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
def get_cost(*inputs):
self.model.build_graph(inputs)
return self.model.get_cost()
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
self._input_source, get_cost, self.model.get_optimizer)
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
# initial local_vars syncing
cb = RunOp(lambda: initial_sync_op,
......@@ -148,3 +144,7 @@ class DistributedTrainerReplicated(Trainer):
return _create_session()
self.config.session_creator = _Creator()
@property
def vs_name_for_predictor(self):
return "tower0"
......@@ -67,13 +67,9 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
def get_cost(*inputs):
self.model.build_graph(inputs)
return self.model.get_cost()
self.train_op = SyncMultiGPUParameterServerBuilder(
self.config.tower, self._ps_device).build(
self._input_source, get_cost, self.model.get_optimizer)
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(callbacks)
......@@ -103,12 +99,8 @@ class SyncMultiGPUTrainerReplicated(Trainer):
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
def get_cost(*inputs):
self.model.build_graph(inputs)
return self.model.get_cost()
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self.config.tower).build(
self._input_source, get_cost, self.model.get_optimizer)
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
cb = RunOp(
lambda: post_init_op,
......@@ -134,12 +126,8 @@ class AsyncMultiGPUTrainer(Trainer):
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
def get_cost(*inputs):
self.model.build_graph(inputs)
return self.model.get_cost()
self.train_op = AsyncMultiGPUBuilder(
self.config.tower, self._scale_gradient).build(
self._input_source, get_cost, self.model.get_optimizer)
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(callbacks)
......@@ -42,11 +42,8 @@ class SimpleTrainer(Trainer):
def _setup(self):
cbs = self._input_source.setup(self.model.get_inputs_desc())
def get_cost(*inputs):
self.model.build_graph(inputs)
return self.model.get_cost()
self.train_op = SimpleBuilder().build(self._input_source, get_cost, self.model.get_optimizer)
self.train_op = SimpleBuilder().build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(cbs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment