Commit 1898cd3c authored by Yuxin Wu's avatar Yuxin Wu

fix bug

parent 6306da7e
......@@ -204,7 +204,7 @@ def get_config():
HumanHyperParamSetter('entropy_beta'),
HumanHyperParamSetter('explore_factor'),
master,
StartProcOrThread(master)
StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
]),
session_config=get_default_sess_config(0.5),
......
......@@ -10,6 +10,7 @@ import os
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from tensorpack.utils.gpu import get_nr_gpu
"""
A small convnet model for Cifar10 or Cifar100 dataset.
......@@ -152,7 +153,10 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
QueueInputTrainer(config).train()
#if args.gpu:
#config.nr_tower = len(args.gpu.split(','))
#AsyncMultiGPUTrainer(config).train()
if args.gpu:
config.nr_tower = len(args.gpu.split(','))
nr_gpu = get_nr_gpu()
if nr_gpu == 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainer(config).train()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import Callback
from ..utils.concurrency import start_proc_mask_signal
from ..utils import logger
__all__ = ['StartProcOrThread']
class StartProcOrThread(Callback):
def __init__(self, procs_threads):
"""
Start extra threads and processes before training
:param procs_threads: list of processes or threads
"""
if not isinstance(procs_threads, list):
procs_threads = [procs_threads]
self._procs_threads = procs_threads
def _before_train(self):
logger.info("Starting all threads & procs ...")
# avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads)
......@@ -86,3 +86,7 @@ class Callbacks(Callback):
with tm.timed_callback(display_name):
cb.trigger_epoch()
tm.log()
def append(self, cb):
assert isinstance(cb, Callback)
self.cbs.append(cb)
......@@ -68,7 +68,7 @@ class Trainer(object):
def trigger_epoch(self):
# by default, add this two stat
self.stat_holder.add_stat('global_step', self.global_step)
self.stat_holder.add_stat('global_step', get_global_step())
self.stat_holder.add_stat('epoch_num', self.epoch_num)
# trigger subclass
......@@ -88,7 +88,7 @@ class Trainer(object):
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
self.stat_holder.add_stat(val.tag, val.simple_value)
self.summary_writer.add_summary(summary, self.global_step)
self.summary_writer.add_summary(summary, get_global_step())
def write_scalar_summary(self, name, val):
self.summary_writer.add_summary(
......@@ -98,10 +98,8 @@ class Trainer(object):
def finalize_graph(self):
# some final operations that might modify the graph
get_global_step_var() # ensure there is such var, before finalizing the graph
logger.info("Setup callbacks ...")
callbacks = self.config.callbacks
callbacks.setup_graph(weakref.proxy(self))
self.config.callbacks.setup_graph(weakref.proxy(self))
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
......@@ -122,8 +120,8 @@ class Trainer(object):
callbacks = self.config.callbacks
with self.sess.as_default():
try:
logger.info("Start training with global_step={}".format(get_global_step()))
callbacks.before_train()
logger.info("Start training with global_step={}".format(get_global_step()))
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
......
......@@ -74,7 +74,7 @@ class TrainConfig(object):
if self.extra_threads_procs:
logger.warn("[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs")
from ..callbacks.concurrency import StartProcOrThread
self.callbacks.cbs.append(StartProcOrThread(self.extra_threads_procs))
self.callbacks.append(StartProcOrThread(self.extra_threads_procs))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, nr_tower=None, tower=None):
......
......@@ -54,7 +54,7 @@ class MultiGPUTrainer(QueueInputTrainer):
tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext('tower{}'.format(idx)) as scope:
logger.info("Building graph for training tower {}...".format(idx))
model_inputs = self._get_model_inputs() # each tower dequeue from input queue
model_inputs = self._get_dequeued_inputs() # each tower dequeue from input queue
self.dequed_inputs.append(model_inputs)
self.model.build_graph(model_inputs)
......
......@@ -17,6 +17,7 @@ from ..tfutils import (get_vars_by_names, freeze_collection,
from ..tfutils.summary import summary_moving_average, add_moving_summary
from ..tfutils.modelutils import describe_model
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
......@@ -160,7 +161,7 @@ class QueueInputTrainer(Trainer):
self.predict_tower = predict_tower or [0]
self.dequed_inputs = None
def _get_model_inputs(self):
def _get_dequeued_inputs(self):
""" Dequeue a datapoint from input_queue and return"""
ret = self.input_queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
......@@ -172,7 +173,7 @@ class QueueInputTrainer(Trainer):
def _single_tower_grad(self):
""" Get grad and cost for single-tower"""
self.dequed_inputs = model_inputs = self._get_model_inputs()
self.dequed_inputs = model_inputs = self._get_dequeued_inputs()
# test the overhead of queue
#with tf.device('/gpu:0'):
......@@ -190,7 +191,7 @@ class QueueInputTrainer(Trainer):
def _build_enque_thread(self):
""" create a thread that keeps filling the queue """
self.input_th = EnqueueThread(self)
self._extra_threads_procs.append(self.input_th)
self.config.callbacks.append(StartProcOrThread(self.input_th))
def train(self):
assert len(self.config.tower) == 1, \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment