Commit 6b6947fc authored by Yuxin Wu's avatar Yuxin Wu

gradproc support replicated

parent 99a7d749
...@@ -24,9 +24,9 @@ It's Yet Another TF wrapper, but different in: ...@@ -24,9 +24,9 @@ It's Yet Another TF wrapper, but different in:
Tensorpack helps you load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization. Tensorpack helps you load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization.
3. It's not a model wrapper. 3. It's not a model wrapper.
+ There are too many symbolic function wrappers. + There are too many symbolic function wrappers in the world.
Tensorpack includes only a few common models. Tensorpack includes only a few common models.
You can use any symbolic function library inside tensorpack, including tflayers/Keras/slim/tflearn/tensorlayer/.... But you can use any symbolic function library inside tensorpack, including tf.layers/Keras/slim/tflearn/tensorlayer/....
See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) to know more about these features. See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) to know more about these features.
......
...@@ -95,6 +95,8 @@ class OnlinePredictor(PredictorBase): ...@@ -95,6 +95,8 @@ class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session and given tensors. """ A predictor which directly use an existing session and given tensors.
""" """
ACCEPT_OPTIONS = False
def __init__(self, input_tensors, output_tensors, def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None): return_input=False, sess=None):
""" """
...@@ -115,7 +117,8 @@ class OnlinePredictor(PredictorBase): ...@@ -115,7 +117,8 @@ class OnlinePredictor(PredictorBase):
if sess is not None: if sess is not None:
self._callable = sess.make_callable( self._callable = sess.make_callable(
fetches=output_tensors, fetches=output_tensors,
feed_list=input_tensors) feed_list=input_tensors,
accept_options=self.ACCEPT_OPTIONS)
else: else:
self._callable = None self._callable = None
else: else:
...@@ -131,8 +134,12 @@ class OnlinePredictor(PredictorBase): ...@@ -131,8 +134,12 @@ class OnlinePredictor(PredictorBase):
if self._callable is None: if self._callable is None:
self._callable = self.sess.make_callable( self._callable = self.sess.make_callable(
fetches=self.output_tensors, fetches=self.output_tensors,
feed_list=self.input_tensors) feed_list=self.input_tensors,
return self._callable(*dp) accept_options=self.ACCEPT_OPTIONS)
# run_metadata = tf.RunMetadata()
# options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
ret = self._callable(*dp)
return ret
def _do_call(self, dp): def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \ assert len(dp) == len(self.input_tensors), \
......
...@@ -152,18 +152,23 @@ class SummaryGradient(MapGradient): ...@@ -152,18 +152,23 @@ class SummaryGradient(MapGradient):
# TODO this is global. not good. # TODO this is global. not good.
_summaried_gradient = set() _summaried_gradient = set()
def __init__(self, regex='.*'): def __init__(self, regex='.*', collections=None):
""" """
Args: Args:
regex(str): same as in :class:`MapGradient`. regex(str): same as in :class:`MapGradient`.
collections (list[str]): list of collection names
""" """
super(SummaryGradient, self).__init__(self._mapper, regex) super(SummaryGradient, self).__init__(self._mapper, regex)
self._coll = collections
def _mapper(self, grad, var): def _mapper(self, grad, var):
name = var.op.name name = var.op.name
if re.match('tower[0-9]+/', name):
# replicated training, var may come from different towers
return grad
if name not in SummaryGradient._summaried_gradient: if name not in SummaryGradient._summaried_gradient:
SummaryGradient._summaried_gradient.add(name) SummaryGradient._summaried_gradient.add(name)
tf.summary.histogram(name + '-grad', grad) tf.summary.histogram(name + '-grad', grad, collections=self._coll)
add_moving_summary(rms(grad, name=name + '/rms')) add_moving_summary(rms(grad, name=name + '/rms'))
return grad return grad
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment