Commit d382ea9d authored by Yuxin Wu's avatar Yuxin Wu

fix cifar bug

parent 28e42e11
...@@ -157,8 +157,6 @@ if __name__ == '__main__': ...@@ -157,8 +157,6 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
with tf.Graph().as_default(): with tf.Graph().as_default():
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
......
...@@ -96,8 +96,8 @@ def get_config(): ...@@ -96,8 +96,8 @@ def get_config():
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-3, learning_rate=1e-3,
global_step=get_global_step_var(), global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 20, decay_steps=dataset_train.size() * 10,
decay_rate=0.1, staircase=True, name='learning_rate') decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
......
...@@ -6,6 +6,7 @@ import os, sys ...@@ -6,6 +6,7 @@ import os, sys
import pickle import pickle
import numpy as np import numpy as np
import random import random
import six
from six.moves import urllib, range from six.moves import urllib, range
import copy import copy
import tarfile import tarfile
...@@ -44,7 +45,10 @@ def read_cifar10(filenames): ...@@ -44,7 +45,10 @@ def read_cifar10(filenames):
ret = [] ret = []
for fname in filenames: for fname in filenames:
fo = open(fname, 'rb') fo = open(fname, 'rb')
if six.PY3:
dic = pickle.load(fo, encoding='bytes') dic = pickle.load(fo, encoding='bytes')
else:
dic = pickle.load(fo)
data = dic[b'data'] data = dic[b'data']
label = dic[b'labels'] label = dic[b'labels']
fo.close() fo.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment