Commit 9bf42054 authored by ppwwyyxx's avatar ppwwyyxx

gpu option. before queue

parent 4a6b480c
...@@ -74,7 +74,7 @@ def get_model(inputs): ...@@ -74,7 +74,7 @@ def get_model(inputs):
def get_config(): def get_config():
IMAGE_SIZE = 28 IMAGE_SIZE = 28
LOG_DIR = 'train_log' LOG_DIR = os.path.join('train_log', os.path.basename(__file__)[:-3])
BATCH_SIZE = 128 BATCH_SIZE = 128
logger.set_file(os.path.join(LOG_DIR, 'training.log')) logger.set_file(os.path.join(LOG_DIR, 'training.log'))
...@@ -83,6 +83,7 @@ def get_config(): ...@@ -83,6 +83,7 @@ def get_config():
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1 sess_config.device_count['GPU'] = 1
sess_config.gpu_options.allocator_type = 'BFC'
sess_config.allow_soft_placement = True sess_config.allow_soft_placement = True
# prepare model # prepare model
......
...@@ -93,17 +93,16 @@ def start_train(config): ...@@ -93,17 +93,16 @@ def start_train(config):
callbacks.trigger_step(feed, outputs, cost) callbacks.trigger_step(feed, outputs, cost)
callbacks.trigger_epoch() callbacks.trigger_epoch()
sess.close()
def main(get_config_func): def main(get_config_func):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
args = parser.parse_args() args = parser.parse_args()
device = '/cpu:0'
if args.gpu: if args.gpu:
device = '/gpu:{}'.format(args.gpu) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default(): with tf.Graph().as_default():
with tf.device(device):
prepare() prepare()
config = get_config_func() config = get_config_func()
start_train(config) start_train(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment