Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
97f47539
Commit
97f47539
authored
Dec 30, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
speed up GAN
parent
bcee048d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
19 deletions
+27
-19
examples/GAN/GAN.py
examples/GAN/GAN.py
+9
-6
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+2
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+16
-11
No files found.
examples/GAN/GAN.py
View file @
97f47539
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
numpy
as
np
import
time
from
tensorpack
import
(
FeedfreeTrainer
,
TowerContext
,
get_global_step_var
,
QueueInput
)
from
tensorpack.tfutils.summary
import
summary_moving_average
,
add_moving_summary
...
...
@@ -21,17 +22,19 @@ class GANTrainer(FeedfreeTrainer):
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
self
.
g_min
=
self
.
config
.
optimizer
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
)
self
.
d_min
=
self
.
config
.
optimizer
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
)
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
)
self
.
gs_incr
=
tf
.
assign_add
(
get_global_step_var
(),
1
,
name
=
'global_step_incr'
)
self
.
summary_op
=
summary_moving_average
()
self
.
d_min
=
tf
.
group
(
self
.
d_min
,
self
.
summary_op
,
self
.
gs_incr
)
#self.train_op = tf.group(self.g_min, self.d_min)
with
tf
.
control_dependencies
([
self
.
g_min
]):
self
.
d_min
=
tf
.
group
(
self
.
d_min
,
self
.
summary_op
,
self
.
gs_incr
)
self
.
train_op
=
self
.
d_min
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
g_min
)
self
.
sess
.
run
(
self
.
d_min
)
self
.
sess
.
run
(
self
.
train_op
)
class
RandomZData
(
DataFlow
):
def
__init__
(
self
,
shape
):
...
...
examples/GAN/Image2Image.py
View file @
97f47539
...
...
@@ -28,7 +28,7 @@ To visualize on test set:
"""
SHAPE
=
256
BATCH
=
4
BATCH
=
1
IN_CH
=
3
OUT_CH
=
3
LAMBDA
=
100
...
...
@@ -159,7 +159,7 @@ def get_config():
dataset
=
dataset
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
beta1
=
0.5
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(
),
StatPrinter
(),
PeriodicCallback
(
ModelSaver
(),
3
),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
200
,
1e-4
)])
]),
model
=
Model
(),
...
...
tensorpack/train/feedfree.py
View file @
97f47539
...
...
@@ -50,17 +50,22 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
def
run_step
(
self
):
""" Simply run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
# debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
#if not hasattr(self, 'cnt'):
#self.cnt = 0
#else:
#self.cnt += 1
#if self.cnt % 10 == 0:
## debug-benchmark code:
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
#run_metadata=run_metadata
#)
#from tensorflow.python.client import timeline
#trace = timeline.Timeline(step_stats=run_metadata.step_stats)
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
class
SimpleFeedfreeTrainer
(
MultiPredictorTowerTrainer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment