Commit 3ed43ab4 authored by Yuxin Wu's avatar Yuxin Wu

Fix use of InputDesc across graphs (fix #398)

parent 8591e253
...@@ -15,10 +15,12 @@ There are several places where you might want to do something else: ...@@ -15,10 +15,12 @@ There are several places where you might want to do something else:
* After the training (e.g. send the model somewhere, send a message to your phone) * After the training (e.g. send the model somewhere, send a message to your phone)
We found people traditionally tend to write the training loop together with these extra features. We found people traditionally tend to write the training loop together with these extra features.
This makes the loop lengthy, and the code for the same feature probably get separated. This makes the loop lengthy, and the code for the same feature probably get separated (imagine a
feature which needs initialization in the beginning and then some actual work between iterations).
By writing callbacks to implement what to do at each place, tensorpack trainers By writing callbacks to implement what to do at each place, tensorpack trainers
will call the callbacks at the proper time. will call the callbacks at the proper time.
Therefore the code can be reused with one single line, as long as you are using tensorpack trainers. Therefore these features can be reused with one single line, as long as you are using tensorpack trainers.
For example, these are the callbacks I used when training a ResNet: For example, these are the callbacks I used when training a ResNet:
...@@ -30,7 +32,7 @@ TrainConfig( ...@@ -30,7 +32,7 @@ TrainConfig(
ModelSaver(), ModelSaver(),
# backup the model with best validation error # backup the model with best validation error
MinSaver('val-error-top1'), MinSaver('val-error-top1'),
# run inference on another Dataflow every epoch, compute top1/top5 classification error and save them in log # run inference on another Dataflow every epoch, compute classification error and log to monitors
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]), ClassificationError('wrong-top5', 'val-error-top5')]),
...@@ -50,11 +52,11 @@ TrainConfig( ...@@ -50,11 +52,11 @@ TrainConfig(
InjectShell(shell='ipython') InjectShell(shell='ipython')
], ],
extra_callbacks=[ # these callbacks are enabled by default already extra_callbacks=[ # these callbacks are enabled by default already
# maintain and summarize moving average of some tensors defined in the model (e.g. training loss, training error) # maintain those moving average summaries already defined in the model (e.g. training loss, training error)
MovingAverageSummary(), MovingAverageSummary(),
# draw a nice progress bar # draw a nice progress bar
ProgressBar(), ProgressBar(),
# run `tf.summary.merge_all` every epoch and send results to monitors # run `tf.summary.merge_all` every epoch and log to monitors
MergeAllSummaries(), MergeAllSummaries(),
# run ops in GraphKeys.UPDATE_OPS collection along with training, if any # run ops in GraphKeys.UPDATE_OPS collection along with training, if any
RunUpdateOps(), RunUpdateOps(),
......
...@@ -25,8 +25,6 @@ class InputDesc( ...@@ -25,8 +25,6 @@ class InputDesc(
input source. input source.
""" """
_cached_placeholder = None
def __new__(cls, type, shape, name): def __new__(cls, type, shape, name):
""" """
Args: Args:
...@@ -36,6 +34,7 @@ class InputDesc( ...@@ -36,6 +34,7 @@ class InputDesc(
""" """
shape = tuple(shape) # has to be tuple for self to be hashable shape = tuple(shape) # has to be tuple for self to be hashable
self = super(InputDesc, cls).__new__(cls, type, shape, name) self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None
return self return self
# TODO in serialization, skip _cached_placeholder # TODO in serialization, skip _cached_placeholder
...@@ -72,10 +71,10 @@ class InputDesc( ...@@ -72,10 +71,10 @@ class InputDesc(
self.type, shape=self.shape, self.type, shape=self.shape,
name=prefix + self.name) name=prefix + self.name)
if prefix == '' and self._cached_placeholder is None: if prefix == '' and self._cached_placeholder is None:
self._cached_placeholder = ret self._cached_placeholder = ret # cached_placeholder only caches the prefix='' case
return ret return ret
@memoized # cannot memoize here, because InputDesc is hashed by its fields.
def build_placeholder_reuse(self): def build_placeholder_reuse(self):
""" """
Build a tf.placeholder from the metadata, or return an old one. Build a tf.placeholder from the metadata, or return an old one.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment