Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
25e9853a
Commit
25e9853a
authored
Sep 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc fix
parent
2e238998
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
7 additions
and
3 deletions
+7
-3
examples/README.md
examples/README.md
+0
-1
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+4
-2
tensorpack/train/config.py
tensorpack/train/config.py
+2
-0
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-0
No files found.
examples/README.md
View file @
25e9853a
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
Examples with __reproducible__ and meaningful performancce.
Examples with __reproducible__ and meaningful performancce.
+
[
An illustrative mnist example
](
mnist-convnet.py
)
+
[
An illustrative mnist example
](
mnist-convnet.py
)
+
[
A small Cifar10 ConvNet with 91% accuracy
](
cifar-convnet.py
)
+
[
A tiny SVHN ConvNet with 97.5% accuracy
](
svhn-digit-convnet.py
)
+
[
A tiny SVHN ConvNet with 97.5% accuracy
](
svhn-digit-convnet.py
)
+
[
Reproduce some reinforcement learning papers
](
Atari2600
)
+
[
Reproduce some reinforcement learning papers
](
Atari2600
)
+
[
char-rnn for fun
](
char-rnn
)
+
[
char-rnn for fun
](
char-rnn
)
...
...
tensorpack/callbacks/common.py
View file @
25e9853a
...
@@ -81,9 +81,10 @@ because {} will be saved".format(v.name, var_dict[name].name))
...
@@ -81,9 +81,10 @@ because {} will be saved".format(v.name, var_dict[name].name))
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
class
MinSaver
(
Callback
):
class
MinSaver
(
Callback
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
):
def
__init__
(
self
,
monitor_stat
,
reverse
=
True
,
filename
=
None
):
self
.
monitor_stat
=
monitor_stat
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
reverse
=
reverse
self
.
filename
=
filename
self
.
min
=
None
self
.
min
=
None
def
_get_stat
(
self
):
def
_get_stat
(
self
):
...
@@ -107,7 +108,8 @@ class MinSaver(Callback):
...
@@ -107,7 +108,8 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
path
=
chpt
.
model_checkpoint_path
path
=
chpt
.
model_checkpoint_path
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
)
self
.
filename
or
(
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
+
'.tfmodel'
))
shutil
.
copy
(
path
,
newname
)
shutil
.
copy
(
path
,
newname
)
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
logger
.
info
(
"Model with {} '{}' saved."
.
format
(
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
...
...
tensorpack/train/config.py
View file @
25e9853a
...
@@ -63,6 +63,8 @@ class TrainConfig(object):
...
@@ -63,6 +63,8 @@ class TrainConfig(object):
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
logger
.
warn
(
"config.set_tower is deprecated. set config.tower or config.nr_tower directly"
)
# this is a deprecated function
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
assert
nr_tower
is
None
or
tower
is
None
,
"Cannot set both nr_tower and tower!"
if
nr_tower
:
if
nr_tower
:
tower
=
list
(
range
(
nr_tower
))
tower
=
list
(
range
(
nr_tower
))
...
...
tensorpack/train/multigpu.py
View file @
25e9853a
...
@@ -22,6 +22,7 @@ class MultiGPUTrainer(QueueInputTrainer):
...
@@ -22,6 +22,7 @@ class MultiGPUTrainer(QueueInputTrainer):
""" Base class for multi-gpu training"""
""" Base class for multi-gpu training"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
super
(
MultiGPUTrainer
,
self
)
.
__init__
(
config
,
input_queue
,
predict_tower
)
super
(
MultiGPUTrainer
,
self
)
.
__init__
(
config
,
input_queue
,
predict_tower
)
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
self
.
dequed_inputs
=
[]
self
.
dequed_inputs
=
[]
@
staticmethod
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment