Commit b9e2bd1b authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'better-a3c'

parents 13d4171f 313723df
...@@ -14,12 +14,8 @@ import multiprocessing, threading ...@@ -14,12 +14,8 @@ import multiprocessing, threading
from collections import deque from collections import deque
from tensorpack import * from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.RL import * from tensorpack.RL import *
import common import common
from common import play_model, Evaluator, eval_model_multithread from common import play_model, Evaluator, eval_model_multithread
......
...@@ -196,10 +196,9 @@ if __name__ == '__main__': ...@@ -196,10 +196,9 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -8,14 +8,9 @@ import tensorflow as tf ...@@ -8,14 +8,9 @@ import tensorflow as tf
import argparse import argparse
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack import *
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
""" """
CIFAR10-resnet example. CIFAR10-resnet example.
...@@ -146,6 +141,10 @@ def get_data(train_or_test): ...@@ -146,6 +141,10 @@ def get_data(train_or_test):
return ds return ds
def get_config(): def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
...@@ -179,18 +178,12 @@ if __name__ == '__main__': ...@@ -179,18 +178,12 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
with tf.device('/cpu:0'):
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -8,16 +8,9 @@ import argparse ...@@ -8,16 +8,9 @@ import argparse
import numpy as np import numpy as np
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack import *
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
""" """
ResNet-110 for SVHN Digit Classification. ResNet-110 for SVHN Digit Classification.
...@@ -151,6 +144,10 @@ def get_data(train_or_test): ...@@ -151,6 +144,10 @@ def get_data(train_or_test):
return ds return ds
def get_config(): def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
...@@ -184,18 +181,12 @@ if __name__ == '__main__': ...@@ -184,18 +181,12 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
with tf.device('/cpu:0'):
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -2,4 +2,5 @@ termcolor ...@@ -2,4 +2,5 @@ termcolor
pillow pillow
scipy scipy
tqdm tqdm
dill msgpack
msgpack-numpy
...@@ -9,6 +9,7 @@ import threading ...@@ -9,6 +9,7 @@ import threading
import weakref import weakref
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
import numpy as np
from six.moves import queue from six.moves import queue
from ..utils.timer import * from ..utils.timer import *
...@@ -42,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -42,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process):
context = zmq.Context() context = zmq.Context()
c2s_socket = context.socket(zmq.DEALER) c2s_socket = context.socket(zmq.DEALER)
c2s_socket.identity = 'simulator-{}'.format(self.idx) c2s_socket.identity = 'simulator-{}'.format(self.idx)
#c2s_socket.set_hwm(2) c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s) c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER) s2c_socket = context.socket(zmq.DEALER)
...@@ -59,7 +60,8 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -59,7 +60,8 @@ class SimulatorProcess(multiprocessing.Process):
action = loads(data) action = loads(data)
reward, isOver = player.action(action) reward, isOver = player.action(action)
c2s_socket.send(dumps((reward, isOver)), copy=False) c2s_socket.send(dumps((reward, isOver)), copy=False)
noop = s2c_socket.recv(copy=False) #with total_timer('client recv_ack'):
ACK = s2c_socket.recv(copy=False)
#cnt += 1 #cnt += 1
#if cnt % 100 == 0: #if cnt % 100 == 0:
#print_total_timer() #print_total_timer()
...@@ -102,6 +104,14 @@ class SimulatorMaster(threading.Thread): ...@@ -102,6 +104,14 @@ class SimulatorMaster(threading.Thread):
self.socket_lock = threading.Lock() self.socket_lock = threading.Lock()
self.daemon = True self.daemon = True
# queueing messages to client
self.send_queue = queue.Queue(maxsize=100)
self.send_thread = LoopThread(lambda:
self.s2c_socket.send_multipart(self.send_queue.get()))
self.send_thread.daemon = True
self.send_thread.start()
# make sure socket get closed at the end
def clean_context(soks, context): def clean_context(soks, context):
for s in soks: for s in soks:
s.close() s.close()
...@@ -113,7 +123,6 @@ class SimulatorMaster(threading.Thread): ...@@ -113,7 +123,6 @@ class SimulatorMaster(threading.Thread):
self.clients = defaultdict(SimulatorMaster.ClientState) self.clients = defaultdict(SimulatorMaster.ClientState)
while True: while True:
ident, msg = self.c2s_socket.recv_multipart() ident, msg = self.c2s_socket.recv_multipart()
#assert _ == ""
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action if not client.protocol_state == 0: # state-action
...@@ -126,6 +135,7 @@ class SimulatorMaster(threading.Thread): ...@@ -126,6 +135,7 @@ class SimulatorMaster(threading.Thread):
self._on_episode_over(ident) self._on_episode_over(ident)
else: else:
self._on_datapoint(ident) self._on_datapoint(ident)
self.send_queue.put([ident, 'Thanks']) # just an ACK
@abstractmethod @abstractmethod
def _on_state(self, state, ident): def _on_state(self, state, ident):
......
...@@ -18,3 +18,4 @@ from .utils import * ...@@ -18,3 +18,4 @@ from .utils import *
from .tfutils import * from .tfutils import *
from .callbacks import * from .callbacks import *
from .dataflow import * from .dataflow import *
from .predict import *
...@@ -36,7 +36,7 @@ class ModelSaver(Callback): ...@@ -36,7 +36,7 @@ class ModelSaver(Callback):
var_dict = {} var_dict = {}
for v in vars: for v in vars:
name = v.op.name name = v.op.name
if re.match('tower[1-9]', name): if re.match('tower[p1-9]', name):
#logger.info("Skip {} when saving model.".format(name)) #logger.info("Skip {} when saving model.".format(name))
continue continue
if 'tower0/' in name: if 'tower0/' in name:
......
...@@ -10,7 +10,7 @@ import uuid ...@@ -10,7 +10,7 @@ import uuid
import os import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import ensure_proc_terminate from ..utils.concurrency import *
from ..utils.serialize import * from ..utils.serialize import *
from ..utils import logger from ..utils import logger
...@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename) self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
for x in self.procs: start_proc_mask_signal(self.procs)
x.start()
# __del__ not guranteed to get called at exit # __del__ not guranteed to get called at exit
import atexit import atexit
......
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import tensorflow as tf import tensorflow as tf
from copy import copy from copy import copy
import re
from ..utils import logger
from ._common import layer_register from ._common import layer_register
__all__ = ['BatchNorm'] __all__ = ['BatchNorm']
...@@ -48,9 +50,35 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -48,9 +50,35 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else: else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay) emaname = 'EMA'
if not batch_mean.name.startswith('towerp'):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
# use training-statistics in prediction
assert not use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
G = tf.get_default_graph()
try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
except KeyError:
mean_name = re.sub('towerp[0-9]+/', 'tower0/', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', ema_var.name)
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
...@@ -58,6 +86,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -58,6 +86,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn') x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else: else:
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
# XXX TODO batch==1?
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_normalization( return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn') x, mean, var, beta, gamma, epsilon, 'bn')
...@@ -81,43 +81,38 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -81,43 +81,38 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp))) self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread): class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id): def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__() super(PredictorWorkerThread, self).__init__()
self.queue = queue self.queue = queue
self.func = pred_func self.func = pred_func
self.daemon = True self.daemon = True
self.batch_size = batch_size
self.id = id self.id = id
def run(self): def run(self):
#self.xxx = None
def fetch(): def fetch():
batched = [] batched, futures = [], []
futures = []
inp, f = self.queue.get() inp, f = self.queue.get()
batched.append(inp) batched.append(inp)
futures.append(f) futures.append(f)
#print "func queue:", self.queue.qsize() if self.batch_size == 1:
#return batched, futures return batched, futures
while True: while True:
try: try:
inp, f = self.queue.get_nowait() inp, f = self.queue.get_nowait()
batched.append(inp) batched.append(inp)
futures.append(f) futures.append(f)
if len(batched) == 5: if len(batched) == self.batch_size:
break break
except queue.Empty: except queue.Empty:
break break
return batched, futures return batched, futures
#self.xxx = None #self.xxx = None
while True: while True:
# normal input
#inputs, f = self.queue.get()
#outputs = self.func(inputs)
#f.set_result(outputs)
batched, futures = fetch() batched, futures = fetch()
#print "batched size: ", len(batched) #print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched]) outputs = self.func([batched])
# debug, for speed testing
#if self.xxx is None: #if self.xxx is None:
#outputs = self.func([batched]) #outputs = self.func([batched])
#self.xxx = outputs #self.xxx = outputs
...@@ -134,13 +129,13 @@ class MultiThreadAsyncPredictor(object): ...@@ -134,13 +129,13 @@ class MultiThreadAsyncPredictor(object):
An online predictor (use the current active session) that works with An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU. QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
""" """
def __init__(self, trainer, input_names, output_names, nr_thread): def __init__(self, trainer, input_names, output_names, nr_thread, batch_size=5):
""" """
:param trainer: a `QueueInputTrainer` instance. :param trainer: a `QueueInputTrainer` instance.
""" """
self.input_queue = queue.Queue(maxsize=nr_thread*10) self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [ self.threads = [
PredictorWorkerThread(self.input_queue, f, id) PredictorWorkerThread(self.input_queue, f, id, batch_size)
for id, f in enumerate( for id, f in enumerate(
trainer.get_predict_funcs( trainer.get_predict_funcs(
input_names, output_names, nr_thread))] input_names, output_names, nr_thread))]
......
...@@ -5,13 +5,19 @@ ...@@ -5,13 +5,19 @@
from ..utils.naming import * from ..utils.naming import *
import tensorflow as tf import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step',
'get_global_step_var', 'get_global_step_var',
'get_op_var_name', 'get_op_var_name',
'get_vars_by_names' 'get_vars_by_names',
] 'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def get_default_sess_config(mem_fraction=0.9): def get_default_sess_config(mem_fraction=0.9):
""" """
...@@ -66,3 +72,24 @@ def get_vars_by_names(names): ...@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn, varn = get_op_var_name(n) opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
return ret return ret
def backup_collection(keys):
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
...@@ -106,8 +106,10 @@ class SaverRestore(SessionInit): ...@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
var_dict = defaultdict(list) var_dict = defaultdict(list)
for v in vars_to_restore: for v in vars_to_restore:
name = v.op.name name = v.op.name
if 'towerp' in name:
logger.warn("Anything from prediction tower shouldn't be saved.")
if 'tower' in name: if 'tower' in name:
new_name = re.sub('tower[0-9]+/', '', name) new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name name = new_name
if name in vars_available: if name in vars_available:
var_dict[name].append(v) var_dict[name].append(v)
......
...@@ -90,7 +90,7 @@ def summary_moving_average(): ...@@ -90,7 +90,7 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY) vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary): for idx, c in enumerate(vars_to_summary):
name = re.sub('tower[0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c)) tf.scalar_summary(name, averager.average(c))
return avg_maintain_op return avg_maintain_op
...@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal ...@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer'] __all__ = ['Trainer']
...@@ -89,7 +88,7 @@ class Trainer(object): ...@@ -89,7 +88,7 @@ class Trainer(object):
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, self.global_step)
...@@ -141,7 +140,6 @@ class Trainer(object): ...@@ -141,7 +140,6 @@ class Trainer(object):
self.sess.close() self.sess.close()
def init_session_and_coord(self): def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip, range
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import *
from .trainer import QueueInputTrainer
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
# avoid repeated summary from each device
backup = backup_collection(self.SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grad_list[0], get_global_step_var()),
summary_moving_average())
describe_model()
# prepare train_op for the rest of the towers
self.training_threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.training_threads.append(th)
self.async_running = False
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.main_loop()
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.training_threads: # resume all threads
th.resume()
super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
self.async_running = False
for th in self.training_threads:
th.pause()
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
...@@ -5,20 +5,16 @@ ...@@ -5,20 +5,16 @@
import tensorflow as tf import tensorflow as tf
import threading import threading
import time import time
import copy
import re
import functools
from six.moves import zip from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import * from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import * from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer', __all__ = ['SimpleTrainer', 'QueueInputTrainer']
'AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
...@@ -42,6 +38,7 @@ class SimpleTrainer(Trainer): ...@@ -42,6 +38,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op) avg_maintain_op)
self.init_session_and_coord() self.init_session_and_coord()
describe_model()
# create an infinte data producer # create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data() self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop() self.main_loop()
...@@ -76,6 +73,7 @@ class EnqueueThread(threading.Thread): ...@@ -76,6 +73,7 @@ class EnqueueThread(threading.Thread):
self.queue = queue self.queue = queue
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
self.daemon = True self.daemon = True
def run(self): def run(self):
...@@ -86,6 +84,8 @@ class EnqueueThread(threading.Thread): ...@@ -86,6 +84,8 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed)
#print size
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
...@@ -97,16 +97,16 @@ class EnqueueThread(threading.Thread): ...@@ -97,16 +97,16 @@ class EnqueueThread(threading.Thread):
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
class QueueInputTrainer(Trainer): class QueueInputTrainer(Trainer):
""" """ Single GPU Trainer, takes input from a queue"""
Trainer which builds a FIFO queue for input.
Support multi GPU. SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
"""
def __init__(self, config, input_queue=None, async=False): def __init__(self, config, input_queue=None, predict_tower=None):
""" """
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100. Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0].
""" """
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
...@@ -115,23 +115,11 @@ class QueueInputTrainer(Trainer): ...@@ -115,23 +115,11 @@ class QueueInputTrainer(Trainer):
100, [x.dtype for x in self.input_vars], name='input_queue') 100, [x.dtype for x in self.input_vars], name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
self.async = async if predict_tower is None:
if self.async: # by default, use the first training gpu for prediction
assert self.config.nr_tower > 1 predict_tower = [0]
self.dequed_inputs = [] self.predict_tower = predict_tower
self.dequed_inputs = None
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _get_model_inputs(self): def _get_model_inputs(self):
""" Dequeue a datapoint from input_queue and return""" """ Dequeue a datapoint from input_queue and return"""
...@@ -141,104 +129,58 @@ class QueueInputTrainer(Trainer): ...@@ -141,104 +129,58 @@ class QueueInputTrainer(Trainer):
assert len(ret) == len(self.input_vars) assert len(ret) == len(self.input_vars)
for qv, v in zip(ret, self.input_vars): for qv, v in zip(ret, self.input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
self.dequed_inputs.append(ret)
return ret return ret
def _build_predict_tower(self):
inputs = self.model.get_input_vars()
tf.get_variable_scope().reuse_variables()
for k in self.predict_tower:
logger.info("Building graph for predict towerp{}...".format(k))
with tf.device('/gpu:{}'.format(k)), \
tf.name_scope('towerp{}'.format(k)):
self.model.build_graph(inputs, False)
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower case""" """ Get grad and cost for single-tower"""
model_inputs = self._get_model_inputs() self.dequed_inputs = model_inputs = self._get_model_inputs()
self.model.build_graph(model_inputs, True) self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
grads = self.config.optimizer.compute_gradients(cost_var) grads = self.config.optimizer.compute_gradients(cost_var)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
return grads return grads
def _multi_tower_grads(self): def _build_enque_thread(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower)) """ create a thread that keeps filling the queue """
enqueue_op = self.input_queue.enqueue(self.input_vars)
# to avoid repeated summary from each device self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
collect_dedup = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY] self.extra_threads_procs.append(self.input_th)
kept_summaries = {}
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
# gate_gradienst=0 seems to be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
for k in collect_dedup:
kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in collect_dedup:
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(kept_summaries[k])
return grad_list
def train(self): def train(self):
enqueue_op = self.input_queue.enqueue(self.input_vars) assert self.config.nr_tower == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self.init_session_and_coord()
self._build_enque_thread()
if self.config.nr_tower > 1:
grad_list = self._multi_tower_grads()
if not self.async:
grads = QueueInputTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
else:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
else:
grads = self._single_tower_grad() grads = self._single_tower_grad()
grads = self.process_grads(grads) grads = self.process_grads(grads)
describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.train_op = tf.group( self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average()) summary_moving_average())
if self.async:
# prepare train_op for the rest of the towers
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False
self.init_session_and_coord()
# create a thread that keeps filling the queue
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th)
self.main_loop() self.main_loop()
def run_step(self): def run_step(self):
if self.async: """ just run self.train_op"""
if not self.async_running: self.sess.run([self.train_op])
self.async_running = True
for th in self.threads: # resume all threads
th.resume()
self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
if self.async:
self.async_running = False
for th in self.threads:
th.pause()
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self._process_summary(summary_str) self._process_summary(summary_str)
...@@ -246,33 +188,20 @@ class QueueInputTrainer(Trainer): ...@@ -246,33 +188,20 @@ class QueueInputTrainer(Trainer):
def get_predict_func(self, input_names, output_names, tower=0): def get_predict_func(self, input_names, output_names, tower=0):
""" """
:param tower: return the kth predict_func :param tower: return the kth predict_func
:returns: a predictor function
""" """
tower = tower % self.config.nr_tower tower = self.predict_tower[tower % len(self.predict_tower)]
if self.config.nr_tower > 1:
logger.info("Prepare a predictor function for tower{} ...".format(tower))
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
input_var_idxs = [self.input_vars.index(v) for v in raw_input_vars] output_names = ['towerp{}/'.format(tower) + n for n in output_names]
dequed = self.dequed_inputs[tower]
input_vars = [dequed[k] for k in input_var_idxs]
if self.config.nr_tower > 1:
output_names = ['tower{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names) output_vars = get_vars_by_names(output_names)
def func(inputs): def func(inputs):
assert len(inputs) == len(input_vars) assert len(inputs) == len(raw_input_vars)
feed = dict(zip(input_vars, inputs)) feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed) return self.sess.run(output_vars, feed_dict=feed)
return func return func
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
""" return n predicts functions evenly on each predict_tower"""
return [self.get_predict_func(input_names, output_names, k) return [self.get_predict_func(input_names, output_names, k)
for k in range(n)] for k in range(n)]
def AsyncMultiGPUTrainer(config):
return QueueInputTrainer(config, async=True)
def SyncMultiGPUTrainer(config):
return QueueInputTrainer(config)
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# File: serialize.py # File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#import msgpack import msgpack
#import msgpack_numpy import msgpack_numpy
#msgpack_numpy.patch() msgpack_numpy.patch()
import dill #import dill
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
def dumps(obj): def dumps(obj):
return dill.dumps(obj) #return dill.dumps(obj)
#return msgpack.dumps(obj, use_bin_type=True) return msgpack.dumps(obj, use_bin_type=True)
def loads(buf): def loads(buf):
return dill.loads(buf) #return dill.loads(buf)
#return msgpack.loads(buf) return msgpack.loads(buf)
...@@ -37,7 +37,7 @@ def print_total_timer(): ...@@ -37,7 +37,7 @@ def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0: if len(_TOTAL_TIMER_DATA) == 0:
return return
for k, v in six.iteritems(_TOTAL_TIMER_DATA): for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec, {} times".format( logger.info("Total Time: {} -> {} sec, {} times, {} sec/time".format(
k, v.sum, v.count)) k, v.sum, v.count, v.average))
atexit.register(print_total_timer) atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment