Commit 20887a79 authored by ppwwyyxx's avatar ppwwyyxx

add prepare

parent 28599036
......@@ -17,9 +17,6 @@ from utils import *
from dataflow.dataset import Mnist
from dataflow import *
IMAGE_SIZE = 28
LOG_DIR = 'train_log'
def get_model(inputs):
"""
Args:
......@@ -33,11 +30,11 @@ def get_model(inputs):
cost: scalar variable
"""
# use this variable in dropout! Tensorpack will automatically set it to 1 at test time
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
keep_prob = tf.get_default_graph().get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
image, label = inputs
image = tf.reshape(image, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
image = tf.expand_dims(image, 3)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5,
padding='valid')
pool0 = MaxPooling('pool0', conv0, 2)
......@@ -76,48 +73,52 @@ def get_model(inputs):
return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
def main(argv=None):
def get_config():
IMAGE_SIZE = 28
LOG_DIR = 'train_log'
BATCH_SIZE = 128
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
output_vars, cost_var = get_model(input_vars)
add_histogram_summary('.*/W') # monitor histogram of all W
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
lr = tf.train.exponential_decay(
learning_rate=1e-4,
global_step=global_step_var,
decay_steps=dataset_train.size() * 50,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return dict(
dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
SummaryWriter(LOG_DIR),
ValidationError(dataset_test, prefix='test'),
PeriodicSaver(LOG_DIR),
],
session_config=sess_config,
inputs=input_vars,
outputs=output_vars,
cost=cost_var,
max_epoch=100,
)
def main(argv=None):
with tf.Graph().as_default():
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
# prepare model
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
output_vars, cost_var = get_model(input_vars)
add_histogram_summary('.*/W') # monitor histogram of all W
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
lr = tf.train.exponential_decay(
learning_rate=1e-4,
global_step=global_step_var,
decay_steps=dataset_train.size() * 50,
decay_rate=0.1, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
config = dict(
dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
ValidationError(
dataset_test,
prefix='test'),
PeriodicSaver(LOG_DIR, period=1),
SummaryWriter(LOG_DIR),
],
session_config=sess_config,
inputs=input_vars,
outputs=output_vars,
cost=cost_var,
max_epoch=100,
)
from train import start_train
from train import prepare, start_train
prepare()
config = get_config()
start_train(config)
if __name__ == '__main__':
......
......@@ -7,6 +7,13 @@ import tensorflow as tf
from utils import *
from itertools import count
def prepare():
keep_prob = tf.placeholder(
tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
def start_train(config):
"""
Start training with the given config
......@@ -40,11 +47,7 @@ def start_train(config):
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
try:
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
except KeyError: # not created
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
global_step_var = G.get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
# add some summary ops to the graph
averager = tf.train.ExponentialMovingAverage(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment