Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
972e298a
Commit
972e298a
authored
Nov 06, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix saver v1/v2 issues
parent
b9498a1a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
12 additions
and
23 deletions
+12
-23
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+0
-10
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+0
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+9
-7
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+2
-4
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
tensorpack/callbacks/common.py
View file @
972e298a
...
...
@@ -75,16 +75,6 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self
.
path
,
global_step
=
get_global_step
(),
write_meta_graph
=
False
)
# create a symbolic link for the latest model
latest
=
self
.
saver
.
last_checkpoints
[
-
1
]
basename
=
os
.
path
.
basename
(
latest
)
linkname
=
os
.
path
.
join
(
os
.
path
.
dirname
(
latest
),
'latest'
)
try
:
os
.
unlink
(
linkname
)
except
OSError
:
pass
os
.
symlink
(
basename
,
linkname
)
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
...
...
tensorpack/models/model_desc.py
View file @
972e298a
...
...
@@ -84,7 +84,6 @@ Use _build_graph(self, input_vars) and get_current_tower_context().is_training i
CheckGradient
()
]
class
ModelFromMetaGraph
(
ModelDesc
):
"""
Load the whole exact TF graph from a saved meta_graph.
...
...
tensorpack/tfutils/sessinit.py
View file @
972e298a
...
...
@@ -54,28 +54,30 @@ class SaverRestore(SessionInit):
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
:param model_path: a model
file
or a ``checkpoint`` file.
:param model_path: a model
name (model-xxxx)
or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
"""
assert
os
.
path
.
isfile
(
model_path
)
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
model_path
=
os
.
path
.
join
(
'.'
,
model_path
)
# avoid #4921
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
model_path
=
tf
.
train
.
get_checkpoint_state
(
os
.
path
.
dirname
(
model_path
))
.
model_checkpoint_path
assert
os
.
path
.
isfile
(
model_path
)
model_path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
model_path
))
# to be consistent with either v1 or v2
assert
os
.
path
.
isfile
(
model_path
)
or
os
.
path
.
isfile
(
model_path
+
'.index'
)
self
.
set_path
(
model_path
)
self
.
prefix
=
prefix
def
_init
(
self
,
sess
):
logger
.
info
(
"Restoring checkpoint from {}."
.
format
(
self
.
path
))
"Restoring checkpoint from {}
..
."
.
format
(
self
.
path
))
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
vars_map
=
self
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)))
try
:
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)),
write_version
=
2
)
except
:
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)))
saver
.
restore
(
sess
,
self
.
path
)
def
set_path
(
self
,
model_path
):
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
972e298a
...
...
@@ -10,7 +10,7 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
"""
:param logits: NxC
:param label: N
:returns: a float32 vector of length N with 0/1 values
, 1 meaning
incorrect prediction
:returns: a float32 vector of length N with 0/1 values
. 1 means
incorrect prediction
"""
return
tf
.
cast
(
tf
.
logical_not
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
topk
)),
tf
.
float32
,
name
=
name
)
...
...
@@ -95,9 +95,7 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
def
huber_loss
(
x
,
delta
=
1
,
name
=
None
):
if
name
is
None
:
name
=
'huber_loss'
def
huber_loss
(
x
,
delta
=
1
,
name
=
'huber_loss'
):
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
return
tf
.
reduce_sum
(
...
...
tensorpack/train/trainer.py
View file @
972e298a
...
...
@@ -78,7 +78,7 @@ class SimpleTrainer(Trainer):
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
summary_moving_average
()
,
name
=
'train_op'
)
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment