Commit 47f76e94 authored by Yuxin Wu's avatar Yuxin Wu

fix everything (hopefully)

parent e072d909
...@@ -14,12 +14,8 @@ import multiprocessing, threading ...@@ -14,12 +14,8 @@ import multiprocessing, threading
from collections import deque from collections import deque
from tensorpack import * from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.callbacks import *
from tensorpack.RL import * from tensorpack.RL import *
import common import common
from common import play_model, Evaluator, eval_model_multithread from common import play_model, Evaluator, eval_model_multithread
......
...@@ -196,10 +196,9 @@ if __name__ == '__main__': ...@@ -196,10 +196,9 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -141,6 +141,10 @@ def get_data(train_or_test): ...@@ -141,6 +141,10 @@ def get_data(train_or_test):
return ds return ds
def get_config(): def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
...@@ -174,10 +178,6 @@ if __name__ == '__main__': ...@@ -174,10 +178,6 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......
...@@ -8,16 +8,9 @@ import argparse ...@@ -8,16 +8,9 @@ import argparse
import numpy as np import numpy as np
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack import *
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
""" """
ResNet-110 for SVHN Digit Classification. ResNet-110 for SVHN Digit Classification.
...@@ -151,6 +144,10 @@ def get_data(train_or_test): ...@@ -151,6 +144,10 @@ def get_data(train_or_test):
return ds return ds
def get_config(): def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
...@@ -184,18 +181,12 @@ if __name__ == '__main__': ...@@ -184,18 +181,12 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default():
with tf.device('/cpu:0'):
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train() SyncMultiGPUTrainer(config).train()
...@@ -36,7 +36,7 @@ class ModelSaver(Callback): ...@@ -36,7 +36,7 @@ class ModelSaver(Callback):
var_dict = {} var_dict = {}
for v in vars: for v in vars:
name = v.op.name name = v.op.name
if re.match('tower[1-9]', name): if re.match('tower[p1-9]', name):
#logger.info("Skip {} when saving model.".format(name)) #logger.info("Skip {} when saving model.".format(name))
continue continue
if 'tower0/' in name: if 'tower0/' in name:
......
...@@ -56,6 +56,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -56,6 +56,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else: else:
# use training-statistics in prediction
assert not use_local_stat assert not use_local_stat
# have to do this again to get actual name. see issue: # have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740 # https://github.com/tensorflow/tensorflow/issues/2740
...@@ -63,15 +64,21 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -63,15 +64,21 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
G = tf.get_default_graph()
try:
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name) mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name) var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0' #var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G = tf.get_default_graph()
ema_mean = G.get_tensor_by_name(mean_name) ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name) ema_var = G.get_tensor_by_name(var_name)
except KeyError:
mean_name = re.sub('towerp[0-9]+/', 'tower0/', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', 'tower0/', ema_var.name)
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
......
...@@ -106,8 +106,10 @@ class SaverRestore(SessionInit): ...@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
var_dict = defaultdict(list) var_dict = defaultdict(list)
for v in vars_to_restore: for v in vars_to_restore:
name = v.op.name name = v.op.name
if 'towerp' in name:
logger.warn("Anything from prediction tower shouldn't be saved.")
if 'tower' in name: if 'tower' in name:
new_name = re.sub('tower[0-9]+/', '', name) new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name name = new_name
if name in vars_available: if name in vars_available:
var_dict[name].append(v) var_dict[name].append(v)
......
...@@ -90,7 +90,7 @@ def summary_moving_average(): ...@@ -90,7 +90,7 @@ def summary_moving_average():
vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY) vars_to_summary = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary): for idx, c in enumerate(vars_to_summary):
name = re.sub('tower[0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.scalar_summary(name, averager.average(c)) tf.scalar_summary(name, averager.average(c))
return avg_maintain_op return avg_maintain_op
...@@ -88,7 +88,7 @@ class Trainer(object): ...@@ -88,7 +88,7 @@ class Trainer(object):
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, self.global_step)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip, range
from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model
from ..tfutils import *
from .trainer import QueueInputTrainer
__all__ = ['AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
# TODO gate_gradienst=0 seems to be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
# avoid repeated summary from each device
backup = backup_collection(self.SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grad_list[0], get_global_step_var()),
summary_moving_average())
describe_model()
# prepare train_op for the rest of the towers
self.training_threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.training_threads.append(th)
self.async_running = False
with freeze_collection(self.SUMMARY_BACKUP_KEYS):
self._build_predict_tower()
self.main_loop()
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.training_threads: # resume all threads
th.resume()
super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
self.async_running = False
for th in self.training_threads:
th.pause()
super(AsyncMultiGPUTrainer, self)._trigger_epoch()
...@@ -5,20 +5,16 @@ ...@@ -5,20 +5,16 @@
import tensorflow as tf import tensorflow as tf
import threading import threading
import time import time
import re
import functools
from six.moves import zip from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import * from ..utils import *
from ..utils.concurrency import LoopThread
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils import * from ..tfutils import *
__all__ = ['SimpleTrainer', 'QueueInputTrainer', __all__ = ['SimpleTrainer', 'QueueInputTrainer']
'AsyncMultiGPUTrainer', 'SyncMultiGPUTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
...@@ -110,6 +106,7 @@ class QueueInputTrainer(Trainer): ...@@ -110,6 +106,7 @@ class QueueInputTrainer(Trainer):
:param config: a `TrainConfig` instance :param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints. :param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100. Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0].
""" """
super(QueueInputTrainer, self).__init__(config) super(QueueInputTrainer, self).__init__(config)
self.input_vars = self.model.get_input_vars() self.input_vars = self.model.get_input_vars()
...@@ -119,7 +116,7 @@ class QueueInputTrainer(Trainer): ...@@ -119,7 +116,7 @@ class QueueInputTrainer(Trainer):
else: else:
self.input_queue = input_queue self.input_queue = input_queue
if predict_tower is None: if predict_tower is None:
# by default, use first training tower for prediction # by default, use the first training gpu for prediction
predict_tower = [0] predict_tower = [0]
self.predict_tower = predict_tower self.predict_tower = predict_tower
self.dequed_inputs = None self.dequed_inputs = None
...@@ -144,7 +141,7 @@ class QueueInputTrainer(Trainer): ...@@ -144,7 +141,7 @@ class QueueInputTrainer(Trainer):
self.model.build_graph(inputs, False) self.model.build_graph(inputs, False)
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower case""" """ Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_model_inputs() self.dequed_inputs = model_inputs = self._get_model_inputs()
self.model.build_graph(model_inputs, True) self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
...@@ -153,13 +150,14 @@ class QueueInputTrainer(Trainer): ...@@ -153,13 +150,14 @@ class QueueInputTrainer(Trainer):
return grads return grads
def _build_enque_thread(self): def _build_enque_thread(self):
# create a thread that keeps filling the queue """ create a thread that keeps filling the queue """
enqueue_op = self.input_queue.enqueue(self.input_vars) enqueue_op = self.input_queue.enqueue(self.input_vars)
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars) self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th) self.extra_threads_procs.append(self.input_th)
def train(self): def train(self):
assert self.config.nr_tower == 1, "QueueInputTrainer only supports 1 tower!" assert self.config.nr_tower == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self.init_session_and_coord() self.init_session_and_coord()
self._build_enque_thread() self._build_enque_thread()
...@@ -207,119 +205,3 @@ class QueueInputTrainer(Trainer): ...@@ -207,119 +205,3 @@ class QueueInputTrainer(Trainer):
return [self.get_predict_func(input_names, output_names, k) return [self.get_predict_func(input_names, output_names, k)
for k in range(n)] for k in range(n)]
class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
def __init__(self, config, input_queue=None, predict_tower=None):
super(MultiGPUTrainer, self).__init__(config, input_queue, predict_tower)
assert config.nr_tower > 1
self.dequed_inputs = []
@staticmethod
def _average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
v = grad_and_vars[0][1]
try:
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
except AssertionError:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
def _multi_tower_grads(self):
logger.info("Training a model of {} tower".format(self.config.nr_tower))
grad_list = []
for i in range(self.config.nr_tower):
with tf.device('/gpu:{}'.format(i)), \
tf.name_scope('tower{}'.format(i)) as scope:
logger.info("Building graph for training tower {}...".format(i))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs, True)
cost_var = self.model.get_cost() # build tower
# gate_gradienst=0 seems to be faster?
grad_list.append(
self.config.optimizer.compute_gradients(cost_var, gate_gradients=0))
if i == 0:
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
tf.get_variable_scope().reuse_variables()
# avoid repeated summary from each device
backup = backup_collection(self.SUMMARY_BACKUP_KEYS)
restore_collection(backup)
return grad_list
class SyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
grads = MultiGPUTrainer._average_grads(grad_list)
grads = self.process_grads(grads)
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
class AsyncMultiGPUTrainer(MultiGPUTrainer):
def train(self):
self.init_session_and_coord()
self._build_enque_thread()
grad_list = self._multi_tower_grads()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def scale(grads):
return [(grad / self.config.nr_tower, var) for grad, var in grads]
grad_list = map(scale, grad_list)
grad_list = [self.process_grads(g) for g in grad_list]
grads = grad_list[0] # use grad from the first tower for the main iteration
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average())
describe_model()
# prepare train_op for the rest of the towers
self.threads = []
for k in range(1, self.config.nr_tower):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
f = lambda op=train_op: self.sess.run([op]) # avoid late-binding
th = LoopThread(f)
th.pause()
th.start()
self.threads.append(th)
self.async_running = False
self._build_predict_tower()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self.main_loop()
def run_step(self):
if not self.async_running:
self.async_running = True
for th in self.threads: # resume all threads
th.resume()
self.sess.run([self.train_op]) # faster since train_op return None
def _trigger_epoch(self):
self.async_running = False
for th in self.threads:
th.pause()
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self._process_summary(summary_str)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment