Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f0273bee
Commit
f0273bee
authored
Dec 01, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use pickle to dump inputvars
parent
6b85a1f1
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
51 additions
and
36 deletions
+51
-36
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+1
-1
examples/HED/hed.py
examples/HED/hed.py
+0
-1
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-1
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+1
-1
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+2
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+18
-12
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+17
-12
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-4
tensorpack/train/queue.py
tensorpack/train/queue.py
+3
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-0
No files found.
examples/GAN/Image2Image.py
View file @
f0273bee
...
@@ -160,7 +160,7 @@ def get_config():
...
@@ -160,7 +160,7 @@ def get_config():
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
]),
]),
model
=
Model
(),
model
=
Model
(),
step_per_epoch
=
300
,
step_per_epoch
=
dataset
.
size
()
,
max_epoch
=
300
,
max_epoch
=
300
,
)
)
...
...
examples/HED/hed.py
View file @
f0273bee
...
@@ -66,7 +66,6 @@ class Model(ModelDesc):
...
@@ -66,7 +66,6 @@ class Model(ModelDesc):
tf
.
concat
(
3
,
[
b1
,
b2
,
b3
,
b4
,
b5
]),
1
,
1
,
tf
.
concat
(
3
,
[
b1
,
b2
,
b3
,
b4
,
b5
]),
1
,
1
,
W_init
=
tf
.
constant_initializer
(
0.2
),
W_init
=
tf
.
constant_initializer
(
0.2
),
use_bias
=
False
,
nl
=
tf
.
identity
)
use_bias
=
False
,
nl
=
tf
.
identity
)
final_map
=
tf
.
squeeze
(
final_map
,
[
3
],
name
=
'predmap'
)
costs
=
[]
costs
=
[]
for
idx
,
b
in
enumerate
([
b1
,
b2
,
b3
,
b4
,
b5
,
final_map
]):
for
idx
,
b
in
enumerate
([
b1
,
b2
,
b3
,
b4
,
b5
,
final_map
]):
output
=
tf
.
nn
.
sigmoid
(
b
,
name
=
'output{}'
.
format
(
idx
+
1
))
output
=
tf
.
nn
.
sigmoid
(
b
,
name
=
'output{}'
.
format
(
idx
+
1
))
...
...
tensorpack/callbacks/inference.py
View file @
f0273bee
...
@@ -93,7 +93,7 @@ class InferenceRunner(Callback):
...
@@ -93,7 +93,7 @@ class InferenceRunner(Callback):
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
if
self
.
input_tensors
is
None
:
if
self
.
input_tensors
is
None
:
input_vars
=
self
.
trainer
.
model
.
reuse
_input_vars
()
input_vars
=
self
.
trainer
.
model
.
get
_input_vars
()
self
.
input_tensors
=
[
x
.
name
for
x
in
input_vars
]
self
.
input_tensors
=
[
x
.
name
for
x
in
input_vars
]
def
_find_output_tensors
(
self
):
def
_find_output_tensors
(
self
):
...
...
tensorpack/dataflow/image.py
View file @
f0273bee
...
@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow):
...
@@ -19,7 +19,7 @@ class ImageFromFile(RNGDataFlow):
:param channel: 1 or 3 channel
:param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize
:param resize: a (h, w) tuple. If given, will force a resize
"""
"""
assert
len
(
files
)
assert
len
(
files
)
,
"No Image Files!"
self
.
files
=
files
self
.
files
=
files
self
.
channel
=
int
(
channel
)
self
.
channel
=
int
(
channel
)
self
.
imread_mode
=
cv2
.
IMREAD_GRAYSCALE
if
self
.
channel
==
1
else
cv2
.
IMREAD_COLOR
self
.
imread_mode
=
cv2
.
IMREAD_GRAYSCALE
if
self
.
channel
==
1
else
cv2
.
IMREAD_COLOR
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
f0273bee
...
@@ -154,8 +154,9 @@ class RandomCropRandomShape(ImageAugmentor):
...
@@ -154,8 +154,9 @@ class RandomCropRandomShape(ImageAugmentor):
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
h
=
self
.
rng
.
randint
(
self
.
hmin
,
hmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
w
=
self
.
rng
.
randint
(
self
.
wmin
,
wmax
+
1
)
diffh
=
img
.
shape
[
0
]
-
h
diffh
=
img
.
shape
[
0
]
-
h
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
diffw
=
img
.
shape
[
1
]
-
w
diffw
=
img
.
shape
[
1
]
-
w
assert
diffh
>=
0
and
diffw
>=
0
y0
=
0
if
diffh
==
0
else
self
.
rng
.
randint
(
diffh
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
x0
=
0
if
diffw
==
0
else
self
.
rng
.
randint
(
diffw
)
return
(
y0
,
x0
,
h
,
w
)
return
(
y0
,
x0
,
h
,
w
)
...
...
tensorpack/models/model_desc.py
View file @
f0273bee
...
@@ -8,6 +8,7 @@ import re
...
@@ -8,6 +8,7 @@ import re
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
namedtuple
from
collections
import
namedtuple
import
inspect
import
inspect
import
pickle
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
...
@@ -16,7 +17,13 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -16,7 +17,13 @@ from ..tfutils.tower import get_current_tower_context
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
_InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
class
InputVar
(
_InputVar
):
def
dumps
(
self
):
return
pickle
.
dumps
(
self
)
@
staticmethod
def
loads
(
buf
):
return
pickle
.
loads
(
buf
)
class
ModelDesc
(
object
):
class
ModelDesc
(
object
):
""" Base class for a model description """
""" Base class for a model description """
...
@@ -29,17 +36,17 @@ class ModelDesc(object):
...
@@ -29,17 +36,17 @@ class ModelDesc(object):
:returns: the list of raw input vars in the graph
:returns: the list of raw input vars in the graph
"""
"""
try
:
try
:
return
self
.
reuse_input_vars
()
return
self
.
_
reuse_input_vars
()
except
KeyError
:
except
KeyError
:
pass
pass
ret
=
self
.
get_placeholders
()
return
self
.
get_placeholders
()
for
v
in
ret
:
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
return
ret
def
get_placeholders
(
self
,
prefix
=
''
):
def
get_placeholders
(
self
,
prefix
=
''
):
""" build placeholders with optional prefix, for each InputVar"""
""" build placeholders with optional prefix, for each InputVar
"""
input_vars
=
self
.
_get_input_vars
()
input_vars
=
self
.
_get_input_vars
()
for
v
in
input_vars
:
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
.
dumps
())
ret
=
[]
ret
=
[]
for
v
in
input_vars
:
for
v
in
input_vars
:
ret
.
append
(
tf
.
placeholder
(
ret
.
append
(
tf
.
placeholder
(
...
@@ -47,7 +54,7 @@ class ModelDesc(object):
...
@@ -47,7 +54,7 @@ class ModelDesc(object):
name
=
prefix
+
v
.
name
))
name
=
prefix
+
v
.
name
))
return
ret
return
ret
def
reuse_input_vars
(
self
):
def
_
reuse_input_vars
(
self
):
""" Find and return already-defined input_vars in default graph"""
""" Find and return already-defined input_vars in default graph"""
input_var_names
=
[
k
.
name
for
k
in
self
.
_get_input_vars
()]
input_var_names
=
[
k
.
name
for
k
in
self
.
_get_input_vars
()]
return
get_tensors_by_names
(
input_var_names
)
return
get_tensors_by_names
(
input_var_names
)
...
@@ -104,11 +111,10 @@ class ModelFromMetaGraph(ModelDesc):
...
@@ -104,11 +111,10 @@ class ModelFromMetaGraph(ModelDesc):
assert
k
in
all_coll
,
\
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
"Collection {} not found in metagraph!"
.
format
(
k
)
def
get_input_vars
(
self
):
return
tf
.
get_collection
(
INPUT_VARS_KEY
)
def
_get_input_vars
(
self
):
def
_get_input_vars
(
self
):
raise
NotImplementedError
(
"Shouldn't call here"
)
col
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
col
=
[
InputVar
.
loads
(
v
)
for
v
in
col
]
return
col
def
_build_graph
(
self
,
_
,
__
):
def
_build_graph
(
self
,
_
,
__
):
""" Do nothing. Graph was imported already """
""" Do nothing. Graph was imported already """
...
...
tensorpack/tfutils/summary.py
View file @
f0273bee
...
@@ -105,7 +105,7 @@ def add_moving_summary(v, *args):
...
@@ -105,7 +105,7 @@ def add_moving_summary(v, *args):
@
memoized
@
memoized
def
summary_moving_average
(
tensors
=
None
):
def
summary_moving_average
(
tensors
=
None
):
"""
"""
Create a MovingAverage op and summary for tensors
Create a MovingAverage op and
add
summary for tensors
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:param tensors: list of tf.Tensor to summary. default to the collection MOVING_SUMMARY_VARS_KEY
:returns: a op to maintain these average.
:returns: a op to maintain these average.
"""
"""
...
...
tensorpack/train/__init__.py
View file @
f0273bee
...
@@ -8,7 +8,7 @@ import os.path
...
@@ -8,7 +8,7 @@ import os.path
def
global_import
(
name
):
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
[]
for
k
in
lst
:
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
globals
()[
k
]
=
p
.
__dict__
[
k
]
del
globals
()[
name
]
del
globals
()[
name
]
...
...
tensorpack/train/base.py
View file @
f0273bee
...
@@ -24,20 +24,25 @@ class StopTraining(BaseException):
...
@@ -24,20 +24,25 @@ class StopTraining(BaseException):
pass
pass
class
Trainer
(
object
):
class
Trainer
(
object
):
"""
""" Base class for a trainer."""
Base class for a trainer.
Available Attritbutes:
stat_holder: a `StatHolder` instance
summary_writer: a `tf.SummaryWriter`
summary_op: a `tf.Operation` which returns summary string
config: a `TrainConfig`
model: a `ModelDesc`
sess: a `tf.Session`
coord: a `tf.train.Coordinator`
"""
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
"""a `StatHolder` instance"""
stat_holder
=
None
"""`tf.SummaryWriter`"""
summary_writer
=
None
"""a tf.Tensor which returns summary string"""
summary_op
=
None
""" TrainConfig """
config
=
None
""" a ModelDesc"""
model
=
None
""" the current session"""
sess
=
None
""" the `tf.train.Coordinator` """
coord
=
None
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
"""
:param config: a `TrainConfig` instance
:param config: a `TrainConfig` instance
...
...
tensorpack/train/multigpu.py
View file @
f0273bee
...
@@ -147,6 +147,7 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
...
@@ -147,6 +147,7 @@ class AsyncMultiGPUTrainer(QueueInputTrainerBase,
for
th
in
self
.
training_threads
:
for
th
in
self
.
training_threads
:
th
.
pause
()
th
.
pause
()
try
:
try
:
if
self
.
config
.
tower
>
1
:
async_step_total_cnt
=
int
(
re
.
findall
(
async_step_total_cnt
=
int
(
re
.
findall
(
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
self
.
write_scalar_summary
(
self
.
write_scalar_summary
(
...
...
tensorpack/train/queue.py
View file @
f0273bee
...
@@ -63,9 +63,11 @@ class QueueInputTrainerBase(FeedlessTrainer):
...
@@ -63,9 +63,11 @@ class QueueInputTrainerBase(FeedlessTrainer):
def
_build_enque_thread
(
self
,
input_queue
=
None
):
def
_build_enque_thread
(
self
,
input_queue
=
None
):
""" create a thread that keeps filling the queue """
""" create a thread that keeps filling the queue """
self
.
input_vars
=
self
.
model
.
get_input_vars
()
self
.
input_vars
=
self
.
model
.
get_input_vars
()
assert
len
(
self
.
input_vars
)
>
0
,
"QueueInput can only be used with input placeholders!"
if
input_queue
is
None
:
if
input_queue
is
None
:
self
.
input_queue
=
tf
.
FIFOQueue
(
self
.
input_queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
50
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
else
:
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
input_th
=
EnqueueThread
(
self
)
input_th
=
EnqueueThread
(
self
)
...
...
tensorpack/train/trainer.py
View file @
f0273bee
...
@@ -125,6 +125,7 @@ class FeedlessTrainer(Trainer):
...
@@ -125,6 +125,7 @@ class FeedlessTrainer(Trainer):
""" return a list of actual input tensors.
""" return a list of actual input tensors.
Always return new tensors (for multi tower) if called mutliple times.
Always return new tensors (for multi tower) if called mutliple times.
"""
"""
pass
class
SingleCostFeedlessTrainer
(
FeedlessTrainer
):
class
SingleCostFeedlessTrainer
(
FeedlessTrainer
):
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment