Commit 1c3d8741 authored by Yuxin Wu's avatar Yuxin Wu

minsaver , jsonstat

parent 7207816d
......@@ -35,7 +35,7 @@ class Model(ModelDesc):
def _get_cost(self, input_vars, is_training):
image, label = input_vars
image = image / 128.0 - 1
image = image / 128.0
def inception(name, x, nr1x1, nr3x3r, nr3x3, nr233r, nr233, nrpool, pooltype):
stride = 2 if nr1x1 == 0 else 1
......@@ -46,7 +46,6 @@ class Model(ModelDesc):
x2 = Conv2D('conv3x3r', x, nr3x3r, 1)
outs.append(Conv2D('conv3x3', x2, nr3x3, 3, stride=stride))
x3 = Conv2D('conv233r', x, nr233r, 1)
x3 = Conv2D('conv233a', x3, nr233, 3)
outs.append(Conv2D('conv233b', x3, nr233, 3, stride=stride))
......@@ -133,6 +132,8 @@ def get_data(train_or_test):
if isTrain:
augmentors = [
imgaug.Resize((256, 256)),
imgaug.Brightness(30, False),
imgaug.Contrast((0.8,1.2), True),
imgaug.MapImage(lambda x: x - pp_mean),
imgaug.RandomCrop((224, 224)),
imgaug.Flip(horiz=True),
......@@ -172,9 +173,9 @@ def get_config():
ClassificationError('wrong-top5', 'val-top5-error')]),
#HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter('learning_rate',
[(8, 0.03), (13, 0.02), (21, 5e-3),
(28, 3e-3), (33, 1e-3), (44, 5e-4),
(49, 1e-4), (59, 2e-5)])
[(8, 0.03), (13, 0.02), (16, 5e-3),
(18, 3e-3), (24, 1e-3), (26, 2e-4),
(28, 5e-5) ])
]),
session_config=sess_config,
model=Model(),
......
......@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import os
import os, shutil
import re
from .base import Callback
......@@ -56,6 +56,24 @@ class ModelSaver(Callback):
class MinSaver(Callback):
def __init__(self, monitor_stat):
self.monitor_stat = monitor_stat
self.min = None
def _get_stat(self):
return self.trainer.stat_holder.get_stat_now(self.monitor_stat)
def _trigger_epoch(self):
pass
if self.min is None or self._get_stat() < self.min:
self.min = self._get_stat()
self._save()
def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before MinSaver?")
path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, 'min_' + self.monitor_stat)
shutil.copy(path, newname)
logger.info("Model with minimum {} saved.".format(self.monitor_stat))
......@@ -6,7 +6,7 @@ import tensorflow as tf
import re
import os
import operator
import pickle
import json
from .base import Callback
from ..utils import *
......@@ -25,11 +25,11 @@ class StatHolder(object):
self.stat_now = {}
self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.pkl')
self.filename = os.path.join(log_dir, 'stat.json')
if os.path.isfile(self.filename):
logger.info("Loading stats from {}...".format(self.filename))
with open(self.filename) as f:
self.stat_history = pickle.load(f)
self.stat_history = json.load(f)
else:
self.stat_history = []
......@@ -47,6 +47,12 @@ class StatHolder(object):
"""
self.print_tag = None if print_tag is None else set(print_tag)
def get_stat_now(self, k):
"""
Return the value of a stat in the current epoch.
"""
return self.stat_now[k]
def finalize(self):
"""
Called after finishing adding stats. Will print and write stats to disk.
......@@ -64,7 +70,7 @@ class StatHolder(object):
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
with open(tmp_filename, 'wb') as f:
pickle.dump(self.stat_history, f)
json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
class StatPrinter(Callback):
......
......@@ -19,6 +19,13 @@ __all__ = ['Trainer']
class Trainer(object):
"""
Base class for a trainer.
Available Attritbutes:
stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig`
model: a `ModelDesc`
global_step: a `int`
"""
__metaclass__ = ABCMeta
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment