Commit 1898cd3c authored by Yuxin Wu's avatar Yuxin Wu

fix bug

parent 6306da7e
...@@ -204,7 +204,7 @@ def get_config(): ...@@ -204,7 +204,7 @@ def get_config():
HumanHyperParamSetter('entropy_beta'), HumanHyperParamSetter('entropy_beta'),
HumanHyperParamSetter('explore_factor'), HumanHyperParamSetter('explore_factor'),
master, master,
StartProcOrThread(master) StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
]), ]),
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
......
...@@ -10,6 +10,7 @@ import os ...@@ -10,6 +10,7 @@ import os
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
""" """
A small convnet model for Cifar10 or Cifar100 dataset. A small convnet model for Cifar10 or Cifar100 dataset.
...@@ -152,7 +153,10 @@ if __name__ == '__main__': ...@@ -152,7 +153,10 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train() if args.gpu:
#if args.gpu: config.nr_tower = len(args.gpu.split(','))
#config.nr_tower = len(args.gpu.split(',')) nr_gpu = get_nr_gpu()
#AsyncMultiGPUTrainer(config).train() if nr_gpu == 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainer(config).train()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import Callback
from ..utils.concurrency import start_proc_mask_signal
from ..utils import logger
__all__ = ['StartProcOrThread']
class StartProcOrThread(Callback):
def __init__(self, procs_threads):
"""
Start extra threads and processes before training
:param procs_threads: list of processes or threads
"""
if not isinstance(procs_threads, list):
procs_threads = [procs_threads]
self._procs_threads = procs_threads
def _before_train(self):
logger.info("Starting all threads & procs ...")
# avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads)
...@@ -86,3 +86,7 @@ class Callbacks(Callback): ...@@ -86,3 +86,7 @@ class Callbacks(Callback):
with tm.timed_callback(display_name): with tm.timed_callback(display_name):
cb.trigger_epoch() cb.trigger_epoch()
tm.log() tm.log()
def append(self, cb):
assert isinstance(cb, Callback)
self.cbs.append(cb)
...@@ -68,7 +68,7 @@ class Trainer(object): ...@@ -68,7 +68,7 @@ class Trainer(object):
def trigger_epoch(self): def trigger_epoch(self):
# by default, add this two stat # by default, add this two stat
self.stat_holder.add_stat('global_step', self.global_step) self.stat_holder.add_stat('global_step', get_global_step())
self.stat_holder.add_stat('epoch_num', self.epoch_num) self.stat_holder.add_stat('epoch_num', self.epoch_num)
# trigger subclass # trigger subclass
...@@ -88,7 +88,7 @@ class Trainer(object): ...@@ -88,7 +88,7 @@ class Trainer(object):
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step) self.summary_writer.add_summary(summary, get_global_step())
def write_scalar_summary(self, name, val): def write_scalar_summary(self, name, val):
self.summary_writer.add_summary( self.summary_writer.add_summary(
...@@ -98,10 +98,8 @@ class Trainer(object): ...@@ -98,10 +98,8 @@ class Trainer(object):
def finalize_graph(self): def finalize_graph(self):
# some final operations that might modify the graph # some final operations that might modify the graph
get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
callbacks = self.config.callbacks self.config.callbacks.setup_graph(weakref.proxy(self))
callbacks.setup_graph(weakref.proxy(self))
if not hasattr(logger, 'LOG_DIR'): if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!") raise RuntimeError("logger directory wasn't set!")
...@@ -122,8 +120,8 @@ class Trainer(object): ...@@ -122,8 +120,8 @@ class Trainer(object):
callbacks = self.config.callbacks callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
try: try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train() callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step()))
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1): self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
......
...@@ -74,7 +74,7 @@ class TrainConfig(object): ...@@ -74,7 +74,7 @@ class TrainConfig(object):
if self.extra_threads_procs: if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs") logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
self.callbacks.cbs.append(StartProcOrThread(self.extra_threads_procs)) self.callbacks.append(StartProcOrThread(self.extra_threads_procs))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None): def set_tower(self, nr_tower=None, tower=None):
......
...@@ -54,7 +54,7 @@ class MultiGPUTrainer(QueueInputTrainer): ...@@ -54,7 +54,7 @@ class MultiGPUTrainer(QueueInputTrainer):
tf.variable_scope(global_scope, reuse=idx > 0), \ tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope: TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue model_inputs = self._get_dequeued_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs) self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs) self.model.build_graph(model_inputs)
......
...@@ -17,6 +17,7 @@ from ..tfutils import (get_vars_by_names, freeze_collection, ...@@ -17,6 +17,7 @@ from ..tfutils import (get_vars_by_names, freeze_collection,
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
...@@ -160,7 +161,7 @@ class QueueInputTrainer(Trainer): ...@@ -160,7 +161,7 @@ class QueueInputTrainer(Trainer):
self.predict_tower = predict_tower or [0] self.predict_tower = predict_tower or [0]
self.dequed_inputs = None self.dequed_inputs = None
def _get_model_inputs(self): def _get_dequeued_inputs(self):
""" Dequeue a datapoint from input_queue and return""" """ Dequeue a datapoint from input_queue and return"""
ret = self.input_queue.dequeue(name='input_deque') ret = self.input_queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
...@@ -172,7 +173,7 @@ class QueueInputTrainer(Trainer): ...@@ -172,7 +173,7 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self): def _single_tower_grad(self):
""" Get grad and cost for single-tower""" """ Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_model_inputs() self.dequed_inputs = model_inputs = self._get_dequeued_inputs()
# test the overhead of queue # test the overhead of queue
#with tf.device('/gpu:0'): #with tf.device('/gpu:0'):
...@@ -190,7 +191,7 @@ class QueueInputTrainer(Trainer): ...@@ -190,7 +191,7 @@ class QueueInputTrainer(Trainer):
def _build_enque_thread(self): def _build_enque_thread(self):
""" create a thread that keeps filling the queue """ """ create a thread that keeps filling the queue """
self.input_th = EnqueueThread(self) self.input_th = EnqueueThread(self)
self._extra_threads_procs.append(self.input_th) self.config.callbacks.append(StartProcOrThread(self.input_th))
def train(self): def train(self):
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment