Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1898cd3c
Commit
1898cd3c
authored
Nov 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix bug
parent
6306da7e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
48 additions
and
16 deletions
+48
-16
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+1
-1
examples/cifar-convnet.py
examples/cifar-convnet.py
+8
-4
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+25
-0
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-0
tensorpack/train/base.py
tensorpack/train/base.py
+4
-6
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+4
-3
No files found.
examples/OpenAIGym/train-atari.py
View file @
1898cd3c
...
...
@@ -204,7 +204,7 @@ def get_config():
HumanHyperParamSetter
(
'entropy_beta'
),
HumanHyperParamSetter
(
'explore_factor'
),
master
,
StartProcOrThread
(
master
)
StartProcOrThread
(
master
)
,
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
]),
session_config
=
get_default_sess_config
(
0.5
),
...
...
examples/cifar-convnet.py
View file @
1898cd3c
...
...
@@ -10,6 +10,7 @@ import os
from
tensorpack
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
*
from
tensorpack.utils.gpu
import
get_nr_gpu
"""
A small convnet model for Cifar10 or Cifar100 dataset.
...
...
@@ -152,7 +153,10 @@ if __name__ == '__main__':
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
()
#if args.gpu:
#config.nr_tower = len(args.gpu.split(','))
#AsyncMultiGPUTrainer(config).train()
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
==
1
:
QueueInputTrainer
(
config
)
.
train
()
else
:
SyncMultiGPUTrainer
(
config
)
.
train
()
tensorpack/callbacks/concurrency.py
0 → 100644
View file @
1898cd3c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
.base
import
Callback
from
..utils.concurrency
import
start_proc_mask_signal
from
..utils
import
logger
__all__
=
[
'StartProcOrThread'
]
class
StartProcOrThread
(
Callback
):
def
__init__
(
self
,
procs_threads
):
"""
Start extra threads and processes before training
:param procs_threads: list of processes or threads
"""
if
not
isinstance
(
procs_threads
,
list
):
procs_threads
=
[
procs_threads
]
self
.
_procs_threads
=
procs_threads
def
_before_train
(
self
):
logger
.
info
(
"Starting all threads & procs ..."
)
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_procs_threads
)
tensorpack/callbacks/group.py
View file @
1898cd3c
...
...
@@ -86,3 +86,7 @@ class Callbacks(Callback):
with
tm
.
timed_callback
(
display_name
):
cb
.
trigger_epoch
()
tm
.
log
()
def
append
(
self
,
cb
):
assert
isinstance
(
cb
,
Callback
)
self
.
cbs
.
append
(
cb
)
tensorpack/train/base.py
View file @
1898cd3c
...
...
@@ -68,7 +68,7 @@ class Trainer(object):
def
trigger_epoch
(
self
):
# by default, add this two stat
self
.
stat_holder
.
add_stat
(
'global_step'
,
self
.
global_step
)
self
.
stat_holder
.
add_stat
(
'global_step'
,
get_global_step
()
)
self
.
stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
)
# trigger subclass
...
...
@@ -88,7 +88,7 @@ class Trainer(object):
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
()
)
def
write_scalar_summary
(
self
,
name
,
val
):
self
.
summary_writer
.
add_summary
(
...
...
@@ -98,10 +98,8 @@ class Trainer(object):
def
finalize_graph
(
self
):
# some final operations that might modify the graph
get_global_step_var
()
# ensure there is such var, before finalizing the graph
logger
.
info
(
"Setup callbacks ..."
)
callbacks
=
self
.
config
.
callbacks
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
...
...
@@ -122,8 +120,8 @@ class Trainer(object):
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
try
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
callbacks
.
before_train
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
...
...
tensorpack/train/config.py
View file @
1898cd3c
...
...
@@ -74,7 +74,7 @@ class TrainConfig(object):
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
from
..callbacks.concurrency
import
StartProcOrThread
self
.
callbacks
.
cbs
.
append
(
StartProcOrThread
(
self
.
extra_threads_procs
))
self
.
callbacks
.
append
(
StartProcOrThread
(
self
.
extra_threads_procs
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
...
...
tensorpack/train/multigpu.py
View file @
1898cd3c
...
...
@@ -54,7 +54,7 @@ class MultiGPUTrainer(QueueInputTrainer):
tf
.
variable_scope
(
global_scope
,
reuse
=
idx
>
0
),
\
TowerContext
(
'tower{}'
.
format
(
idx
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
idx
))
model_inputs
=
self
.
_get_
model
_inputs
()
# each tower dequeue from input queue
model_inputs
=
self
.
_get_
dequeued
_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
)
...
...
tensorpack/train/trainer.py
View file @
1898cd3c
...
...
@@ -17,6 +17,7 @@ from ..tfutils import (get_vars_by_names, freeze_collection,
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.modelutils
import
describe_model
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
...
...
@@ -160,7 +161,7 @@ class QueueInputTrainer(Trainer):
self
.
predict_tower
=
predict_tower
or
[
0
]
self
.
dequed_inputs
=
None
def
_get_
model
_inputs
(
self
):
def
_get_
dequeued
_inputs
(
self
):
""" Dequeue a datapoint from input_queue and return"""
ret
=
self
.
input_queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
...
...
@@ -172,7 +173,7 @@ class QueueInputTrainer(Trainer):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower"""
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_
model
_inputs
()
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_
dequeued
_inputs
()
# test the overhead of queue
#with tf.device('/gpu:0'):
...
...
@@ -190,7 +191,7 @@ class QueueInputTrainer(Trainer):
def
_build_enque_thread
(
self
):
""" create a thread that keeps filling the queue """
self
.
input_th
=
EnqueueThread
(
self
)
self
.
_extra_threads_procs
.
append
(
self
.
input_th
)
self
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
input_th
)
)
def
train
(
self
):
assert
len
(
self
.
config
.
tower
)
==
1
,
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment