Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
55f2f5da
Commit
55f2f5da
authored
Feb 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
trainer.get_predictor support tower-tensor as input
parent
6640f9bb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
8 deletions
+24
-8
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+5
-5
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+2
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+16
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+1
-1
No files found.
examples/GAN/InfoGAN-mnist.py
View file @
55f2f5da
...
...
@@ -49,12 +49,12 @@ class Model(GANModelDesc):
l
=
tf
.
reshape
(
l
,
[
-
1
,
7
,
7
,
128
])
l
=
Deconv2D
(
'deconv1'
,
l
,
[
14
,
14
,
64
],
4
,
2
,
nl
=
BNReLU
)
l
=
Deconv2D
(
'deconv2'
,
l
,
[
28
,
28
,
1
],
4
,
2
,
nl
=
tf
.
identity
)
l
=
tf
.
tanh
(
l
,
name
=
'gen'
)
l
=
tf
.
sigmoid
(
l
,
name
=
'gen'
)
return
l
def
discriminator
(
self
,
imgs
):
with
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
kernel_shape
=
4
,
stride
=
2
),
\
argscope
(
LeakyReLU
,
alpha
=
0.
1
):
argscope
(
LeakyReLU
,
alpha
=
0.
2
):
l
=
(
LinearWrap
(
imgs
)
.
Conv2D
(
'conv0'
,
64
)
.
LeakyReLU
()
...
...
@@ -72,7 +72,7 @@ class Model(GANModelDesc):
def
_build_graph
(
self
,
inputs
):
real_sample
=
inputs
[
0
]
real_sample
=
tf
.
expand_dims
(
real_sample
*
2.0
-
1
,
-
1
)
real_sample
=
tf
.
expand_dims
(
real_sample
,
-
1
)
# latent space is cat(10) x uni(1) x uni(1) x noise(NOISE_DIM)
self
.
factors
=
ProductDistribution
(
"factors"
,
[
CategoricalDistribution
(
"cat"
,
10
),
...
...
@@ -93,7 +93,7 @@ class Model(GANModelDesc):
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
)):
with
tf
.
variable_scope
(
'gen'
):
fake_sample
=
self
.
generator
(
z
)
fake_sample_viz
=
tf
.
cast
((
fake_sample
+
1
)
*
128
.0
,
tf
.
uint8
,
name
=
'viz'
)
fake_sample_viz
=
tf
.
cast
((
fake_sample
)
*
255
.0
,
tf
.
uint8
,
name
=
'viz'
)
tf
.
summary
.
image
(
'gen'
,
fake_sample_viz
,
max_outputs
=
30
)
# may need to investigate how bn stats should be updated across two discrim
...
...
@@ -164,7 +164,7 @@ def get_config():
logger
.
auto_set_dir
()
return
TrainConfig
(
dataflow
=
get_data
(),
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
(
keep_freq
=
0.1
)],
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
steps_per_epoch
=
500
,
...
...
tensorpack/callbacks/concurrency.py
View file @
55f2f5da
...
...
@@ -41,6 +41,8 @@ class StartProcOrThread(Callback):
if
not
self
.
_stop_at_last
:
return
for
k
in
self
.
_procs_threads
:
if
not
k
.
is_alive
():
continue
if
isinstance
(
k
,
mp
.
Process
):
logger
.
info
(
"Stopping {} ..."
.
format
(
k
.
name
))
k
.
terminate
()
...
...
tensorpack/train/trainer.py
View file @
55f2f5da
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
.base
import
Trainer
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
,
get_op_tensor_name
from
..tfutils.collection
import
freeze_collection
from
..predict
import
OnlinePredictor
,
build_prediction_graph
from
.input_data
import
FeedInput
...
...
@@ -35,8 +35,22 @@ class PredictorFactory(object):
if
not
self
.
tower_built
:
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
def
get_name_in_tower
(
name
):
return
PREDICT_TOWER
+
str
(
tower
)
+
'/'
+
name
def
maybe_inside_tower
(
name
):
name
=
get_op_tensor_name
(
name
)[
0
]
if
name
in
placeholder_names
:
return
name
else
:
return
get_name_in_tower
(
name
)
input_names
=
map
(
maybe_inside_tower
,
input_names
)
raw_input_vars
=
get_tensors_by_names
(
input_names
)
output_names
=
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
tower
)
+
n
for
n
in
output_names
]
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_vars
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
...
...
tensorpack/utils/timer.py
View file @
55f2f5da
...
...
@@ -42,7 +42,7 @@ def timed_operation(msg, log_start=False):
logger
.
info
(
'Start {} ...'
.
format
(
msg
))
start
=
time
.
time
()
yield
logger
.
info
(
'{} finished, time:{:.
2
f}sec.'
.
format
(
logger
.
info
(
'{} finished, time:{:.
4
f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment