Commit 6b6947fc authored by Yuxin Wu's avatar Yuxin Wu

gradproc support replicated

parent 99a7d749
......@@ -24,9 +24,9 @@ It's Yet Another TF wrapper, but different in:
Tensorpack helps you load large datasets (e.g. ImageNet) in __pure Python__ with autoparallelization.
3. It's not a model wrapper.
+ There are too many symbolic function wrappers.
+ There are too many symbolic function wrappers in the world.
Tensorpack includes only a few common models.
You can use any symbolic function library inside tensorpack, including tflayers/Keras/slim/tflearn/tensorlayer/....
But you can use any symbolic function library inside tensorpack, including tf.layers/Keras/slim/tflearn/tensorlayer/....
See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) to know more about these features.
......
......@@ -95,6 +95,8 @@ class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session and given tensors.
"""
ACCEPT_OPTIONS = False
def __init__(self, input_tensors, output_tensors,
return_input=False, sess=None):
"""
......@@ -115,7 +117,8 @@ class OnlinePredictor(PredictorBase):
if sess is not None:
self._callable = sess.make_callable(
fetches=output_tensors,
feed_list=input_tensors)
feed_list=input_tensors,
accept_options=self.ACCEPT_OPTIONS)
else:
self._callable = None
else:
......@@ -131,8 +134,12 @@ class OnlinePredictor(PredictorBase):
if self._callable is None:
self._callable = self.sess.make_callable(
fetches=self.output_tensors,
feed_list=self.input_tensors)
return self._callable(*dp)
feed_list=self.input_tensors,
accept_options=self.ACCEPT_OPTIONS)
# run_metadata = tf.RunMetadata()
# options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
ret = self._callable(*dp)
return ret
def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \
......
......@@ -152,18 +152,23 @@ class SummaryGradient(MapGradient):
# TODO this is global. not good.
_summaried_gradient = set()
def __init__(self, regex='.*'):
def __init__(self, regex='.*', collections=None):
"""
Args:
regex(str): same as in :class:`MapGradient`.
collections (list[str]): list of collection names
"""
super(SummaryGradient, self).__init__(self._mapper, regex)
self._coll = collections
def _mapper(self, grad, var):
name = var.op.name
if re.match('tower[0-9]+/', name):
# replicated training, var may come from different towers
return grad
if name not in SummaryGradient._summaried_gradient:
SummaryGradient._summaried_gradient.add(name)
tf.summary.histogram(name + '-grad', grad)
tf.summary.histogram(name + '-grad', grad, collections=self._coll)
add_moving_summary(rms(grad, name=name + '/rms'))
return grad
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment