Commit 5deebdcb authored by Yuxin Wu's avatar Yuxin Wu

fix imports

parent 0b2d375d
...@@ -33,7 +33,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -33,7 +33,6 @@ class ExpReplay(DataFlow, Callback):
player, player,
batch_size=32, batch_size=32,
memory_size=1e6, memory_size=1e6,
populate_size=None, # deprecated
init_memory_size=50000, init_memory_size=50000,
exploration=1, exploration=1,
end_exploration=0.1, end_exploration=0.1,
...@@ -50,10 +49,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -50,10 +49,6 @@ class ExpReplay(DataFlow, Callback):
:param update_frequency: number of new transitions to add to memory :param update_frequency: number of new transitions to add to memory
after sampling a batch of transitions for training after sampling a batch of transitions for training
""" """
# XXX back-compat
if populate_size is not None:
logger.warn("populate_size in ExpReplay is deprecated in favor of init_memory_size")
init_memory_size = populate_size
init_memory_size = int(init_memory_size) init_memory_size = int(init_memory_size)
for k, v in locals().items(): for k, v in locals().items():
......
...@@ -28,7 +28,7 @@ class Callback(object): ...@@ -28,7 +28,7 @@ class Callback(object):
Called before finalizing the graph. Called before finalizing the graph.
Use this callback to setup some ops used in the callback. Use this callback to setup some ops used in the callback.
:param trainer: a :class:`train.Trainer` instance :param trainer: :class:`train.Trainer` instance
""" """
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
......
...@@ -13,7 +13,7 @@ from six.moves import zip, map ...@@ -13,7 +13,7 @@ from six.moves import zip, map
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import get_tqdm_kwargs, logger from ..utils import get_tqdm_kwargs, logger
from ..utils.stat import RatioCounter, BinaryStatistics from ..utils.stat import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback from .base import Callback
__all__ = ['InferenceRunner', 'ClassificationError', __all__ = ['InferenceRunner', 'ClassificationError',
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import signal import signal
import re import re
import weakref
from six.moves import range from six.moves import range
import tqdm import tqdm
...@@ -108,7 +109,7 @@ class Trainer(object): ...@@ -108,7 +109,7 @@ class Trainer(object):
get_global_step_var() # ensure there is such var, before finalizing the graph get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
callbacks = self.config.callbacks callbacks = self.config.callbacks
callbacks.setup_graph(self) # TODO use weakref instead? callbacks.setup_graph(weakref.proxy(self))
self._init_summary() self._init_summary()
logger.info("Initializing graph variables ...") logger.info("Initializing graph variables ...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment