Commit 25e9853a authored by Yuxin Wu's avatar Yuxin Wu

misc fix

parent 2e238998
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
Examples with __reproducible__ and meaningful performancce. Examples with __reproducible__ and meaningful performancce.
+ [An illustrative mnist example](mnist-convnet.py) + [An illustrative mnist example](mnist-convnet.py)
+ [A small Cifar10 ConvNet with 91% accuracy](cifar-convnet.py)
+ [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py) + [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py)
+ [Reproduce some reinforcement learning papers](Atari2600) + [Reproduce some reinforcement learning papers](Atari2600)
+ [char-rnn for fun](char-rnn) + [char-rnn for fun](char-rnn)
......
...@@ -81,9 +81,10 @@ because {} will be saved".format(v.name, var_dict[name].name)) ...@@ -81,9 +81,10 @@ because {} will be saved".format(v.name, var_dict[name].name))
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback): class MinSaver(Callback):
def __init__(self, monitor_stat, reverse=True): def __init__(self, monitor_stat, reverse=True, filename=None):
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.reverse = reverse self.reverse = reverse
self.filename = filename
self.min = None self.min = None
def _get_stat(self): def _get_stat(self):
...@@ -107,7 +108,8 @@ class MinSaver(Callback): ...@@ -107,7 +108,8 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, newname = os.path.join(logger.LOG_DIR,
'max-' if self.reverse else 'min-' + self.monitor_stat) self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
shutil.copy(path, newname) shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format( logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat)) 'maximum' if self.reverse else 'minimum', self.monitor_stat))
......
...@@ -63,6 +63,8 @@ class TrainConfig(object): ...@@ -63,6 +63,8 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None): def set_tower(self, nr_tower=None, tower=None):
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
# this is a deprecated function
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!" assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower: if nr_tower:
tower = list(range(nr_tower)) tower = list(range(nr_tower))
......
...@@ -22,6 +22,7 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -22,6 +22,7 @@ class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None): def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower) super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
self.dequed_inputs = [] self.dequed_inputs = []
@staticmethod @staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment