Commit 1c3d8741 authored by Yuxin Wu's avatar Yuxin Wu

minsaver , jsonstat

parent 7207816d
...@@ -35,7 +35,7 @@ class Model(ModelDesc): ...@@ -35,7 +35,7 @@ class Model(ModelDesc):
def _get_cost(self, input_vars, is_training): def _get_cost(self, input_vars, is_training):
image, label = input_vars image, label = input_vars
image = image / 128.0 - 1 image = image / 128.0
def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype): def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype):
stride = 2 if nr1x1 == 0 else 1 stride = 2 if nr1x1 == 0 else 1
...@@ -46,7 +46,6 @@ class Model(ModelDesc): ...@@ -46,7 +46,6 @@ class Model(ModelDesc):
x2 = Conv2D('conv3x3r', x, nr3x3r, 1) x2 = Conv2D('conv3x3r', x, nr3x3r, 1)
outs.append(Conv2D('conv3x3', x2, nr3x3, 3, stride=stride)) outs.append(Conv2D('conv3x3', x2, nr3x3, 3, stride=stride))
x3 = Conv2D('conv233r', x, nr233r, 1) x3 = Conv2D('conv233r', x, nr233r, 1)
x3 = Conv2D('conv233a', x3, nr233, 3) x3 = Conv2D('conv233a', x3, nr233, 3)
outs.append(Conv2D('conv233b', x3, nr233, 3, stride=stride)) outs.append(Conv2D('conv233b', x3, nr233, 3, stride=stride))
...@@ -133,6 +132,8 @@ def get_data(train_or_test): ...@@ -133,6 +132,8 @@ def get_data(train_or_test):
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.Resize((256, 256)), imgaug.Resize((256, 256)),
imgaug.Brightness(30, False),
imgaug.Contrast((0.8,1.2), True),
imgaug.MapImage(lambda x: x - pp_mean), imgaug.MapImage(lambda x: x - pp_mean),
imgaug.RandomCrop((224, 224)), imgaug.RandomCrop((224, 224)),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
...@@ -172,9 +173,9 @@ def get_config(): ...@@ -172,9 +173,9 @@ def get_config():
ClassificationError('wrong-top5', 'val-top5-error')]), ClassificationError('wrong-top5', 'val-top5-error')]),
#HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt') #HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(8, 0.03), (13, 0.02), (21, 5e-3), [(8, 0.03), (13, 0.02), (16, 5e-3),
(28, 3e-3), (33, 1e-3), (44, 5e-4), (18, 3e-3), (24, 1e-3), (26, 2e-4),
(49, 1e-4), (59, 2e-5)]) (28, 5e-5) ])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(), model=Model(),
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import os import os, shutil
import re import re
from .base import Callback from .base import Callback
...@@ -56,6 +56,24 @@ class ModelSaver(Callback): ...@@ -56,6 +56,24 @@ class ModelSaver(Callback):
class MinSaver(Callback): class MinSaver(Callback):
def __init__(self, monitor_stat): def __init__(self, monitor_stat):
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
self.min = None
def _get_stat(self):
return self.trainer.stat_holder.get_stat_now(self.monitor_stat)
def _trigger_epoch(self): def _trigger_epoch(self):
pass if self.min is None or self._get_stat() < self.min:
self.min = self._get_stat()
self._save()
def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before MinSaver?")
path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, 'min_' + self.monitor_stat)
shutil.copy(path, newname)
logger.info("Model with minimum {} saved.".format(self.monitor_stat))
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
import re import re
import os import os
import operator import operator
import pickle import json
from .base import Callback from .base import Callback
from ..utils import * from ..utils import *
...@@ -25,11 +25,11 @@ class StatHolder(object): ...@@ -25,11 +25,11 @@ class StatHolder(object):
self.stat_now = {} self.stat_now = {}
self.log_dir = log_dir self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.pkl') self.filename = os.path.join(log_dir, 'stat.json')
if os.path.isfile(self.filename): if os.path.isfile(self.filename):
logger.info("Loading stats from {}...".format(self.filename)) logger.info("Loading stats from {}...".format(self.filename))
with open(self.filename) as f: with open(self.filename) as f:
self.stat_history = pickle.load(f) self.stat_history = json.load(f)
else: else:
self.stat_history = [] self.stat_history = []
...@@ -47,6 +47,12 @@ class StatHolder(object): ...@@ -47,6 +47,12 @@ class StatHolder(object):
""" """
self.print_tag = None if print_tag is None else set(print_tag) self.print_tag = None if print_tag is None else set(print_tag)
def get_stat_now(self, k):
"""
Return the value of a stat in the current epoch.
"""
return self.stat_now[k]
def finalize(self): def finalize(self):
""" """
Called after finishing adding stats. Will print and write stats to disk. Called after finishing adding stats. Will print and write stats to disk.
...@@ -64,7 +70,7 @@ class StatHolder(object): ...@@ -64,7 +70,7 @@ class StatHolder(object):
def _write_stat(self): def _write_stat(self):
tmp_filename = self.filename + '.tmp' tmp_filename = self.filename + '.tmp'
with open(tmp_filename, 'wb') as f: with open(tmp_filename, 'wb') as f:
pickle.dump(self.stat_history, f) json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename) os.rename(tmp_filename, self.filename)
class StatPrinter(Callback): class StatPrinter(Callback):
......
...@@ -19,6 +19,13 @@ __all__ = ['Trainer'] ...@@ -19,6 +19,13 @@ __all__ = ['Trainer']
class Trainer(object): class Trainer(object):
""" """
Base class for a trainer. Base class for a trainer.
Available Attritbutes:
stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig`
model: a `ModelDesc`
global_step: a `int`
""" """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment