Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f4507d45
Commit
f4507d45
authored
Feb 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dump script & prefetch size
parent
f18314d6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
9 deletions
+14
-9
example_cifar10.py
example_cifar10.py
+8
-4
scripts/dump_train_config.py
scripts/dump_train_config.py
+2
-4
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+3
-0
No files found.
example_cifar10.py
View file @
f4507d45
...
...
@@ -91,10 +91,6 @@ class Model(ModelDesc):
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
dataset_train
=
dataset
.
Cifar10
(
'train'
)
augmentors
=
[
...
...
@@ -102,10 +98,13 @@ def get_config():
imgaug
.
Flip
(
horiz
=
True
),
imgaug
.
BrightnessAdd
(
63
),
imgaug
.
Contrast
((
0.2
,
1.8
)),
#imgaug.GaussianDeform([(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
#(30,30), 0.2, 3),
imgaug
.
MeanVarianceNormalize
(
all_channel
=
True
)
]
dataset_train
=
AugmentImageComponent
(
dataset_train
,
augmentors
)
dataset_train
=
BatchData
(
dataset_train
,
128
)
#dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch
=
dataset_train
.
size
()
augmentors
=
[
...
...
@@ -145,6 +144,11 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
else
:
...
...
scripts/dump_train_config.py
View file @
f4507d45
...
...
@@ -9,10 +9,10 @@ import tensorflow as tf
import
imp
import
tqdm
import
os
from
tensorpack.utils
import
logger
from
tensorpack.utils.utils
import
mkdir_p
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
dest
=
'config'
)
parser
.
add_argument
(
'-o'
,
'--output'
,
help
=
'output directory to dump dataset image'
)
...
...
@@ -23,8 +23,6 @@ parser.add_argument('-n', '--number', help='number of images to dump',
default
=
10
,
type
=
int
)
args
=
parser
.
parse_args
()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
config
=
get_config_func
()
...
...
@@ -39,7 +37,7 @@ if args.output:
for
bi
,
img
in
enumerate
(
imgbatch
):
cnt
+=
1
fname
=
os
.
path
.
join
(
args
.
output
,
'{:03d}-{}.png'
.
format
(
cnt
,
bi
))
cv2
.
imwrite
(
fname
,
img
*
255
)
cv2
.
imwrite
(
fname
,
img
)
NR_DP_TEST
=
100
logger
.
info
(
"Testing dataflow speed:"
)
...
...
tensorpack/callbacks/common.py
View file @
f4507d45
...
...
@@ -15,11 +15,11 @@ __all__ = ['PeriodicSaver']
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
def
_before_train
(
self
):
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
...
...
tensorpack/dataflow/prefetch.py
View file @
f4507d45
...
...
@@ -39,6 +39,9 @@ class PrefetchData(DataFlow):
self
.
nr_proc
=
nr_proc
self
.
nr_prefetch
=
nr_prefetch
def
size
(
self
):
return
self
.
ds
.
size
()
*
self
.
nr_proc
def
get_data
(
self
):
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
procs
=
[
PrefetchProcess
(
self
.
ds
,
queue
)
for
_
in
range
(
self
.
nr_proc
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment