Commit 4a6b480c authored by ppwwyyxx's avatar ppwwyyxx

flag select gpu

parent 838b1df7
......@@ -41,9 +41,8 @@ def get_model(inputs):
pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
pool1 = MaxPooling('pool1', conv1, 2)
conv2 = Conv2D('conv2', pool1, out_channel=32, kernel_shape=3)
fc0 = FullyConnected('fc0', conv2, 1024)
fc0 = FullyConnected('fc0', pool1, 1024)
fc0 = tf.nn.dropout(fc0, keep_prob)
# fc will have activation summary by default. disable this for the output layer
......@@ -56,16 +55,14 @@ def get_model(inputs):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost)
# compute the number of failed samples, for ValidationErro to use at test time
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training accuracy
# monitor training error
tf.add_to_collection(
SUMMARY_VARS_KEY,
tf.sub(1.0, tf.reduce_mean(wrong), name='train_error'))
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
......@@ -86,6 +83,7 @@ def get_config():
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
sess_config.allow_soft_placement = True
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
......@@ -117,12 +115,6 @@ def get_config():
max_epoch=100,
)
def main(argv=None):
with tf.Graph().as_default():
from train import prepare, start_train
prepare()
config = get_config()
start_train(config)
if __name__ == '__main__':
tf.app.run()
from train import main
main(get_config)
......@@ -7,6 +7,7 @@ import tensorflow as tf
from utils import *
from dataflow import DataFlow
from itertools import count
import argparse
def prepare():
keep_prob = tf.placeholder(
......@@ -92,3 +93,17 @@ def start_train(config):
callbacks.trigger_step(feed, outputs, cost)
callbacks.trigger_epoch()
def main(get_config_func):
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='GPU(s) to use.') # nargs='*' in multi mode
args = parser.parse_args()
device = '/cpu:0'
if args.gpu:
device = '/gpu:{}'.format(args.gpu)
with tf.Graph().as_default():
with tf.device(device):
prepare()
config = get_config_func()
start_train(config)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment