Commit 41bf8ffe authored by Yuxin Wu's avatar Yuxin Wu

add multigpu to cifar10_convnet

parent b6370d50
...@@ -20,6 +20,8 @@ from tensorpack.dataflow import imgaug ...@@ -20,6 +20,8 @@ from tensorpack.dataflow import imgaug
""" """
CIFAR10 90% validation accuracy after 70k step. CIFAR10 90% validation accuracy after 70k step.
91% validation accuracy after 36k step with 3 GPU.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -126,10 +128,11 @@ def get_config(): ...@@ -126,10 +128,11 @@ def get_config():
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
nr_gpu = get_nr_gpu()
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-2, learning_rate=1e-2,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30, decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 15,
decay_rate=0.5, staircase=True, name='learning_rate') decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
......
...@@ -8,7 +8,8 @@ from abc import abstractmethod, ABCMeta ...@@ -8,7 +8,8 @@ from abc import abstractmethod, ABCMeta
import operator import operator
from .base import Callback from .base import Callback
from ..utils import logger, get_op_var_name from ..utils import logger
from ..tfutils import get_op_var_name
__all__ = ['HyperParamSetter', 'HumanHyperParamSetter', __all__ = ['HyperParamSetter', 'HumanHyperParamSetter',
'ScheduledHyperParamSetter'] 'ScheduledHyperParamSetter']
......
...@@ -32,3 +32,9 @@ def get_global_step(): ...@@ -32,3 +32,9 @@ def get_global_step():
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
def get_op_var_name(name):
if name.endswith(':0'):
return name[:-2], name
else:
return name, name + ':0'
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from . import logger from . import logger
__all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized', __all__ = ['timed_operation', 'change_env', 'get_rng', 'memoized',
'get_op_var_name'] 'get_nr_gpu']
#def expand_dim_if_necessary(var, dp): #def expand_dim_if_necessary(var, dp):
# """ # """
...@@ -79,8 +79,7 @@ def get_rng(self): ...@@ -79,8 +79,7 @@ def get_rng(self):
seed = (id(self) + os.getpid()) % 4294967295 seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
def get_op_var_name(name): def get_nr_gpu():
if name.endswith(':0'): env = os.environ['CUDA_VISIBLE_DEVICES']
return name[:-2], name assert env is not None
else: return len(env.split(','))
return name, name + ':0'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment