Commit 8c674cc4 authored by ppwwyyxx's avatar ppwwyyxx

cifar test and load

parent f16643ac
...@@ -13,6 +13,7 @@ from tensorpack.utils import * ...@@ -13,6 +13,7 @@ from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.utils.callback import *
from tensorpack.utils.sessinit import *
from tensorpack.utils.validation_callback import * from tensorpack.utils.validation_callback import *
from tensorpack.dataflow.dataset import Cifar10 from tensorpack.dataflow.dataset import Cifar10
from tensorpack.dataflow import * from tensorpack.dataflow import *
...@@ -83,7 +84,9 @@ def get_config(): ...@@ -83,7 +84,9 @@ def get_config():
dataset_train = Cifar10('train') dataset_train = Cifar10('train')
dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24))) dataset_train = MapData(dataset_train, lambda img: cv2.resize(img, (24, 24)))
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
#dataset_test = BatchData(Cifar10('test'), 128) dataset_test = Cifar10('test')
dataset_test = MapData(dataset_test, lambda img: cv2.resize(img, (24, 24)))
dataset_test = BatchData(dataset_test, 128)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
#step_per_epoch = 20 #step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
...@@ -115,7 +118,7 @@ def get_config(): ...@@ -115,7 +118,7 @@ def get_config():
callback=Callbacks([ callback=Callbacks([
SummaryWriter(), SummaryWriter(),
PeriodicSaver(), PeriodicSaver(),
#ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
session_config=sess_config, session_config=sess_config,
inputs=input_vars, inputs=input_vars,
...@@ -129,6 +132,8 @@ if __name__ == '__main__': ...@@ -129,6 +132,8 @@ if __name__ == '__main__':
from tensorpack import train from tensorpack import train
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
global args
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -136,4 +141,8 @@ if __name__ == '__main__': ...@@ -136,4 +141,8 @@ if __name__ == '__main__':
with tf.Graph().as_default(): with tf.Graph().as_default():
train.prepare() train.prepare()
config = get_config() config = get_config()
if args.load:
config['session_init'] = SaverRestore(args.load)
sess_init = NewSession()
train.start_train(config) train.start_train(config)
...@@ -2,12 +2,6 @@ ...@@ -2,12 +2,6 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar10.py # File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys import os, sys
import cPickle import cPickle
import numpy import numpy
......
...@@ -166,6 +166,8 @@ class TestCallbacks(Callback): ...@@ -166,6 +166,8 @@ class TestCallbacks(Callback):
cb.before_train() cb.before_train()
def trigger_epoch(self): def trigger_epoch(self):
if not self.cbs:
return
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
with self.graph.as_default(), self.sess.as_default(): with self.graph.as_default(), self.sess.as_default():
s = time.time() s = time.time()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment