Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e791b9a5
Commit
e791b9a5
authored
Jul 31, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
check config.data/config.model in trainers
parent
784e2b7b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
13 additions
and
4 deletions
+13
-4
docs/tutorial/model.md
docs/tutorial/model.md
+2
-1
tensorpack/train/config.py
tensorpack/train/config.py
+2
-1
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+6
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-0
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-0
tensorpack/train/simple.py
tensorpack/train/simple.py
+1
-0
No files found.
docs/tutorial/model.md
View file @
e791b9a5
...
@@ -10,13 +10,14 @@ class MyModel(ModelDesc):
...
@@ -10,13 +10,14 @@ class MyModel(ModelDesc):
return
[
InputDesc
(
...
),
InputDesc
(
...
)]
return
[
InputDesc
(
...
),
InputDesc
(
...
)]
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
tensorA
,
tensorB
=
inputs
# build the graph
# build the graph
def
_get_optimizer
(
self
):
def
_get_optimizer
(
self
):
return
tf
.
train
.
GradientDescentOptimizer
(
0.1
)
return
tf
.
train
.
GradientDescentOptimizer
(
0.1
)
```
```
Basically,
`_get_inputs`
should define the metainfo of all the possible placeholder
s your graph may need.
`_get_inputs`
should define the metainfo of all the input
s your graph may need.
`_build_graph`
should add tensors/operations to the graph, where
`_build_graph`
should add tensors/operations to the graph, where
the argument
`inputs`
is the list of input tensors matching
`_get_inputs`
.
the argument
`inputs`
is the list of input tensors matching
`_get_inputs`
.
...
...
tensorpack/train/config.py
View file @
e791b9a5
...
@@ -95,8 +95,9 @@ class TrainConfig(object):
...
@@ -95,8 +95,9 @@ class TrainConfig(object):
monitors
=
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
monitors
=
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
self
.
monitors
=
monitors
self
.
monitors
=
monitors
if
model
is
not
None
:
assert_type
(
model
,
ModelDesc
)
self
.
model
=
model
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
if
session_init
is
None
:
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
session_init
=
JustCurrentSession
()
...
...
tensorpack/train/distributed.py
View file @
e791b9a5
...
@@ -52,7 +52,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
...
@@ -52,7 +52,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
config (TrainConfig): the train config.
config (TrainConfig): the train config.
server (tf.train.Server): the server object with ps and workers
server (tf.train.Server): the server object with ps and workers
"""
"""
assert
config
.
data
is
not
None
and
config
.
model
is
not
None
self
.
server
=
server
self
.
server
=
server
server_def
=
server
.
server_def
server_def
=
server
.
server_def
self
.
cluster
=
tf
.
train
.
ClusterSpec
(
server_def
.
cluster
)
self
.
cluster
=
tf
.
train
.
ClusterSpec
(
server_def
.
cluster
)
...
@@ -83,7 +83,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
...
@@ -83,7 +83,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
,
devices
):
def
_average_grads
(
tower_grads
,
devices
):
"""
"""
Average grad with round-robin device selection.
Average grads from towers.
The device where the average happens is chosen with round-robin.
Args:
Args:
tower_grads: Ngpu x Nvar x 2
tower_grads: Ngpu x Nvar x 2
...
@@ -111,6 +112,9 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
...
@@ -111,6 +112,9 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
def
_apply_shadow_vars
(
avg_grads
):
def
_apply_shadow_vars
(
avg_grads
):
"""
"""
Replace variables in avg_grads by shadow variables.
Replace variables in avg_grads by shadow variables.
Args:
avg_grads: list of (grad, var) tuples
"""
"""
ps_var_grads
=
[]
ps_var_grads
=
[]
for
grad
,
var
in
avg_grads
:
for
grad
,
var
in
avg_grads
:
...
...
tensorpack/train/feedfree.py
View file @
e791b9a5
...
@@ -58,6 +58,7 @@ def QueueInputTrainer(config, input_queue=None):
...
@@ -58,6 +58,7 @@ def QueueInputTrainer(config, input_queue=None):
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
"""
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
not
None
:
if
config
.
data
is
not
None
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
else
:
else
:
...
...
tensorpack/train/multigpu.py
View file @
e791b9a5
...
@@ -30,6 +30,7 @@ def _check_tf_version():
...
@@ -30,6 +30,7 @@ def _check_tf_version():
def
apply_prefetch_policy
(
config
,
gpu_prefetch
=
True
):
def
apply_prefetch_policy
(
config
,
gpu_prefetch
=
True
):
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
data
is
None
and
config
.
dataflow
is
not
None
:
if
config
.
data
is
None
and
config
.
dataflow
is
not
None
:
# always use Queue prefetch
# always use Queue prefetch
config
.
data
=
QueueInput
(
config
.
dataflow
)
config
.
data
=
QueueInput
(
config
.
dataflow
)
...
...
tensorpack/train/simple.py
View file @
e791b9a5
...
@@ -27,6 +27,7 @@ class SimpleTrainer(Trainer):
...
@@ -27,6 +27,7 @@ class SimpleTrainer(Trainer):
"Got nr_tower={}, but doesn't support multigpu!"
\
"Got nr_tower={}, but doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
config
.
tower
))
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
config
.
tower
))
assert
(
config
.
data
is
not
None
or
config
.
dataflow
is
not
None
)
and
config
.
model
is
not
None
if
config
.
dataflow
is
None
:
if
config
.
dataflow
is
None
:
self
.
_input_source
=
config
.
data
self
.
_input_source
=
config
.
data
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment