Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5d529d03
Commit
5d529d03
authored
Sep 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bug in ProxyCallback when before_run is used
parent
087e66db
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
8 deletions
+9
-8
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+3
-2
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-1
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+1
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+1
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+3
-2
No files found.
examples/ResNet/imagenet-resnet.py
View file @
5d529d03
...
...
@@ -79,9 +79,8 @@ class Model(ModelDesc):
def
get_data
(
name
,
batch
):
isTrain
=
name
==
'train'
augmentors
=
fbresnet_augmentor
(
isTrain
)
datadir
=
args
.
data
return
get_imagenet_dataflow
(
datadir
,
name
,
batch
,
augmentors
,
dir_structure
=
'original'
)
args
.
data
,
name
,
batch
,
augmentors
,
dir_structure
=
'original'
)
def
get_config
(
model
,
fake
=
False
):
...
...
@@ -106,8 +105,10 @@ def get_config(model, fake=False):
infs
=
[
ClassificationError
(
'wrong-top1'
,
'val-error-top1'
),
ClassificationError
(
'wrong-top5'
,
'val-error-top5'
)]
if
nr_tower
==
1
:
# single-GPU inference with queue prefetch
callbacks
.
append
(
InferenceRunner
(
QueueInput
(
dataset_val
),
infs
))
else
:
# multi-GPU inference (with mandatory queue prefetch)
callbacks
.
append
(
DataParallelInferenceRunner
(
dataset_val
,
infs
,
list
(
range
(
nr_tower
))))
...
...
tensorpack/callbacks/base.py
View file @
5d529d03
...
...
@@ -250,7 +250,7 @@ class ProxyCallback(Callback):
self
.
cb
.
after_epoch
()
def
_before_run
(
self
,
ctx
):
self
.
cb
.
_before_run
(
ctx
)
return
self
.
cb
.
_before_run
(
ctx
)
def
_after_run
(
self
,
ctx
,
run_values
):
self
.
cb
.
_after_run
(
ctx
,
run_values
)
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
5d529d03
...
...
@@ -146,8 +146,7 @@ class ILSVRC12Files(RNGDataFlow):
class
ILSVRC12
(
ILSVRC12Files
):
"""
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax].
Produces uint8 ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999].
"""
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
None
,
dir_structure
=
'original'
):
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
5d529d03
...
...
@@ -41,7 +41,7 @@ class SVHNDigit(RNGDataFlow):
if
not
os
.
path
.
isfile
(
filename
):
url
=
SVHN_URL
+
os
.
path
.
basename
(
filename
)
logger
.
info
(
"File {} not found!"
.
format
(
filename
))
logger
.
info
(
"Downloading from {}."
.
format
(
url
))
logger
.
info
(
"Downloading from {}
..
."
.
format
(
url
))
download
(
url
,
os
.
path
.
dirname
(
filename
))
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
...
...
tensorpack/models/regularize.py
View file @
5d529d03
...
...
@@ -24,8 +24,9 @@ l1_regularizer = tf.contrib.layers.l1_regularizer
def
regularize_cost
(
regex
,
func
,
name
=
'regularize_cost'
):
"""
Apply a regularizer on trainable variables matching the regex.
In replicated mode, will only regularize variables within the current tower.
Apply a regularizer on trainable variables matching the regex, and print
the matched variables (only print once in multi-tower training).
In replicated mode, it will only regularize variables within the current tower.
Args:
regex (str): a regex to match variable names, e.g. "conv.*/W"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment