Commit b9e2bd1b authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'better-a3c'

parents 13d4171f 313723df
......@@ -14,12 +14,8 @@ import multiprocessing, threading
from collections import deque
from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.RL import *
import common
from common import play_model, Evaluator, eval_model_multithread
......
......@@ -196,10 +196,9 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train()
......@@ -8,14 +8,9 @@ import tensorflow as tf
import argparse
import os
from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
"""
CIFAR10-resnet example.
......@@ -146,6 +141,10 @@ def get_data(train_or_test):
return ds
def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset
dataset_train = get_data('train')
step_per_epoch = dataset_train.size()
......@@ -179,18 +178,12 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
with tf.device('/cpu:0'):
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train()
......@@ -8,16 +8,9 @@ import argparse
import numpy as np
import os
from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
ResNet-110 for SVHN Digit Classification.
......@@ -151,6 +144,10 @@ def get_data(train_or_test):
return ds
def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset
dataset_train = get_data('train')
step_per_epoch = dataset_train.size()
......@@ -184,18 +181,12 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
with tf.device('/cpu:0'):
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
SyncMultiGPUTrainer(config).train()
......@@ -9,6 +9,7 @@ import threading
import weakref
from abc import abstractmethod, ABCMeta
from collections import defaultdict, namedtuple
import numpy as np
from six.moves import queue
from ..utils.timer import *
......@@ -42,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process):
context = zmq.Context()
c2s_socket = context.socket(zmq.DEALER)
c2s_socket.identity = 'simulator-{}'.format(self.idx)
#c2s_socket.set_hwm(2)
c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER)
......@@ -59,7 +60,8 @@ class SimulatorProcess(multiprocessing.Process):
action = loads(data)
reward, isOver = player.action(action)
c2s_socket.send(dumps((reward, isOver)), copy=False)
noop = s2c_socket.recv(copy=False)
#with total_timer('client recv_ack'):
ACK = s2c_socket.recv(copy=False)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
......@@ -102,6 +104,14 @@ class SimulatorMaster(threading.Thread):
self.socket_lock = threading.Lock()
self.daemon = True
# queueing messages to client
self.send_queue = queue.Queue(maxsize=100)
self.send_thread = LoopThread(lambda:
self.s2c_socket.send_multipart(self.send_queue.get()))
self.send_thread.daemon = True
self.send_thread.start()
# make sure socket get closed at the end
def clean_context(soks, context):
for s in soks:
s.close()
......@@ -113,7 +123,6 @@ class SimulatorMaster(threading.Thread):
self.clients = defaultdict(SimulatorMaster.ClientState)
while True:
ident, msg = self.c2s_socket.recv_multipart()
#assert _ == ""
client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action
......@@ -126,6 +135,7 @@ class SimulatorMaster(threading.Thread):
self._on_episode_over(ident)
else:
self._on_datapoint(ident)
self.send_queue.put([ident, 'Thanks']) # just an ACK
@abstractmethod
def _on_state(self, state, ident):
......
......@@ -18,3 +18,4 @@ from .utils import *
from .tfutils import *
from .callbacks import *
from .dataflow import *
from .predict import *
......@@ -36,7 +36,7 @@ class ModelSaver(Callback):
var_dict = {}
for v in vars:
name = v.op.name
if re.match('tower[1-9]', name):
if re.match('tower[p1-9]', name):
#logger.info("Skip {} when saving model.".format(name))
continue
if 'tower0/' in name:
......
......@@ -10,7 +10,7 @@ import uuid
import os
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_proc_terminate
from ..utils.concurrency import *
from ..utils.serialize import *
from ..utils import logger
......@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
for _ in range(self.nr_proc)]
for x in self.procs:
x.start()
start_proc_mask_signal(self.procs)
# __del__ not guranteed to get called at exit
import atexit
......
......@@ -5,7 +5,9 @@
import tensorflow as tf
from copy import copy
import re
from ..utils import logger
from ._common import layer_register
__all__ = ['BatchNorm']
......@@ -48,9 +50,35 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
emaname = 'EMA'
if not batch_mean.name.startswith('towerp'):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
# use training-statistics in prediction
assert not use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
G = tf.get_default_graph()
try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
except KeyError:
mean_name = re.sub('towerp[0-9]+/', 'tower0/', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', ema_var.name)
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if use_local_stat:
with tf.control_dependencies([ema_apply_op]):
......@@ -58,6 +86,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else:
batch = tf.cast(tf.shape(x)[0], tf.float32)
# XXX TODO batch==1?
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn')
......@@ -81,43 +81,38 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self.outqueue.put((tid, self.func(dp)))
class PredictorWorkerThread(threading.Thread):
def __init__(self, queue, pred_func, id):
def __init__(self, queue, pred_func, id, batch_size=5):
super(PredictorWorkerThread, self).__init__()
self.queue = queue
self.func = pred_func
self.daemon = True
self.batch_size = batch_size
self.id = id
def run(self):
#self.xxx = None
def fetch():
batched = []
futures = []
batched, futures = [], []
inp, f = self.queue.get()
batched.append(inp)
futures.append(f)
#print "func queue:", self.queue.qsize()
#return batched, futures
if self.batch_size == 1:
return batched, futures
while True:
try:
inp, f = self.queue.get_nowait()
batched.append(inp)
futures.append(f)
if len(batched) == 5:
if len(batched) == self.batch_size:
break
except queue.Empty:
break
return batched, futures
#self.xxx = None
while True:
# normal input
#inputs, f = self.queue.get()
#outputs = self.func(inputs)
#f.set_result(outputs)
batched, futures = fetch()
#print "batched size: ", len(batched)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched])
# debug, for speed testing
#if self.xxx is None:
#outputs = self.func([batched])
#self.xxx = outputs
......@@ -134,13 +129,13 @@ class MultiThreadAsyncPredictor(object):
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
def __init__(self, trainer, input_names, output_names, nr_thread):
def __init__(self, trainer, input_names, output_names, nr_thread, batch_size=5):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [
PredictorWorkerThread(self.input_queue, f, id)
PredictorWorkerThread(self.input_queue, f, id, batch_size)
for id, f in enumerate(
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
......
......@@ -5,13 +5,19 @@
from ..utils.naming import *
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['get_default_sess_config',
'get_global_step',
'get_global_step_var',
'get_op_var_name',
'get_vars_by_names'
]
'get_vars_by_names',
'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def get_default_sess_config(mem_fraction=0.9):
"""
......@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn))
return ret
def backup_collection(keys):
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
......@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
var_dict = defaultdict(list)
for v in vars_to_restore:
name = v.op.name
if 'towerp' in name:
logger.warn("Anything from prediction tower shouldn't be saved.")
if 'tower' in name:
new_name = re.sub('tower[0-9]+/', '', name)
new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name
if name in vars_available:
var_dict[name].append(v)
......
......@@ -90,7 +90,7 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary):
name = re.sub('tower[0-9]+/', '', c.op.name)
name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c))
return avg_maintain_op
......@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder
from ..tfutils import *
from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer']
......@@ -89,7 +88,7 @@ class Trainer(object):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]+/', '', val.tag) # TODO move to subclasses
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
......@@ -141,7 +140,6 @@ class Trainer(object):
self.sess.close()
def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip, range
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import *
from .trainer import QueueInputTrainer
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
# avoid repeated summary from each device
backup = backup_collection(self.SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grad_list[0], get_global_step_var()),
summary_moving_average())
describe_model()
# prepare train_op for the rest of the towers
self.training_threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.training_threads.append(th)
self.async_running = False
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.main_loop()
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.training_threads: # resume all threads
th.resume()
super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
self.async_running = False
for th in self.training_threads:
th.pause()
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
This diff is collapsed.
......@@ -3,17 +3,17 @@
# File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#import msgpack
#import msgpack_numpy
#msgpack_numpy.patch()
import dill
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
#import dill
__all__ = ['loads', 'dumps']
def dumps(obj):
return dill.dumps(obj)
#return msgpack.dumps(obj, use_bin_type=True)
#return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
return dill.loads(buf)
#return msgpack.loads(buf)
#return dill.loads(buf)
return msgpack.loads(buf)
......@@ -37,7 +37,7 @@ def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec, {} times".format(
k, v.sum, v.count))
logger.info("Total Time: {} -> {} sec, {} times, {} sec/time".format(
k, v.sum, v.count, v.average))
atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment