Commit 8c674cc4 authored by ppwwyyxx's avatar ppwwyyxx

cifar test and load

parent f16643ac
......@@ -13,6 +13,7 @@ from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.utils.sessinit import *
from tensorpack.utils.validation_callback import *
from tensorpack.dataflow.dataset import Cifar10
from tensorpack.dataflow import *
......@@ -83,7 +84,9 @@ def get_config():
dataset_train = Cifar10('train')
dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24)))
dataset_train = BatchData(dataset_train, 128)
#dataset_test = BatchData(Cifar10('test'), 128)
dataset_test = Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size()
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
......@@ -115,7 +118,7 @@ def get_config():
callback=Callbacks([
SummaryWriter(),
PeriodicSaver(),
#ValidationError(dataset_test, prefix='test'),
ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
inputs=input_vars,
......@@ -129,6 +132,8 @@ if __name__ == '__main__':
from tensorpack import train
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
global args
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......@@ -136,4 +141,8 @@ if __name__ == '__main__':
with tf.Graph().as_default():
train.prepare()
config = get_config()
if args.load:
config['session_init'] = SaverRestore(args.load)
sess_init = NewSession()
train.start_train(config)
......@@ -2,12 +2,6 @@
# -*- coding: UTF-8 -*-
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys
import cPickle
import numpy
......
......@@ -166,6 +166,8 @@ class TestCallbacks(Callback):
cb.before_train()
def trigger_epoch(self):
if not self.cbs:
return
tm = CallbackTimeLogger()
with self.graph.as_default(), self.sess.as_default():
s = time.time()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment