Commit 9fd9f1ed authored by Yuxin Wu's avatar Yuxin Wu

Fix missing `restore_collection` in tower. (fix #462)

parent a812979a
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: predictor_factory.py
# File: predict.py
import tensorflow as tf
from contextlib import contextmanager
......
......@@ -87,7 +87,8 @@ class FeedInput(InputSource):
return self.ds.size()
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder(prefix='') for v in inputs]
# placeholders as input are always safe to reuse.
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._cb = self._FeedCallback(self._iter_ds, self._all_placehdrs)
def _get_input_tensors(self):
......
......@@ -150,6 +150,7 @@ class CollectionGuard(object):
self._name, ', '.join(
map(lambda t: "({}: {}->{})".format(*t),
size_change))))
restore_collection(self._freeze_backup)
def get_collection_in_tower(self, key):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment