Commit 9fd9f1ed authored by Yuxin Wu's avatar Yuxin Wu

Fix missing `restore_collection` in tower. (fix #462)

parent a812979a
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: predictor_factory.py # File: predict.py
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
......
...@@ -87,7 +87,8 @@ class FeedInput(InputSource): ...@@ -87,7 +87,8 @@ class FeedInput(InputSource):
return self.ds.size() return self.ds.size()
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder(prefix='') for v in inputs] # placeholders as input are always safe to reuse.
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs) self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
def _get_input_tensors(self): def _get_input_tensors(self):
......
...@@ -150,6 +150,7 @@ class CollectionGuard(object): ...@@ -150,6 +150,7 @@ class CollectionGuard(object):
self._name, ', '.join( self._name, ', '.join(
map(lambda t: "({}: {}->{})".format(*t), map(lambda t: "({}: {}->{})".format(*t),
size_change)))) size_change))))
restore_collection(self._freeze_backup)
def get_collection_in_tower(self, key): def get_collection_in_tower(self, key):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment