Commit d646972d authored by Yuxin Wu's avatar Yuxin Wu

better prefetch & periodic callback as wrapper

parent 9f1af4c8
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: argscope_test.py # File: cifar10_convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -116,8 +116,7 @@ def get_config(): ...@@ -116,8 +116,7 @@ def get_config():
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
dataset_test = get_data('test') dataset_test = get_data('test')
sess_config = get_default_sess_config() sess_config = get_default_sess_config(0.5)
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
...@@ -132,7 +131,7 @@ def get_config(): ...@@ -132,7 +131,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), ModelSaver(),
ClassificationError(dataset_test, prefix='test'), ClassificationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
......
...@@ -166,7 +166,7 @@ def get_config(): ...@@ -166,7 +166,7 @@ def get_config():
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), ModelSaver(),
ClassificationError(dataset_test, prefix='test'), ClassificationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
......
...@@ -106,7 +106,7 @@ def get_config(): ...@@ -106,7 +106,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), ModelSaver(),
#ValidationError(dataset_test, prefix='test'), #ValidationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
......
...@@ -105,7 +105,7 @@ def get_config(): ...@@ -105,7 +105,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), ModelSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']), ValidationStatPrinter(dataset_test, ['cost:0']),
ClassificationError(dataset_test, prefix='validation'), ClassificationError(dataset_test, prefix='validation'),
]), ]),
......
...@@ -109,7 +109,7 @@ def get_config(): ...@@ -109,7 +109,7 @@ def get_config():
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
PeriodicSaver(), ModelSaver(),
ClassificationError(test, prefix='test'), ClassificationError(test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
......
...@@ -81,18 +81,24 @@ class Callback(object): ...@@ -81,18 +81,24 @@ class Callback(object):
class PeriodicCallback(Callback): class PeriodicCallback(Callback):
""" """
A callback to be triggered after every `period` epochs. A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
""" """
def __init__(self, period): def __init__(self, cb, period):
""" """
:param cb: a `Callback`
:param period: int :param period: int
""" """
self.cb = cb
self.period = int(period) self.period = int(period)
def _before_train(self):
self.cb.before_train(self.trainer)
def _after_train(self):
self.cb.after_train()
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.epoch_num = self.epoch_num - 1
if self.epoch_num % self.period == 0: if self.epoch_num % self.period == 0:
self._trigger_periodic() self.cb.trigger_epoch()
@abstractmethod
def _trigger_periodic(self):
pass
...@@ -6,22 +6,20 @@ import tensorflow as tf ...@@ -6,22 +6,20 @@ import tensorflow as tf
import os import os
import re import re
from .base import Callback, PeriodicCallback from .base import Callback
from ..utils import * from ..utils import *
__all__ = ['PeriodicSaver'] __all__ = ['ModelSaver']
class PeriodicSaver(PeriodicCallback): class ModelSaver(Callback):
""" """
Save the model to logger directory. Save the model to logger directory.
""" """
def __init__(self, period=1, keep_recent=10, keep_freq=0.5): def __init__(self, keep_recent=10, keep_freq=0.5):
""" """
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation. :param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation. :param keep_freq: see `tf.train.Saver` documentation.
""" """
super(PeriodicSaver, self).__init__(period)
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
...@@ -48,7 +46,7 @@ class PeriodicSaver(PeriodicCallback): ...@@ -48,7 +46,7 @@ class PeriodicSaver(PeriodicCallback):
var_dict[name] = v var_dict[name] = v
return var_dict return var_dict
def _trigger_periodic(self): def _trigger_epoch(self):
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
self.path, self.path,
......
...@@ -80,7 +80,7 @@ class TestCallbackContext(object): ...@@ -80,7 +80,7 @@ class TestCallbackContext(object):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR) ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None: if ckpt is None:
raise RuntimeError( raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver before any TestCallback?") "Cannot find a checkpoint state. Do you forget to use ModelSaver before all TestCallback?")
logger.info( logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path)) "Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path) self.saver.restore(self.sess, ckpt.model_checkpoint_path)
......
...@@ -8,7 +8,7 @@ import os ...@@ -8,7 +8,7 @@ import os
import operator import operator
import pickle import pickle
from .base import Callback, PeriodicCallback from .base import Callback
from ..utils import * from ..utils import *
__all__ = ['StatHolder', 'StatPrinter'] __all__ = ['StatHolder', 'StatPrinter']
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import multiprocessing import multiprocessing
from .base import DataFlow from .base import ProxyDataFlow
from ..utils.concurrency import ensure_procs_terminate from ..utils.concurrency import ensure_procs_terminate
__all__ = ['PrefetchData'] __all__ = ['PrefetchData']
...@@ -30,7 +30,7 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -30,7 +30,7 @@ class PrefetchProcess(multiprocessing.Process):
finally: finally:
self.queue.put(Sentinel()) self.queue.put(Sentinel())
class PrefetchData(DataFlow): class PrefetchData(ProxyDataFlow):
""" """
Prefetch data from a `DataFlow` using multiprocessing Prefetch data from a `DataFlow` using multiprocessing
""" """
...@@ -40,35 +40,34 @@ class PrefetchData(DataFlow): ...@@ -40,35 +40,34 @@ class PrefetchData(DataFlow):
:param nr_prefetch: size of the queue to hold prefetched datapoints. :param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use. :param nr_proc: number of processes to use.
""" """
self.ds = ds super(PrefetchData, self).__init__(ds)
self._size = self.ds.size() self._size = self.size()
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
self.queue = multiprocessing.Queue(self.nr_prefetch)
def size(self): self.procs = [PrefetchProcess(self.ds, self.queue)
return self._size for _ in range(self.nr_proc)]
ensure_procs_terminate(self.procs)
for x in self.procs:
x.start()
def get_data(self): def get_data(self):
queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
ensure_procs_terminate(procs)
[x.start() for x in procs]
end_cnt = 0 end_cnt = 0
tot_cnt = 0 tot_cnt = 0
try: while True:
while True: dp = self.queue.get()
dp = queue.get() if isinstance(dp, Sentinel):
if isinstance(dp, Sentinel): end_cnt += 1
end_cnt += 1 if end_cnt == self.nr_proc:
if end_cnt == self.nr_proc:
break
continue
tot_cnt += 1
yield dp
if tot_cnt == self._size:
break break
finally: continue
queue.close() tot_cnt += 1
[x.terminate() for x in procs] yield dp
if tot_cnt == self._size:
break
def __del__(self):
self.queue.close()
for x in self.procs:
x.terminate()
...@@ -61,7 +61,7 @@ class ScaleGradient(GradientProcessor): ...@@ -61,7 +61,7 @@ class ScaleGradient(GradientProcessor):
self.multipliers = multipliers self.multipliers = multipliers
def _process(self, grads): def _process(self, grads):
# TODO use None for zero to speed up? # TODO use None for zero can speed up (or not)?
ret = [] ret = []
for grad, var in grads: for grad, var in grads:
varname = var.op.name varname = var.op.name
......
...@@ -76,7 +76,8 @@ class memoized(object): ...@@ -76,7 +76,8 @@ class memoized(object):
return functools.partial(self.__call__, obj) return functools.partial(self.__call__, obj)
def get_rng(self): def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295 seed = (id(self) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
def get_nr_gpu(): def get_nr_gpu():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment