Commit 0cfc0889 authored by Yuxin Wu's avatar Yuxin Wu

Fix some old usage with v1 trainer (https://github.com/YixuanLi/densenet-tensorflow/pull/16)

parent 8b7b3f3c
...@@ -12,7 +12,7 @@ from ..callbacks import ( ...@@ -12,7 +12,7 @@ from ..callbacks import (
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger from ..utils import logger
from ..tfutils.sessinit import SessionInit, SaverRestore from ..tfutils.sessinit import SessionInit, SaverRestore, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource from ..input_source import InputSource
...@@ -60,7 +60,8 @@ class TrainConfig(object): ...@@ -60,7 +60,8 @@ class TrainConfig(object):
model=None, model=None,
callbacks=None, extra_callbacks=None, monitors=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None, session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999): starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
**kwargs):
""" """
Args: Args:
dataflow (DataFlow): dataflow (DataFlow):
...@@ -146,6 +147,31 @@ class TrainConfig(object): ...@@ -146,6 +147,31 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch) self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch) self.max_epoch = int(max_epoch)
if 'nr_tower' in kwargs:
self.nr_tower = kwargs.pop('nr_tower')
if 'tower' in kwargs:
self.tower = kwargs.pop('tower')
assert len(kwargs) == 0, "Unknown arguments: {}".format(kwargs.keys())
@property
def nr_tower(self):
logger.warn("TrainConfig.nr_tower was deprecated! Set the number of GPUs on the trainer instead!")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
return len(self.tower)
@nr_tower.setter
def nr_tower(self, value):
logger.warn("TrainConfig.nr_tower was deprecated! Set the number of GPUs on the trainer instead!")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
self.tower = list(range(value))
def _deprecated_parsing(self):
self.callbacks = self.callbacks or []
self.extra_callbacks = DEFAULT_CALLBACKS() if self.extra_callbacks is None else self.extra_callbacks
self.callbacks.extend(self.extra_callbacks)
self.monitors = DEFAULT_MONITORS() if self.monitors is None else self.monitors
self.session_init = self.session_init or JustCurrentSession()
class AutoResumeTrainConfig(TrainConfig): class AutoResumeTrainConfig(TrainConfig):
""" """
......
...@@ -89,7 +89,11 @@ class TowerTrainer(Trainer): ...@@ -89,7 +89,11 @@ class TowerTrainer(Trainer):
try: try:
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
assert tower is not None, "This is a bug!"
except KeyError: except KeyError:
tower = None
if tower is None:
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.inputs_desc) input.setup(self.inputs_desc)
......
...@@ -52,6 +52,7 @@ class Trainer(object): ...@@ -52,6 +52,7 @@ class Trainer(object):
config (TrainConfig): the train config. config (TrainConfig): the train config.
""" """
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
config._deprecated_parsing()
self._config = config self._config = config
self.model = config.model self.model = config.model
if self.model is not None: if self.model is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment