Commit 0bd1e92f authored by Yuxin Wu's avatar Yuxin Wu

better log for grad & model preparation

parent d04661e3
# tensorpack # tensorpack
Neural Network Toolbox on TensorFlow Neural Network Toolbox on TensorFlow
In development. No document. In development. No document. See [examples](https://github.com/ppwwyyxx/tensorpack/tree/master/examples).
## Features: ## Features:
+ Scoped abstraction of common models. + Scoped abstraction of common models.
+ Callbacks systems to control different aspects of training. + Use `Dataflow` to define data preprocessing in pure Python.
+ Use `Dataflow` to gain fine-grained control on data preprocessing. + Callbacks systems to control training.
+ Training and testing graph are modeled together. Just need to follow the conventions to setup stuffs. + Training and testing are described together. Just need to follow the conventions to setup stuffs.
+ Write summary easier for tensorboard. + Write summary easier for tensorboard.
...@@ -36,7 +36,7 @@ class ModelSaver(Callback): ...@@ -36,7 +36,7 @@ class ModelSaver(Callback):
for v in vars: for v in vars:
name = v.op.name name = v.op.name
if re.match('tower[1-9]', name): if re.match('tower[1-9]', name):
logger.info("Skip {} when saving model.".format(name)) #logger.info("Skip {} when saving model.".format(name))
continue continue
if 'tower0/' in name: if 'tower0/' in name:
new_name = name.replace('tower0/', '') new_name = name.replace('tower0/', '')
......
...@@ -50,7 +50,7 @@ class HyperParamSetter(Callback): ...@@ -50,7 +50,7 @@ class HyperParamSetter(Callback):
ret = self._get_current_value() ret = self._get_current_value()
if ret is not None and ret != self.last_value: if ret is not None and ret != self.last_value:
logger.info("{} at epoch {} will change to {}".format( logger.info("{} at epoch {} will change to {}".format(
self.op_name, self.epoch_num, ret)) self.op_name, self.epoch_num + 1, ret))
self.last_value = ret self.last_value = ret
return ret return ret
......
...@@ -36,8 +36,8 @@ class CenterCrop(ImageAugmentor): ...@@ -36,8 +36,8 @@ class CenterCrop(ImageAugmentor):
def _augment(self, img): def _augment(self, img):
orig_shape = img.arr.shape orig_shape = img.arr.shape
h0 = (orig_shape[0] - self.crop_shape[0]) * 0.5 h0 = int((orig_shape[0] - self.crop_shape[0]) * 0.5)
w0 = (orig_shape[1] - self.crop_shape[1]) * 0.5 w0 = int((orig_shape[1] - self.crop_shape[1]) * 0.5)
img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]] img.arr = img.arr[h0:h0+self.crop_shape[0],w0:w0+self.crop_shape[1]]
if img.coords: if img.coords:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -88,8 +88,12 @@ class QueueInputTrainer(Trainer): ...@@ -88,8 +88,12 @@ class QueueInputTrainer(Trainer):
ret = [] ret = []
with tf.device('/gpu:0'): with tf.device('/gpu:0'):
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v)) ret.append((grad, v))
return ret return ret
...@@ -129,6 +133,7 @@ class QueueInputTrainer(Trainer): ...@@ -129,6 +133,7 @@ class QueueInputTrainer(Trainer):
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
for k in coll_keys: for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k)) kept_summaries[k] = copy.copy(tf.get_collection(k))
logger.info("Graph built for tower {}.".format(i))
for k in coll_keys: for k in coll_keys:
del tf.get_collection(k)[:] del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection(k).extend(kept_summaries[k])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment