Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8c674cc4
Commit
8c674cc4
authored
Dec 30, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
cifar test and load
parent
f16643ac
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
8 deletions
+13
-8
example_cifar10.py
example_cifar10.py
+11
-2
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+0
-6
tensorpack/utils/callback.py
tensorpack/utils/callback.py
+2
-0
No files found.
example_cifar10.py
View file @
8c674cc4
...
@@ -13,6 +13,7 @@ from tensorpack.utils import *
...
@@ -13,6 +13,7 @@ from tensorpack.utils import *
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.utils.sessinit
import
*
from
tensorpack.utils.validation_callback
import
*
from
tensorpack.utils.validation_callback
import
*
from
tensorpack.dataflow.dataset
import
Cifar10
from
tensorpack.dataflow.dataset
import
Cifar10
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
*
...
@@ -83,7 +84,9 @@ def get_config():
...
@@ -83,7 +84,9 @@ def get_config():
dataset_train
=
Cifar10
(
'train'
)
dataset_train
=
Cifar10
(
'train'
)
dataset_train
=
MapData
(
dataset_train
,
lambda
img
:
cv2
.
resize
(
img
,
(
24
,
24
)))
dataset_train
=
MapData
(
dataset_train
,
lambda
img
:
cv2
.
resize
(
img
,
(
24
,
24
)))
dataset_train
=
BatchData
(
dataset_train
,
128
)
dataset_train
=
BatchData
(
dataset_train
,
128
)
#dataset_test = BatchData(Cifar10('test'), 128)
dataset_test
=
Cifar10
(
'test'
)
dataset_test
=
MapData
(
dataset_test
,
lambda
img
:
cv2
.
resize
(
img
,
(
24
,
24
)))
dataset_test
=
BatchData
(
dataset_test
,
128
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
#step_per_epoch = 20
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
#dataset_test = FixedSizeData(dataset_test, 20)
...
@@ -115,7 +118,7 @@ def get_config():
...
@@ -115,7 +118,7 @@ def get_config():
callback
=
Callbacks
([
callback
=
Callbacks
([
SummaryWriter
(),
SummaryWriter
(),
PeriodicSaver
(),
PeriodicSaver
(),
#
ValidationError(dataset_test, prefix='test'),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
inputs
=
input_vars
,
inputs
=
input_vars
,
...
@@ -129,6 +132,8 @@ if __name__ == '__main__':
...
@@ -129,6 +132,8 @@ if __name__ == '__main__':
from
tensorpack
import
train
from
tensorpack
import
train
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
global
args
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
@@ -136,4 +141,8 @@ if __name__ == '__main__':
...
@@ -136,4 +141,8 @@ if __name__ == '__main__':
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
train
.
prepare
()
train
.
prepare
()
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
[
'session_init'
]
=
SaverRestore
(
args
.
load
)
sess_init
=
NewSession
()
train
.
start_train
(
config
)
train
.
start_train
(
config
)
tensorpack/dataflow/dataset/cifar10.py
View file @
8c674cc4
...
@@ -2,12 +2,6 @@
...
@@ -2,12 +2,6 @@
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: cifar10.py
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
,
sys
import
os
,
sys
import
cPickle
import
cPickle
import
numpy
import
numpy
...
...
tensorpack/utils/callback.py
View file @
8c674cc4
...
@@ -166,6 +166,8 @@ class TestCallbacks(Callback):
...
@@ -166,6 +166,8 @@ class TestCallbacks(Callback):
cb
.
before_train
()
cb
.
before_train
()
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
if
not
self
.
cbs
:
return
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
s
=
time
.
time
()
s
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment