Commit 25e9853a authored by Yuxin Wu's avatar Yuxin Wu

misc fix

parent 2e238998
......@@ -4,7 +4,6 @@
Examples with __reproducible__ and meaningful performancce.
+ [An illustrative mnist example](mnist-convnet.py)
+ [A small Cifar10 ConvNet with 91% accuracy](cifar-convnet.py)
+ [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py)
+ [Reproduce some reinforcement learning papers](Atari2600)
+ [char-rnn for fun](char-rnn)
......
......@@ -81,9 +81,10 @@ because {} will be saved".format(v.name, var_dict[name].name))
logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback):
def __init__(self, monitor_stat, reverse=True):
def __init__(self, monitor_stat, reverse=True, filename=None):
self.monitor_stat = monitor_stat
self.reverse = reverse
self.filename = filename
self.min = None
def _get_stat(self):
......@@ -107,7 +108,8 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR,
'max-' if self.reverse else 'min-' + self.monitor_stat)
self.filename or
('max-' if self.reverse else 'min-' + self.monitor_stat + '.tfmodel'))
shutil.copy(path, newname)
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
......
......@@ -63,6 +63,8 @@ class TrainConfig(object):
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
logger.warn("config.set_tower is deprecated. set config.tower or config.nr_tower directly")
# this is a deprecated function
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower:
tower = list(range(nr_tower))
......
......@@ -22,6 +22,7 @@ class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one GPU."
self.dequed_inputs = []
@staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment