Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f80843dc
Commit
f80843dc
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
distinguish between sess.run call and run_step call. fix WGAN examples. (#147)
parent
eee05770
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
60 additions
and
45 deletions
+60
-45
examples/GAN/GAN.py
examples/GAN/GAN.py
+0
-4
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+2
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+0
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-1
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+36
-13
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+3
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+2
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+15
-15
No files found.
examples/GAN/GAN.py
View file @
f80843dc
...
@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase):
...
@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
)
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
)
self
.
train_op
=
self
.
d_min
self
.
train_op
=
self
.
d_min
def
run_step
(
self
):
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
())
return
ret
[
1
:]
class
RandomZData
(
DataFlow
):
class
RandomZData
(
DataFlow
):
def
__init__
(
self
,
shape
):
def
__init__
(
self
,
shape
):
...
...
examples/GAN/WGAN-CelebA.py
View file @
f80843dc
...
@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
...
@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def
run_step
(
self
):
def
run_step
(
self
):
for
k
in
range
(
5
):
for
k
in
range
(
5
):
self
.
sess
.
run
(
self
.
d_min
)
self
.
monitored_sess
.
run
(
self
.
d_min
)
ret
=
self
.
sess
.
run
([
self
.
g_min
]
+
self
.
get_extra_fetches
())
self
.
monitored_sess
.
run
(
self
.
g_min
)
return
ret
[
1
:]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/callbacks/base.py
View file @
f80843dc
...
@@ -114,11 +114,6 @@ class Callback(object):
...
@@ -114,11 +114,6 @@ class Callback(object):
def
epoch_num
(
self
):
def
epoch_num
(
self
):
return
self
.
trainer
.
epoch_num
return
self
.
trainer
.
epoch_num
@
property
def
local_step
(
self
):
# inside trainer, we're still in the 'local_step' loop, so the number is off by 1
return
self
.
trainer
.
local_step
+
1
@
property
@
property
def
global_step
(
self
):
def
global_step
(
self
):
return
self
.
trainer
.
global_step
return
self
.
trainer
.
global_step
...
...
tensorpack/callbacks/group.py
View file @
f80843dc
...
@@ -102,7 +102,6 @@ class Callbacks(Callback):
...
@@ -102,7 +102,6 @@ class Callbacks(Callback):
traceback
.
print_exc
()
traceback
.
print_exc
()
def
get_hooks
(
self
):
def
get_hooks
(
self
):
# TODO skip
return
[
CallbackHook
(
cb
)
for
cb
in
self
.
cbs
]
return
[
CallbackHook
(
cb
)
for
cb
in
self
.
cbs
]
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
...
tensorpack/callbacks/steps.py
View file @
f80843dc
...
@@ -12,7 +12,9 @@ import tqdm
...
@@ -12,7 +12,9 @@ import tqdm
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.naming
import
(
GLOBAL_STEP_INCR_OP_NAME
,
from
..utils.naming
import
(
GLOBAL_STEP_INCR_OP_NAME
,
LOCAL_STEP_OP_NAME
)
LOCAL_STEP_OP_NAME
)
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
,
get_global_step_value
from
..tfutils.common
import
(
get_op_tensor_name
,
get_global_step_var
,
get_global_step_value
,
get_op_or_tensor_by_name
)
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'StepTensorPrinter'
,
'MaintainStepCounter'
,
__all__
=
[
'StepTensorPrinter'
,
'MaintainStepCounter'
,
...
@@ -33,8 +35,11 @@ class StepTensorPrinter(Callback):
...
@@ -33,8 +35,11 @@ class StepTensorPrinter(Callback):
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!"
)
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!"
)
self
.
_names
=
names
self
.
_names
=
names
def
_before_train
(
self
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
def
_extra_fetches
(
self
):
def
_extra_fetches
(
self
):
return
self
.
_
nam
es
return
self
.
_
fetch
es
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
,
*
args
):
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
...
@@ -63,9 +68,15 @@ class MaintainStepCounter(Callback):
...
@@ -63,9 +68,15 @@ class MaintainStepCounter(Callback):
gs_val
=
get_global_step_value
()
gs_val
=
get_global_step_value
()
if
gs_val
!=
0
:
if
gs_val
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
self
.
_last_updated
=
self
.
trainer
.
local_step
def
_extra_fetches
(
self
):
def
_extra_fetches
(
self
):
return
[
self
.
gs_incr_var
.
op
]
# increase global_step, when trainer.local_step changed
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
trainer
.
local_step
return
[
self
.
gs_incr_var
.
op
]
else
:
return
[]
class
ProgressBar
(
Callback
):
class
ProgressBar
(
Callback
):
...
@@ -80,21 +91,33 @@ class ProgressBar(Callback):
...
@@ -80,21 +91,33 @@ class ProgressBar(Callback):
self
.
_names
=
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
names
]
self
.
_names
=
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
names
]
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
def
_extra_fetches
(
self
):
return
self
.
_names
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
if
len
(
self
.
_names
):
if
len
(
self
.
_names
):
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_extra_fetches
(
self
):
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
# local_step == number of steps that have finished in this epoch
self
.
_last_updated
=
self
.
trainer
.
local_step
if
self
.
trainer
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
else
:
self
.
_bar
.
update
()
# XXX TODO move this to trigger_step after rename
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
return
self
.
_fetches
else
:
return
[]
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
,
*
args
):
if
self
.
local_step
==
1
:
if
len
(
args
):
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
if
len
(
self
.
_names
):
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
self
.
_bar
.
update
()
if
self
.
local_step
==
self
.
_total
:
self
.
_bar
.
close
()
tensorpack/callbacks/trigger.py
View file @
f80843dc
...
@@ -34,7 +34,9 @@ class PeriodicTrigger(ProxyCallback):
...
@@ -34,7 +34,9 @@ class PeriodicTrigger(ProxyCallback):
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
,
*
args
):
if
self
.
_step_k
is
None
:
if
self
.
_step_k
is
None
:
return
return
if
self
.
local_step
%
self
.
_step_k
==
0
:
# trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished
if
(
self
.
trainer
.
local_step
+
1
)
%
self
.
_step_k
==
0
:
self
.
cb
.
trigger
()
self
.
cb
.
trigger
()
def
_trigger_epoch
(
self
,
*
args
):
def
_trigger_epoch
(
self
,
*
args
):
...
...
tensorpack/models/conv2d.py
View file @
f80843dc
...
@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape,
...
@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape,
for
k
in
out_shape
:
for
k
in
out_shape
:
if
not
isinstance
(
k
,
int
):
if
not
isinstance
(
k
,
int
):
raise
ValueError
(
"[Deconv2D] out_shape {} is invalid!"
.
format
(
k
))
raise
ValueError
(
"[Deconv2D] out_shape {} is invalid!"
.
format
(
k
))
out_channel
=
out_shape
[
channel_axis
]
out_channel
=
out_shape
[
channel_axis
-
1
]
# out_shape doesn't have batch
shp3_static
=
shp3_dyn
=
out_shape
shp3_static
=
shp3_dyn
=
out_shape
filter_shape
=
kernel_shape
+
[
out_channel
,
in_channel
]
filter_shape
=
kernel_shape
+
[
out_channel
,
in_channel
]
...
...
tensorpack/tfutils/common.py
View file @
f80843dc
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
map
from
..utils.naming
import
(
from
..utils.naming
import
(
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_VAR_NAME
,
...
@@ -133,7 +134,7 @@ def get_op_or_tensor_by_name(name):
...
@@ -133,7 +134,7 @@ def get_op_or_tensor_by_name(name):
if
not
isinstance
(
name
,
list
):
if
not
isinstance
(
name
,
list
):
return
f
(
name
)
return
f
(
name
)
else
:
else
:
return
map
(
f
,
name
)
return
list
(
map
(
f
,
name
)
)
def
get_name_scope_name
():
def
get_name_scope_name
():
...
...
tensorpack/train/base.py
View file @
f80843dc
...
@@ -54,7 +54,7 @@ class Trainer(object):
...
@@ -54,7 +54,7 @@ class Trainer(object):
self
.
model
=
config
.
model
self
.
model
=
config
.
model
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
local_step
=
0
self
.
local_step
=
-
1
def
train
(
self
):
def
train
(
self
):
""" Start training """
""" Start training """
...
...
tensorpack/train/feedfree.py
View file @
f80843dc
...
@@ -46,21 +46,6 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -46,21 +46,6 @@ class FeedfreeTrainerBase(Trainer):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
cost
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
True
)
return
cost
,
grads
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
""" Simply run ``self.train_op``, which minimizes the cost."""
self
.
monitored_sess
.
run
(
self
.
train_op
)
self
.
monitored_sess
.
run
(
self
.
train_op
)
...
@@ -82,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -82,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
# import sys; sys.exit()
# import sys; sys.exit()
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
cost
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
True
)
return
cost
,
grads
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment