Commit 41bf8ffe authored by Yuxin Wu's avatar Yuxin Wu

add multigpu to cifar10_convnet

parent b6370d50
......@@ -20,6 +20,8 @@ from tensorpack.dataflow import imgaug
"""
CIFAR10 90% validation accuracy after 70k step.
91% validation accuracy after 36k step with 3 GPU.
"""
BATCH_SIZE = 128
......@@ -126,10 +128,11 @@ def get_config():
sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
nr_gpu = get_nr_gpu()
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30,
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 15,
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
......
......@@ -8,7 +8,8 @@ from abc import abstractmethod, ABCMeta
import operator
from .base import Callback
from ..utils import logger, get_op_var_name
from ..utils import logger
from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter']
......
......@@ -32,3 +32,9 @@ def get_global_step():
tf.get_default_session(),
get_global_step_var())
def get_op_var_name(name):
if name.endswith(':0'):
return name[:-2], name
else:
return name, name + ':0'
......@@ -11,7 +11,7 @@ import numpy as np
from . import logger
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized',
'get_op_var_name']
'get_nr_gpu']
#def expand_dim_if_necessary(var, dp):
# """
......@@ -79,8 +79,7 @@ def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
def get_op_var_name(name):
if name.endswith(':0'):
return name[:-2], name
else:
return name, name + ':0'
def get_nr_gpu():
env = os.environ['CUDA_VISIBLE_DEVICES']
assert env is not None
return len(env.split(','))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment