Commit d53b5c4c authored by Yuxin Wu's avatar Yuxin Wu

Create the old-style trainer when requested

parent 4e644290
......@@ -110,6 +110,7 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase):
"""
A ModelDesc with single cost and single optimizer.
It contains information about InputDesc, how to get cost, and how to get optimizer.
"""
def get_cost(self):
......
......@@ -21,6 +21,7 @@ from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
import tensorpack.train as old_train # noqa
from ..train.base import StopTraining, TrainLoop
__all__ = ['Trainer', 'SingleCostTrainer']
......@@ -185,6 +186,15 @@ class Trainer(object):
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if isinstance(args[0], old_train.TrainConfig) or 'config' in kwargs:
name = cls.__name__
old_trainer = getattr(old_train, name)
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
def _get_property(name):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment