Commit 20887a79 authored by ppwwyyxx's avatar ppwwyyxx

add prepare

parent 28599036
...@@ -17,9 +17,6 @@ from utils import * ...@@ -17,9 +17,6 @@ from utils import *
from dataflow.dataset import Mnist from dataflow.dataset import Mnist
from dataflow import * from dataflow import *
IMAGE_SIZE = 28
LOG_DIR = 'train_log'
def get_model(inputs): def get_model(inputs):
""" """
Args: Args:
...@@ -33,11 +30,11 @@ def get_model(inputs): ...@@ -33,11 +30,11 @@ def get_model(inputs):
cost: scalar variable cost: scalar variable
""" """
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time # use this variable in dropout! Tensorpack will automatically set it to 1 at test time
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME) keep_prob = tf.get_default_graph().get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
image, label = inputs image, label = inputs
image = tf.reshape(image, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) image = tf.expand_dims(image, 3)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5, conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5,
padding='valid') padding='valid')
pool0 = MaxPooling('pool0', conv0, 2) pool0 = MaxPooling('pool0', conv0, 2)
...@@ -76,48 +73,52 @@ def get_model(inputs): ...@@ -76,48 +73,52 @@ def get_model(inputs):
return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost') return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
def main(argv=None): def get_config():
IMAGE_SIZE = 28
LOG_DIR = 'train_log'
BATCH_SIZE = 128 BATCH_SIZE = 128
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
output_vars, cost_var = get_model(input_vars)
add_histogram_summary('.*/W') # monitor histogram of all W
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
lr = tf.train.exponential_decay(
learning_rate=1e-4,
global_step=global_step_var,
decay_steps=dataset_train.size() * 50,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return dict(
dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
SummaryWriter(LOG_DIR),
ValidationError(dataset_test, prefix='test'),
PeriodicSaver(LOG_DIR),
],
session_config=sess_config,
inputs=input_vars,
outputs=output_vars,
cost=cost_var,
max_epoch=100,
)
def main(argv=None):
with tf.Graph().as_default(): with tf.Graph().as_default():
dataset_train = BatchData(Mnist('train'), BATCH_SIZE) from train import prepare, start_train
dataset_test = BatchData(Mnist('test'), 256, remainder=True) prepare()
config = get_config()
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
output_vars, cost_var = get_model(input_vars)
add_histogram_summary('.*/W') # monitor histogram of all W
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
lr = tf.train.exponential_decay(
learning_rate=1e-4,
global_step=global_step_var,
decay_steps=dataset_train.size() * 50,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
config = dict(
dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
ValidationError(
dataset_test,
prefix='test'),
PeriodicSaver(LOG_DIR, period=1),
SummaryWriter(LOG_DIR),
],
session_config=sess_config,
inputs=input_vars,
outputs=output_vars,
cost=cost_var,
max_epoch=100,
)
from train import start_train
start_train(config) start_train(config)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -7,6 +7,13 @@ import tensorflow as tf ...@@ -7,6 +7,13 @@ import tensorflow as tf
from utils import * from utils import *
from itertools import count from itertools import count
def prepare():
keep_prob = tf.placeholder(
tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
def start_train(config): def start_train(config):
""" """
Start training with the given config Start training with the given config
...@@ -40,11 +47,7 @@ def start_train(config): ...@@ -40,11 +47,7 @@ def start_train(config):
for v in output_vars: for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v) G.add_to_collection(OUTPUT_VARS_KEY, v)
try: global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError: # not created
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
# add some summary ops to the graph # add some summary ops to the graph
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment