Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d646972d
Commit
d646972d
authored
Apr 14, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better prefetch & periodic callback as wrapper
parent
9f1af4c8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
54 additions
and
51 deletions
+54
-51
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+3
-4
examples/cifar10_resnet.py
examples/cifar10_resnet.py
+1
-1
examples/load_alexnet.py
examples/load_alexnet.py
+1
-1
examples/mnist_convnet.py
examples/mnist_convnet.py
+1
-1
examples/svhn_digit_convnet.py
examples/svhn_digit_convnet.py
+1
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+12
-6
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+5
-7
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+1
-1
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+25
-26
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+1
-1
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+2
-1
No files found.
examples/cifar10_convnet.py
View file @
d646972d
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File:
argscope_tes
t.py
# File:
cifar10_convne
t.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -116,8 +116,7 @@ def get_config():
...
@@ -116,8 +116,7 @@ def get_config():
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
dataset_test
=
get_data
(
'test'
)
dataset_test
=
get_data
(
'test'
)
sess_config
=
get_default_sess_config
()
sess_config
=
get_default_sess_config
(
0.5
)
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
nr_gpu
=
get_nr_gpu
()
nr_gpu
=
get_nr_gpu
()
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
...
@@ -132,7 +131,7 @@ def get_config():
...
@@ -132,7 +131,7 @@ def get_config():
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
Periodic
Saver
(),
Model
Saver
(),
ClassificationError
(
dataset_test
,
prefix
=
'test'
),
ClassificationError
(
dataset_test
,
prefix
=
'test'
),
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
...
...
examples/cifar10_resnet.py
View file @
d646972d
...
@@ -166,7 +166,7 @@ def get_config():
...
@@ -166,7 +166,7 @@ def get_config():
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
optimizer
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
Periodic
Saver
(),
Model
Saver
(),
ClassificationError
(
dataset_test
,
prefix
=
'test'
),
ClassificationError
(
dataset_test
,
prefix
=
'test'
),
ScheduledHyperParamSetter
(
'learning_rate'
,
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
...
...
examples/load_alexnet.py
View file @
d646972d
...
@@ -106,7 +106,7 @@ def get_config():
...
@@ -106,7 +106,7 @@ def get_config():
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
Periodic
Saver
(),
Model
Saver
(),
#ValidationError(dataset_test, prefix='test'),
#ValidationError(dataset_test, prefix='test'),
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
...
...
examples/mnist_convnet.py
View file @
d646972d
...
@@ -105,7 +105,7 @@ def get_config():
...
@@ -105,7 +105,7 @@ def get_config():
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
Periodic
Saver
(),
Model
Saver
(),
ValidationStatPrinter
(
dataset_test
,
[
'cost:0'
]),
ValidationStatPrinter
(
dataset_test
,
[
'cost:0'
]),
ClassificationError
(
dataset_test
,
prefix
=
'validation'
),
ClassificationError
(
dataset_test
,
prefix
=
'validation'
),
]),
]),
...
...
examples/svhn_digit_convnet.py
View file @
d646972d
...
@@ -109,7 +109,7 @@ def get_config():
...
@@ -109,7 +109,7 @@ def get_config():
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
Periodic
Saver
(),
Model
Saver
(),
ClassificationError
(
test
,
prefix
=
'test'
),
ClassificationError
(
test
,
prefix
=
'test'
),
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
...
...
tensorpack/callbacks/base.py
View file @
d646972d
...
@@ -81,18 +81,24 @@ class Callback(object):
...
@@ -81,18 +81,24 @@ class Callback(object):
class
PeriodicCallback
(
Callback
):
class
PeriodicCallback
(
Callback
):
"""
"""
A callback to be triggered after every `period` epochs.
A callback to be triggered after every `period` epochs.
Doesn't work for trigger_step
"""
"""
def
__init__
(
self
,
period
):
def
__init__
(
self
,
cb
,
period
):
"""
"""
:param cb: a `Callback`
:param period: int
:param period: int
"""
"""
self
.
cb
=
cb
self
.
period
=
int
(
period
)
self
.
period
=
int
(
period
)
def
_before_train
(
self
):
self
.
cb
.
before_train
(
self
.
trainer
)
def
_after_train
(
self
):
self
.
cb
.
after_train
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
cb
.
epoch_num
=
self
.
epoch_num
-
1
if
self
.
epoch_num
%
self
.
period
==
0
:
if
self
.
epoch_num
%
self
.
period
==
0
:
self
.
_trigger_periodic
()
self
.
cb
.
trigger_epoch
()
@
abstractmethod
def
_trigger_periodic
(
self
):
pass
tensorpack/callbacks/common.py
View file @
d646972d
...
@@ -6,22 +6,20 @@ import tensorflow as tf
...
@@ -6,22 +6,20 @@ import tensorflow as tf
import
os
import
os
import
re
import
re
from
.base
import
Callback
,
PeriodicCallback
from
.base
import
Callback
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'
Periodic
Saver'
]
__all__
=
[
'
Model
Saver'
]
class
PeriodicSaver
(
Periodic
Callback
):
class
ModelSaver
(
Callback
):
"""
"""
Save the model to logger directory.
Save the model to logger directory.
"""
"""
def
__init__
(
self
,
period
=
1
,
keep_recent
=
10
,
keep_freq
=
0.5
):
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
):
"""
"""
:param period: number of epochs to save models.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
"""
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
keep_recent
=
keep_recent
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
self
.
keep_freq
=
keep_freq
...
@@ -48,7 +46,7 @@ class PeriodicSaver(PeriodicCallback):
...
@@ -48,7 +46,7 @@ class PeriodicSaver(PeriodicCallback):
var_dict
[
name
]
=
v
var_dict
[
name
]
=
v
return
var_dict
return
var_dict
def
_trigger_
periodic
(
self
):
def
_trigger_
epoch
(
self
):
self
.
saver
.
save
(
self
.
saver
.
save
(
tf
.
get_default_session
(),
tf
.
get_default_session
(),
self
.
path
,
self
.
path
,
...
...
tensorpack/callbacks/group.py
View file @
d646972d
...
@@ -80,7 +80,7 @@ class TestCallbackContext(object):
...
@@ -80,7 +80,7 @@ class TestCallbackContext(object):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
if
ckpt
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use
PeriodicSaver before any
TestCallback?"
)
"Cannot find a checkpoint state. Do you forget to use
ModelSaver before all
TestCallback?"
)
logger
.
info
(
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
self
.
saver
.
restore
(
self
.
sess
,
ckpt
.
model_checkpoint_path
)
...
...
tensorpack/callbacks/summary.py
View file @
d646972d
...
@@ -8,7 +8,7 @@ import os
...
@@ -8,7 +8,7 @@ import os
import
operator
import
operator
import
pickle
import
pickle
from
.base
import
Callback
,
PeriodicCallback
from
.base
import
Callback
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'StatHolder'
,
'StatPrinter'
]
__all__
=
[
'StatHolder'
,
'StatPrinter'
]
...
...
tensorpack/dataflow/prefetch.py
View file @
d646972d
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
multiprocessing
import
multiprocessing
from
.base
import
DataFlow
from
.base
import
Proxy
DataFlow
from
..utils.concurrency
import
ensure_procs_terminate
from
..utils.concurrency
import
ensure_procs_terminate
__all__
=
[
'PrefetchData'
]
__all__
=
[
'PrefetchData'
]
...
@@ -30,7 +30,7 @@ class PrefetchProcess(multiprocessing.Process):
...
@@ -30,7 +30,7 @@ class PrefetchProcess(multiprocessing.Process):
finally
:
finally
:
self
.
queue
.
put
(
Sentinel
())
self
.
queue
.
put
(
Sentinel
())
class
PrefetchData
(
DataFlow
):
class
PrefetchData
(
Proxy
DataFlow
):
"""
"""
Prefetch data from a `DataFlow` using multiprocessing
Prefetch data from a `DataFlow` using multiprocessing
"""
"""
...
@@ -40,35 +40,34 @@ class PrefetchData(DataFlow):
...
@@ -40,35 +40,34 @@ class PrefetchData(DataFlow):
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
:param nr_proc: number of processes to use.
"""
"""
s
elf
.
ds
=
ds
s
uper
(
PrefetchData
,
self
)
.
__init__
(
ds
)
self
.
_size
=
self
.
ds
.
size
()
self
.
_size
=
self
.
size
()
self
.
nr_proc
=
nr_proc
self
.
nr_proc
=
nr_proc
self
.
nr_prefetch
=
nr_prefetch
self
.
nr_prefetch
=
nr_prefetch
self
.
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
def
size
(
self
):
self
.
procs
=
[
PrefetchProcess
(
self
.
ds
,
self
.
queue
)
return
self
.
_size
for
_
in
range
(
self
.
nr_proc
)]
ensure_procs_terminate
(
self
.
procs
)
for
x
in
self
.
procs
:
x
.
start
()
def
get_data
(
self
):
def
get_data
(
self
):
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
procs
=
[
PrefetchProcess
(
self
.
ds
,
queue
)
for
_
in
range
(
self
.
nr_proc
)]
ensure_procs_terminate
(
procs
)
[
x
.
start
()
for
x
in
procs
]
end_cnt
=
0
end_cnt
=
0
tot_cnt
=
0
tot_cnt
=
0
try
:
while
True
:
while
True
:
dp
=
self
.
queue
.
get
()
dp
=
queue
.
get
()
if
isinstance
(
dp
,
Sentinel
):
if
isinstance
(
dp
,
Sentinel
):
end_cnt
+=
1
end_cnt
+=
1
if
end_cnt
==
self
.
nr_proc
:
if
end_cnt
==
self
.
nr_proc
:
break
continue
tot_cnt
+=
1
yield
dp
if
tot_cnt
==
self
.
_size
:
break
break
finally
:
continue
queue
.
close
()
tot_cnt
+=
1
[
x
.
terminate
()
for
x
in
procs
]
yield
dp
if
tot_cnt
==
self
.
_size
:
break
def
__del__
(
self
):
self
.
queue
.
close
()
for
x
in
self
.
procs
:
x
.
terminate
()
tensorpack/tfutils/gradproc.py
View file @
d646972d
...
@@ -61,7 +61,7 @@ class ScaleGradient(GradientProcessor):
...
@@ -61,7 +61,7 @@ class ScaleGradient(GradientProcessor):
self
.
multipliers
=
multipliers
self
.
multipliers
=
multipliers
def
_process
(
self
,
grads
):
def
_process
(
self
,
grads
):
# TODO use None for zero
to speed up
?
# TODO use None for zero
can speed up (or not)
?
ret
=
[]
ret
=
[]
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
varname
=
var
.
op
.
name
varname
=
var
.
op
.
name
...
...
tensorpack/utils/utils.py
View file @
d646972d
...
@@ -76,7 +76,8 @@ class memoized(object):
...
@@ -76,7 +76,8 @@ class memoized(object):
return
functools
.
partial
(
self
.
__call__
,
obj
)
return
functools
.
partial
(
self
.
__call__
,
obj
)
def
get_rng
(
self
):
def
get_rng
(
self
):
seed
=
(
id
(
self
)
+
os
.
getpid
())
%
4294967295
seed
=
(
id
(
self
)
+
os
.
getpid
()
+
int
(
datetime
.
now
()
.
strftime
(
"
%
Y
%
m
%
d
%
H
%
M
%
S
%
f"
)))
%
4294967295
return
np
.
random
.
RandomState
(
seed
)
return
np
.
random
.
RandomState
(
seed
)
def
get_nr_gpu
():
def
get_nr_gpu
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment