Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
321440af
Commit
321440af
authored
Jul 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc clean-ups
parent
8ef16f14
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
23 additions
and
155 deletions
+23
-155
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+3
-106
examples/DynamicFilterNetwork/steering-filter.py
examples/DynamicFilterNetwork/steering-filter.py
+2
-3
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+2
-2
examples/Inception/inceptionv3.py
examples/Inception/inceptionv3.py
+1
-1
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+1
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+7
-5
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+1
-32
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+1
-1
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+3
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-3
No files found.
examples/A3C-Gym/simulator.py
View file @
321440af
...
...
@@ -61,7 +61,9 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
def
__init__
(
self
,
idx
,
pipe_c2s
,
pipe_s2c
):
"""
:param idx: idx of this process
Args:
idx: idx of this process
pipe_c2s, pipe_s2c (str): name of the pipe
"""
super
(
SimulatorProcessStateExchange
,
self
)
.
__init__
(
idx
)
self
.
c2s
=
pipe_c2s
...
...
@@ -177,111 +179,6 @@ class SimulatorMaster(threading.Thread):
self
.
context
.
destroy
(
linger
=
0
)
# ------------------- the following code are not used at all. Just experimental
class
SimulatorProcessDF
(
SimulatorProcessBase
):
""" A simulator which contains a forward model itself, allowing
it to produce data points directly """
def
__init__
(
self
,
idx
,
pipe_c2s
):
super
(
SimulatorProcessDF
,
self
)
.
__init__
(
idx
)
self
.
pipe_c2s
=
pipe_c2s
def
run
(
self
):
self
.
player
=
self
.
_build_player
()
self
.
ctx
=
zmq
.
Context
()
self
.
c2s_socket
=
self
.
ctx
.
socket
(
zmq
.
PUSH
)
self
.
c2s_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
self
.
c2s_socket
.
set_hwm
(
5
)
self
.
c2s_socket
.
connect
(
self
.
pipe_c2s
)
self
.
_prepare
()
for
dp
in
self
.
get_data
():
self
.
c2s_socket
.
send
(
dumps
(
dp
),
copy
=
False
)
@
abstractmethod
def
_prepare
(
self
):
pass
@
abstractmethod
def
get_data
(
self
):
pass
class
SimulatorProcessSharedWeight
(
SimulatorProcessDF
):
""" A simulator process with an extra thread waiting for event,
and take shared weight from shm.
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
self
.
pred_config
=
pred_config
def
_prepare
(
self
):
disable_layer_logging
()
self
.
predictor
=
OfflinePredictor
(
self
.
pred_config
)
with
self
.
predictor
.
graph
.
as_default
():
vars_to_update
=
self
.
_params_to_update
()
self
.
sess_updater
=
SessionUpdate
(
self
.
predictor
.
session
,
vars_to_update
)
# TODO setup callback for explore?
self
.
predictor
.
graph
.
finalize
()
self
.
weight_lock
=
threading
.
Lock
()
# start a thread to wait for notification
def
func
():
self
.
condvar
.
acquire
()
while
True
:
self
.
condvar
.
wait
()
self
.
_trigger_evt
()
self
.
evt_th
=
threading
.
Thread
(
target
=
func
)
self
.
evt_th
.
daemon
=
True
self
.
evt_th
.
start
()
def
_trigger_evt
(
self
):
with
self
.
weight_lock
:
self
.
sess_updater
.
update
(
self
.
shared_dic
[
'params'
])
logger
.
info
(
"Updated."
)
def
_params_to_update
(
self
):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
class
WeightSync
(
Callback
):
""" Sync weight from main process to shared_dic and notify"""
def
__init__
(
self
,
condvar
,
shared_dic
):
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
def
_setup_graph
(
self
):
self
.
vars
=
self
.
_params_to_update
()
def
_params_to_update
(
self
):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
def
_before_train
(
self
):
self
.
_sync
()
def
_trigger_epoch
(
self
):
self
.
_sync
()
def
_sync
(
self
):
logger
.
info
(
"Updating weights ..."
)
dic
=
{
v
.
name
:
v
.
eval
()
for
v
in
self
.
vars
}
self
.
shared_dic
[
'params'
]
=
dic
self
.
condvar
.
acquire
()
self
.
condvar
.
notify_all
()
self
.
condvar
.
release
()
if
__name__
==
'__main__'
:
import
random
from
tensorpack.RL
import
NaiveRLEnvironment
...
...
examples/DynamicFilterNetwork/steering-filter.py
View file @
321440af
...
...
@@ -94,7 +94,6 @@ class OnlineTensorboardExport(Callback):
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
# TODO: allow arbitrary batch sizes
return
[
InputDesc
(
tf
.
float32
,
(
BATCH
,
),
'theta'
),
InputDesc
(
tf
.
float32
,
(
BATCH
,
SHAPE
,
SHAPE
),
'image'
),
InputDesc
(
tf
.
float32
,
(
BATCH
,
SHAPE
,
SHAPE
),
'gt_image'
),
...
...
@@ -120,9 +119,9 @@ class Model(ModelDesc):
logger
.
info
(
'Parameter net output: {}'
.
format
(
pred_filter
.
get_shape
()
.
as_list
()))
return
pred_filter
def
_build_graph
(
self
,
input
_var
s
):
def
_build_graph
(
self
,
inputs
):
kernel_size
=
9
theta
,
image
,
gt_image
,
gt_filter
=
input
_var
s
theta
,
image
,
gt_image
,
gt_filter
=
inputs
image
=
image
gt_image
=
gt_image
...
...
examples/GAN/ConditionalGAN-mnist.py
View file @
321440af
...
...
@@ -73,8 +73,8 @@ class Model(GANModelDesc):
.
FullyConnected
(
'fct'
,
1
,
nl
=
tf
.
identity
)())
return
l
def
_build_graph
(
self
,
input
_var
s
):
image_pos
,
y
=
input
_var
s
def
_build_graph
(
self
,
inputs
):
image_pos
,
y
=
inputs
image_pos
=
tf
.
expand_dims
(
image_pos
*
2.0
-
1
,
-
1
)
y
=
tf
.
one_hot
(
y
,
10
,
name
=
'label_onehot'
)
...
...
examples/Inception/inceptionv3.py
View file @
321440af
...
...
@@ -138,7 +138,7 @@ class Model(ModelDesc):
br1
=
AvgPooling
(
'avgpool'
,
l
,
5
,
3
,
padding
=
'VALID'
)
br1
=
Conv2D
(
'conv11'
,
br1
,
128
,
1
)
shape
=
br1
.
get_shape
()
.
as_list
()
br1
=
Conv2D
(
'convout'
,
br1
,
768
,
shape
[
1
:
3
],
padding
=
'VALID'
)
# TODO gauss, stddev=0.01
br1
=
Conv2D
(
'convout'
,
br1
,
768
,
shape
[
1
:
3
],
padding
=
'VALID'
)
br1
=
FullyConnected
(
'fc'
,
br1
,
1000
,
nl
=
tf
.
identity
)
with
tf
.
variable_scope
(
'incep-17-1280a'
):
...
...
tensorpack/callbacks/hooks.py
View file @
321440af
...
...
@@ -45,7 +45,7 @@ class HookToCallback(Callback):
def
_before_train
(
self
):
sess
=
tf
.
get_default_session
()
#
TODO fix coord?
#
coord is set to None when converting
self
.
_hook
.
after_create_session
(
sess
,
None
)
def
_before_run
(
self
,
ctx
):
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
321440af
...
...
@@ -151,11 +151,13 @@ class Cifar100(CifarBase):
if
__name__
==
'__main__'
:
ds
=
Cifar10
(
'train'
)
from
tensorpack.dataflow.dftools
import
dump_dataflow_images
mean
=
ds
.
get_per_channel_mean
()
print
(
mean
)
dump_dataflow_images
(
ds
,
'/tmp/cifar'
,
100
)
# for (img, label) in ds.get_data():
# from IPython import embed; embed()
# break
import
cv2
ds
.
reset_state
()
for
i
,
dp
in
enumerate
(
ds
.
get_data
()):
if
i
==
100
:
break
img
=
dp
[
0
]
cv2
.
imwrite
(
"{:04d}.jpg"
.
format
(
i
),
img
)
tensorpack/dataflow/dftools.py
View file @
321440af
...
...
@@ -2,7 +2,6 @@
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
sys
import
os
import
multiprocessing
as
mp
from
six.moves
import
range
...
...
@@ -11,35 +10,11 @@ from .base import DataFlow
from
..utils
import
get_tqdm
,
logger
from
..utils.concurrency
import
DIE
from
..utils.serialize
import
dumps
from
..utils.fs
import
mkdir_p
__all__
=
[
'dump_dataflow_
images'
,
'dump_dataflow_
to_process_queue'
,
__all__
=
[
'dump_dataflow_to_process_queue'
,
'dump_dataflow_to_lmdb'
,
'dump_dataflow_to_tfrecord'
]
def
dump_dataflow_images
(
df
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a DataFlow to a directory.
Args:
df (DataFlow): the DataFlow to dump.
dirname (str): name of the directory.
max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
"""
# TODO pass a name_func to write label as filename?
mkdir_p
(
dirname
)
if
max_count
is
None
:
max_count
=
sys
.
maxint
df
.
reset_state
()
for
i
,
dp
in
enumerate
(
df
.
get_data
()):
if
i
%
100
==
0
:
print
(
i
)
if
i
>
max_count
:
return
img
=
dp
[
index
]
cv2
.
imwrite
(
os
.
path
.
join
(
dirname
,
"{}.jpg"
.
format
(
i
)),
img
)
def
dump_dataflow_to_process_queue
(
df
,
size
,
nr_consumer
):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
...
...
@@ -160,9 +135,3 @@ try:
except
ImportError
:
dump_dataflow_to_tfrecord
=
create_dummy_func
(
# noqa
'dump_dataflow_to_tfrecord'
,
'tensorflow'
)
try
:
import
cv2
except
ImportError
:
dump_dataflow_images
=
create_dummy_func
(
# noqa
'dump_dataflow_images'
,
'cv2'
)
tensorpack/graph_builder/model_desc.py
View file @
321440af
...
...
@@ -107,7 +107,7 @@ class ModelDescBase(object):
:returns: a list of InputDesc
"""
# TODO only use InputSource in the future? Now
mainly used in predict/
# TODO only use InputSource in the future? Now
only used in predictor_factory
def
build_graph
(
self
,
inputs
):
"""
Build the whole symbolic graph.
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
321440af
...
...
@@ -67,6 +67,8 @@ class PredictorFactory(object):
input
.
setup
(
self
.
_model
.
get_inputs_desc
())
input
=
input
.
get_input_tensors
()
assert
isinstance
(
input
,
(
list
,
tuple
)),
input
# TODO still using tensors here instead of inputsource
# can be fixed after having towertensorhandle inside modeldesc
self
.
_model
.
build_graph
(
input
)
desc_names
=
[
k
.
name
for
k
in
self
.
_model
.
get_inputs_desc
()]
...
...
@@ -88,7 +90,7 @@ class PredictorFactory(object):
tower
=
self
.
_towers
[
tower
]
device
=
'/gpu:{}'
.
format
(
tower
)
if
tower
>=
0
else
'/cpu:0'
# use a previously-built tower
# TODO conflict with inference runner??
# TODO c
heck c
onflict with inference runner??
if
tower_name
not
in
self
.
_names_built
:
with
tf
.
variable_scope
(
self
.
_vs_name
,
reuse
=
True
):
handle
=
self
.
build
(
tower_name
,
device
)
...
...
tensorpack/tfutils/sessinit.py
View file @
321440af
...
...
@@ -196,15 +196,14 @@ class DictRestore(SessionInit):
self
.
prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_run_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
# TODO
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
variable_names
=
set
([
k
.
name
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
intersect
=
variable_names
&
param_names
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
mismatch
=
MismatchLogger
(
'graph'
,
'dict'
)
for
k
in
sorted
(
variable_names
-
param_names
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment