Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1c3d8741
Commit
1c3d8741
authored
May 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
minsaver , jsonstat
parent
7207816d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
43 additions
and
11 deletions
+43
-11
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+6
-5
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+20
-2
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+10
-4
tensorpack/train/base.py
tensorpack/train/base.py
+7
-0
No files found.
examples/Inception/inception-bn.py
View file @
1c3d8741
...
...
@@ -35,7 +35,7 @@ class Model(ModelDesc):
def
_get_cost
(
self
,
input_vars
,
is_training
):
image
,
label
=
input_vars
image
=
image
/
128.0
-
1
image
=
image
/
128.0
def
inception
(
name
,
x
,
nr1x1
,
nr3x3r
,
nr3x3
,
nr233r
,
nr233
,
nrpool
,
pooltype
):
stride
=
2
if
nr1x1
==
0
else
1
...
...
@@ -46,7 +46,6 @@ class Model(ModelDesc):
x2
=
Conv2D
(
'conv3x3r'
,
x
,
nr3x3r
,
1
)
outs
.
append
(
Conv2D
(
'conv3x3'
,
x2
,
nr3x3
,
3
,
stride
=
stride
))
x3
=
Conv2D
(
'conv233r'
,
x
,
nr233r
,
1
)
x3
=
Conv2D
(
'conv233a'
,
x3
,
nr233
,
3
)
outs
.
append
(
Conv2D
(
'conv233b'
,
x3
,
nr233
,
3
,
stride
=
stride
))
...
...
@@ -133,6 +132,8 @@ def get_data(train_or_test):
if
isTrain
:
augmentors
=
[
imgaug
.
Resize
((
256
,
256
)),
imgaug
.
Brightness
(
30
,
False
),
imgaug
.
Contrast
((
0.8
,
1.2
),
True
),
imgaug
.
MapImage
(
lambda
x
:
x
-
pp_mean
),
imgaug
.
RandomCrop
((
224
,
224
)),
imgaug
.
Flip
(
horiz
=
True
),
...
...
@@ -172,9 +173,9 @@ def get_config():
ClassificationError
(
'wrong-top5'
,
'val-top5-error'
)]),
#HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
8
,
0.03
),
(
13
,
0.02
),
(
21
,
5e-3
),
(
28
,
3e-3
),
(
33
,
1e-3
),
(
44
,
5
e-4
),
(
49
,
1e-4
),
(
59
,
2e-5
)
])
[(
8
,
0.03
),
(
13
,
0.02
),
(
16
,
5e-3
),
(
18
,
3e-3
),
(
24
,
1e-3
),
(
26
,
2
e-4
),
(
28
,
5e-5
)
])
]),
session_config
=
sess_config
,
model
=
Model
(),
...
...
tensorpack/callbacks/common.py
View file @
1c3d8741
...
...
@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
os
import
os
,
shutil
import
re
from
.base
import
Callback
...
...
@@ -56,6 +56,24 @@ class ModelSaver(Callback):
class
MinSaver
(
Callback
):
def
__init__
(
self
,
monitor_stat
):
self
.
monitor_stat
=
monitor_stat
self
.
min
=
None
def
_get_stat
(
self
):
return
self
.
trainer
.
stat_holder
.
get_stat_now
(
self
.
monitor_stat
)
def
_trigger_epoch
(
self
):
pass
if
self
.
min
is
None
or
self
.
_get_stat
()
<
self
.
min
:
self
.
min
=
self
.
_get_stat
()
self
.
_save
()
def
_save
(
self
):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use ModelSaver before MinSaver?"
)
path
=
chpt
.
model_checkpoint_path
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'min_'
+
self
.
monitor_stat
)
shutil
.
copy
(
path
,
newname
)
logger
.
info
(
"Model with minimum {} saved."
.
format
(
self
.
monitor_stat
))
tensorpack/callbacks/summary.py
View file @
1c3d8741
...
...
@@ -6,7 +6,7 @@ import tensorflow as tf
import
re
import
os
import
operator
import
pickle
import
json
from
.base
import
Callback
from
..utils
import
*
...
...
@@ -25,11 +25,11 @@ class StatHolder(object):
self
.
stat_now
=
{}
self
.
log_dir
=
log_dir
self
.
filename
=
os
.
path
.
join
(
log_dir
,
'stat.
pkl
'
)
self
.
filename
=
os
.
path
.
join
(
log_dir
,
'stat.
json
'
)
if
os
.
path
.
isfile
(
self
.
filename
):
logger
.
info
(
"Loading stats from {}..."
.
format
(
self
.
filename
))
with
open
(
self
.
filename
)
as
f
:
self
.
stat_history
=
pickle
.
load
(
f
)
self
.
stat_history
=
json
.
load
(
f
)
else
:
self
.
stat_history
=
[]
...
...
@@ -47,6 +47,12 @@ class StatHolder(object):
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
get_stat_now
(
self
,
k
):
"""
Return the value of a stat in the current epoch.
"""
return
self
.
stat_now
[
k
]
def
finalize
(
self
):
"""
Called after finishing adding stats. Will print and write stats to disk.
...
...
@@ -64,7 +70,7 @@ class StatHolder(object):
def
_write_stat
(
self
):
tmp_filename
=
self
.
filename
+
'.tmp'
with
open
(
tmp_filename
,
'wb'
)
as
f
:
pickle
.
dump
(
self
.
stat_history
,
f
)
json
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
class
StatPrinter
(
Callback
):
...
...
tensorpack/train/base.py
View file @
1c3d8741
...
...
@@ -19,6 +19,13 @@ __all__ = ['Trainer']
class
Trainer
(
object
):
"""
Base class for a trainer.
Available Attritbutes:
stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig`
model: a `ModelDesc`
global_step: a `int`
"""
__metaclass__
=
ABCMeta
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment