Commit 4cb898a4 authored by Yuxin Wu's avatar Yuxin Wu

fix linting

parent 95ab1563
...@@ -9,7 +9,7 @@ import six ...@@ -9,7 +9,7 @@ import six
from six.moves import zip, range from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import logger, get_tqdm, PREDICT_TOWER, SUMMARY_BACKUP_KEYS from ..utils import logger, get_tqdm, SUMMARY_BACKUP_KEYS
from ..tfutils.common import get_op_tensor_name, freeze_collection from ..tfutils.common import get_op_tensor_name, freeze_collection
from ..tfutils import TowerContext from ..tfutils import TowerContext
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
...@@ -189,7 +189,7 @@ class FeedfreeInferenceRunner(Callback): ...@@ -189,7 +189,7 @@ class FeedfreeInferenceRunner(Callback):
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_): def fn(_):
self.trainer.model.build_graph(self._input_tensors) self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0], prefix=self._prefix) # TODO use towerp1 to support multiple FeedfreeInferenceRunner build_prediction_graph(fn, [0], prefix=self._prefix)
self._tower_prefix = TowerContext.get_predict_tower_name(self._prefix, 0) self._tower_prefix = TowerContext.get_predict_tower_name(self._prefix, 0)
self._find_output_tensors() self._find_output_tensors()
......
...@@ -7,7 +7,6 @@ from abc import abstractmethod, ABCMeta ...@@ -7,7 +7,6 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils.naming import PREDICT_TOWER
from ..utils import logger from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment