Commit d646972d authored by Yuxin Wu's avatar Yuxin Wu

better prefetch & periodic callback as wrapper

parent 9f1af4c8
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: argscope_test.py
# File: cifar10_convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
......@@ -116,8 +116,7 @@ def get_config():
step_per_epoch = dataset_train.size()
dataset_test = get_data('test')
sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
sess_config = get_default_sess_config(0.5)
nr_gpu = get_nr_gpu()
lr = tf.train.exponential_decay(
......@@ -132,7 +131,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ModelSaver(),
ClassificationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
......
......@@ -166,7 +166,7 @@ def get_config():
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ModelSaver(),
ClassificationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
......
......@@ -106,7 +106,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ModelSaver(),
#ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
......
......@@ -105,7 +105,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ModelSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ClassificationError(dataset_test, prefix='validation'),
]),
......
......@@ -109,7 +109,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ModelSaver(),
ClassificationError(test, prefix='test'),
]),
session_config=sess_config,
......
......@@ -81,18 +81,24 @@ class Callback(object):
class PeriodicCallback(Callback):
"""
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
"""
def __init__(self, period):
def __init__(self, cb, period):
"""
:param cb: a `Callback`
:param period: int
"""
self.cb = cb
self.period = int(period)
def _before_train(self):
self.cb.before_train(self.trainer)
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self):
self.cb.epoch_num = self.epoch_num - 1
if self.epoch_num % self.period == 0:
self._trigger_periodic()
@abstractmethod
def _trigger_periodic(self):
pass
self.cb.trigger_epoch()
......@@ -6,22 +6,20 @@ import tensorflow as tf
import os
import re
from .base import Callback, PeriodicCallback
from .base import Callback
from ..utils import *
__all__ = ['PeriodicSaver']
__all__ = ['ModelSaver']
class PeriodicSaver(PeriodicCallback):
class ModelSaver(Callback):
"""
Save the model to logger directory.
"""
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
def __init__(self, keep_recent=10, keep_freq=0.5):
"""
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
super(PeriodicSaver, self).__init__(period)
self.keep_recent = keep_recent
self.keep_freq = keep_freq
......@@ -48,7 +46,7 @@ class PeriodicSaver(PeriodicCallback):
var_dict[name] = v
return var_dict
def _trigger_periodic(self):
def _trigger_epoch(self):
self.saver.save(
tf.get_default_session(),
self.path,
......
......@@ -80,7 +80,7 @@ class TestCallbackContext(object):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver before any TestCallback?")
"Cannot find a checkpoint state. Do you forget to use ModelSaver before all TestCallback?")
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
......
......@@ -8,7 +8,7 @@ import os
import operator
import pickle
from .base import Callback, PeriodicCallback
from .base import Callback
from ..utils import *
__all__ = ['StatHolder', 'StatPrinter']
......
......@@ -4,7 +4,7 @@
import multiprocessing
from .base import DataFlow
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate
__all__ = ['PrefetchData']
......@@ -30,7 +30,7 @@ class PrefetchProcess(multiprocessing.Process):
finally:
self.queue.put(Sentinel())
class PrefetchData(DataFlow):
class PrefetchData(ProxyDataFlow):
"""
Prefetch data from a `DataFlow` using multiprocessing
"""
......@@ -40,25 +40,22 @@ class PrefetchData(DataFlow):
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
"""
self.ds = ds
self._size = self.ds.size()
super(PrefetchData, self).__init__(ds)
self._size = self.size()
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
def size(self):
return self._size
self.queue = multiprocessing.Queue(self.nr_prefetch)
self.procs = [PrefetchProcess(self.ds, self.queue)
for _ in range(self.nr_proc)]
ensure_procs_terminate(self.procs)
for x in self.procs:
x.start()
def get_data(self):
queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
ensure_procs_terminate(procs)
[x.start() for x in procs]
end_cnt = 0
tot_cnt = 0
try:
while True:
dp = queue.get()
dp = self.queue.get()
if isinstance(dp, Sentinel):
end_cnt += 1
if end_cnt == self.nr_proc:
......@@ -68,7 +65,9 @@ class PrefetchData(DataFlow):
yield dp
if tot_cnt == self._size:
break
finally:
queue.close()
[x.terminate() for x in procs]
def __del__(self):
self.queue.close()
for x in self.procs:
x.terminate()
......@@ -61,7 +61,7 @@ class ScaleGradient(GradientProcessor):
self.multipliers = multipliers
def _process(self, grads):
# TODO use None for zero to speed up?
# TODO use None for zero can speed up (or not)?
ret = []
for grad, var in grads:
varname = var.op.name
......
......@@ -76,7 +76,8 @@ class memoized(object):
return functools.partial(self.__call__, obj)
def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295
seed = (id(self) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed)
def get_nr_gpu():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment