Commit e9a6a5af authored by Yuxin Wu's avatar Yuxin Wu

train/ directory

parent 2264b5a3
...@@ -8,7 +8,7 @@ import argparse ...@@ -8,7 +8,7 @@ import argparse
import numpy as np import numpy as np
import os import os
from tensorpack.train import TrainConfig, start_train from tensorpack.train import TrainConfig, QueueInputTrainer
from tensorpack.models import * from tensorpack.models import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.utils import * from tensorpack.utils import *
...@@ -158,4 +158,4 @@ if __name__ == '__main__': ...@@ -158,4 +158,4 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
config.nr_tower = len(args.gpu.split(',')) config.nr_tower = len(args.gpu.split(','))
start_train(config) QueueInputTrainer(config).train()
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import os, sys import os, sys
import argparse import argparse
from tensorpack.train import TrainConfig, start_train from tensorpack.train import TrainConfig, SimpleTrainer
from tensorpack.models import * from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.utils.symbolic_functions import *
...@@ -92,7 +92,7 @@ def get_config(): ...@@ -92,7 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 20 #step_per_epoch = 20
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
...@@ -131,5 +131,5 @@ if __name__ == '__main__': ...@@ -131,5 +131,5 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
start_train(config) SimpleTrainer(config).train()
...@@ -47,7 +47,6 @@ def BatchNorm(x, is_training, gamma_init=1.0): ...@@ -47,7 +47,6 @@ def BatchNorm(x, is_training, gamma_init=1.0):
x.set_shape(hack_shape) x.set_shape(hack_shape)
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments')
print batch_mean
ema = tf.train.ExponentialMovingAverage(decay=0.999) ema = tf.train.ExponentialMovingAverage(decay=0.999)
ema_apply_op = ema.apply([batch_mean, batch_var]) ema_apply_op = ema.apply([batch_mean, batch_var])
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from abc import ABCMeta
import tqdm
import re
from .config import TrainConfig
from ..utils import *
from ..callbacks import StatHolder
from ..utils.modelutils import describe_model
__all__ = ['Trainer']
class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
"""
Config: a `TrainConfig` instance
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.summary_writer.close()
self.sess.close()
def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..callbacks import Callbacks
from ..models import ModelDesc
from ..utils import *
from ..dataflow import DataFlow
__all__ = ['TrainConfig']
class TrainConfig(object):
"""
Config for training a model with a single loss
"""
def __init__(self, **kwargs):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -4,65 +4,16 @@ ...@@ -4,65 +4,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from itertools import count
import copy import copy
import argparse
import re import re
import tqdm
from abc import ABCMeta
from .models import ModelDesc from .base import Trainer
from .dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from .utils import * from ..utils import *
from .utils.concurrency import EnqueueThread from ..utils.concurrency import EnqueueThread
from .callbacks import * from ..utils.summary import summary_moving_average
from .utils.summary import summary_moving_average
from .utils.modelutils import describe_model
from .utils import logger
from .dataflow import DataFlow
class TrainConfig(object): __all__ = ['SimpleTrainer', 'QueueInputTrainer']
"""
Config for training a model with a single loss
"""
def __init__(self, **kwargs):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
self.dataset = kwargs.pop('dataset')
assert_type(self.dataset, DataFlow)
self.optimizer = kwargs.pop('optimizer')
assert_type(self.optimizer, tf.train.Optimizer)
self.callbacks = kwargs.pop('callbacks')
assert_type(self.callbacks, Callbacks)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.session_config = kwargs.pop('session_config', get_default_sess_config())
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
...@@ -88,95 +39,6 @@ def scale_grads(grads, multiplier): ...@@ -88,95 +39,6 @@ def scale_grads(grads, multiplier):
ret.append((grad, var)) ret.append((grad, var))
return ret return ret
class Trainer(object):
__metaclass__ = ABCMeta
def __init__(self, config):
"""
Config: a `TrainConfig` instance
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
tf.add_to_collection(MODEL_KEY, config.model)
@abstractmethod
def train(self):
pass
@abstractmethod
def run_step(self):
pass
def trigger_epoch(self):
self.global_step += self.config.step_per_epoch
self._trigger_epoch()
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
logger.stat_holder.finalize()
@abstractmethod
def _trigger_epoch(self):
pass
def _init_summary(self):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.writer = self.summary_writer
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
logger.stat_holder = StatHolder(logger.LOG_DIR, [])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) # TODO move to subclasses
logger.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
def main_loop(self):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
self._init_summary()
self.global_step = get_global_step()
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
tf.get_default_graph().finalize()
for epoch in xrange(1, self.config.max_epoch):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
if self.coord.should_stop():
return
self.run_step()
callbacks.trigger_step()
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
self.coord.request_stop()
# Do I need to run queue.close?
callbacks.after_train()
self.summary_writer.close()
self.sess.close()
def init_session_and_coord(self):
self.sess = tf.Session(config=self.config.session_config)
self.config.session_init.init(self.sess)
# start training:
self.coord = tf.train.Coordinator()
tf.train.start_queue_runners(
sess=self.sess, coord=self.coord, daemon=True, start=True)
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
...@@ -200,7 +62,6 @@ class SimpleTrainer(Trainer): ...@@ -200,7 +62,6 @@ class SimpleTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) avg_maintain_op)
describe_model()
self.init_session_and_coord() self.init_session_and_coord()
# create an infinte data producer # create an infinte data producer
self.data_producer = RepeatedData(self.config.dataset, -1).get_data() self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
...@@ -280,7 +141,6 @@ class QueueInputTrainer(Trainer): ...@@ -280,7 +141,6 @@ class QueueInputTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) avg_maintain_op)
describe_model()
self.init_session_and_coord() self.init_session_and_coord()
# create a thread that keeps filling the queue # create a thread that keeps filling the queue
...@@ -299,11 +159,5 @@ class QueueInputTrainer(Trainer): ...@@ -299,11 +159,5 @@ class QueueInputTrainer(Trainer):
def start_train(config): def start_train(config):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
tr = SimpleTrainer(config) tr = SimpleTrainer(config)
#tr = QueueInputTrainer(config)
tr.train() tr.train()
...@@ -7,7 +7,6 @@ import threading ...@@ -7,7 +7,6 @@ import threading
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
from .utils import expand_dim_if_necessary
from .naming import * from .naming import *
from . import logger from . import logger
...@@ -44,7 +43,6 @@ class EnqueueThread(threading.Thread): ...@@ -44,7 +43,6 @@ class EnqueueThread(threading.Thread):
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed) self.sess.run([self.op], feed_dict=feed)
#print '\nExauhsted!!!'
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
except Exception: except Exception:
......
...@@ -4,19 +4,19 @@ ...@@ -4,19 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os import os
def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
""" # """
Args: # Args:
var: a tensor # var: a tensor
dp: a numpy array # dp: a numpy array
Return a reshaped version of dp, if that makes it match the valid dimension of var # Return a reshaped version of dp, if that makes it match the valid dimension of var
""" # """
shape = var.get_shape().as_list() # shape = var.get_shape().as_list()
valid_shape = [k for k in shape if k] # valid_shape = [k for k in shape if k]
if dp.shape == tuple(valid_shape): # if dp.shape == tuple(valid_shape):
new_shape = [k if k else 1 for k in shape] # new_shape = [k if k else 1 for k in shape]
dp = dp.reshape(new_shape) # dp = dp.reshape(new_shape)
return dp # return dp
def mkdir_p(dirname): def mkdir_p(dirname):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment