Commit 0cfc0889 authored by Yuxin Wu's avatar Yuxin Wu

Fix some old usage with v1 trainer (https://github.com/YixuanLi/densenet-tensorflow/pull/16)

parent 8b7b3f3c
......@@ -12,7 +12,7 @@ from ..callbacks import (
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..tfutils.sessinit import SessionInit, SaverRestore
from ..tfutils.sessinit import SessionInit, SaverRestore, JustCurrentSession
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource
......@@ -60,7 +60,8 @@ class TrainConfig(object):
model=None,
callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999):
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
**kwargs):
"""
Args:
dataflow (DataFlow):
......@@ -146,6 +147,31 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
if 'nr_tower' in kwargs:
self.nr_tower = kwargs.pop('nr_tower')
if 'tower' in kwargs:
self.tower = kwargs.pop('tower')
assert len(kwargs) == 0, "Unknown arguments: {}".format(kwargs.keys())
@property
def nr_tower(self):
logger.warn("TrainConfig.nr_tower was deprecated! Set the number of GPUs on the trainer instead!")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
return len(self.tower)
@nr_tower.setter
def nr_tower(self, value):
logger.warn("TrainConfig.nr_tower was deprecated! Set the number of GPUs on the trainer instead!")
logger.warn("See https://github.com/ppwwyyxx/tensorpack/issues/458 for more information.")
self.tower = list(range(value))
def _deprecated_parsing(self):
self.callbacks = self.callbacks or []
self.extra_callbacks = DEFAULT_CALLBACKS() if self.extra_callbacks is None else self.extra_callbacks
self.callbacks.extend(self.extra_callbacks)
self.monitors = DEFAULT_MONITORS() if self.monitors is None else self.monitors
self.session_init = self.session_init or JustCurrentSession()
class AutoResumeTrainConfig(TrainConfig):
"""
......
......@@ -89,7 +89,11 @@ class TowerTrainer(Trainer):
try:
tower = self.tower_func.towers[tower_name]
assert tower is not None, "This is a bug!"
except KeyError:
tower = None
if tower is None:
input = PlaceholderInput()
input.setup(self.inputs_desc)
......
......@@ -52,6 +52,7 @@ class Trainer(object):
config (TrainConfig): the train config.
"""
assert isinstance(config, TrainConfig), type(config)
config._deprecated_parsing()
self._config = config
self.model = config.model
if self.model is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment