Commit d382ea9d authored by Yuxin Wu's avatar Yuxin Wu

fix cifar bug

parent 28e42e11
......@@ -157,8 +157,6 @@ if __name__ == '__main__':
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default():
with tf.device('/cpu:0'):
......
......@@ -96,8 +96,8 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 20,
decay_rate=0.1, staircase=True, name='learning_rate')
decay_steps=dataset_train.size() * 10,
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......
......@@ -6,6 +6,7 @@ import os, sys
import pickle
import numpy as np
import random
import six
from six.moves import urllib, range
import copy
import tarfile
......@@ -44,7 +45,10 @@ def read_cifar10(filenames):
ret = []
for fname in filenames:
fo = open(fname, 'rb')
dic = pickle.load(fo, encoding='bytes')
if six.PY3:
dic = pickle.load(fo, encoding='bytes')
else:
dic = pickle.load(fo)
data = dic[b'data']
label = dic[b'labels']
fo.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment