Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dc709e94
Commit
dc709e94
authored
Oct 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update some examples
parent
0c65c338
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
43 additions
and
27 deletions
+43
-27
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+3
-2
examples/CTC-TIMIT/train-timit.py
examples/CTC-TIMIT/train-timit.py
+3
-2
examples/Char-RNN/char-rnn.py
examples/Char-RNN/char-rnn.py
+3
-2
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+3
-2
examples/DisturbLabel/mnist-disturb.py
examples/DisturbLabel/mnist-disturb.py
+2
-1
examples/DisturbLabel/svhn-disturb.py
examples/DisturbLabel/svhn-disturb.py
+2
-1
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+2
-1
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+3
-2
examples/DynamicFilterNetwork/steering-filter.py
examples/DynamicFilterNetwork/steering-filter.py
+3
-2
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+3
-2
examples/HED/hed.py
examples/HED/hed.py
+4
-2
examples/boilerplate.py
examples/boilerplate.py
+3
-2
examples/mnist-tfslim.py
examples/mnist-tfslim.py
+2
-1
examples/mnist-visualizations.py
examples/mnist-visualizations.py
+2
-1
examples/svhn-digit-convnet.py
examples/svhn-digit-convnet.py
+3
-2
tensorpack/train/interface.py
tensorpack/train/interface.py
+2
-2
No files found.
examples/A3C-Gym/train-atari.py
View file @
dc709e94
...
...
@@ -18,6 +18,7 @@ import tensorflow as tf
import
six
from
six.moves
import
queue
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.serialize
import
*
...
...
@@ -303,5 +304,5 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
trainer
=
QueueInputTrainer
if
config
.
nr_tower
==
1
else
AsyncMultiGPUTrainer
trainer
(
config
)
.
train
(
)
trainer
=
SimpleTrainer
()
if
config
.
nr_tower
==
1
else
AsyncMultiGPUTrainer
(
config
.
tower
)
launch_train_with_config
(
config
,
trainer
)
examples/CTC-TIMIT/train-timit.py
View file @
dc709e94
...
...
@@ -12,6 +12,7 @@ import operator
import
six
from
six.moves
import
map
,
range
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils.gradproc
import
SummaryGradient
,
GlobalNormClip
from
tensorpack.utils.globvars
import
globalns
as
param
...
...
@@ -94,7 +95,7 @@ def get_data(path, isTrain, stat_file):
def
get_config
(
ds_train
,
ds_test
):
return
TrainConfig
(
data
flow
=
ds_train
,
data
=
QueueInput
(
ds_train
)
,
callbacks
=
[
ModelSaver
(),
StatMonitorParamSetter
(
'learning_rate'
,
'error'
,
...
...
@@ -128,4 +129,4 @@ if __name__ == '__main__':
config
=
get_config
(
ds_train
,
ds_test
)
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/Char-RNN/char-rnn.py
View file @
dc709e94
...
...
@@ -12,6 +12,7 @@ import operator
import
six
from
six.moves
import
map
,
range
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils
import
symbolic_functions
,
summary
,
optimizer
from
tensorpack.tfutils.gradproc
import
GlobalNormClip
...
...
@@ -116,7 +117,7 @@ def get_config():
ds
=
BatchData
(
ds
,
param
.
batch_size
)
return
TrainConfig
(
data
flow
=
ds
,
data
=
QueueInput
(
ds
)
,
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
25
,
2e-4
)])
...
...
@@ -190,4 +191,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/DeepQNetwork/DQN.py
View file @
dc709e94
...
...
@@ -16,6 +16,7 @@ import multiprocessing
import
threading
from
collections
import
deque
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
*
import
tensorflow
as
tf
...
...
@@ -105,7 +106,7 @@ def get_config():
)
return
TrainConfig
(
data
flow
=
expreplay
,
data
=
QueueInput
(
expreplay
)
,
model
=
Model
(),
callbacks
=
[
ModelSaver
(),
...
...
@@ -166,4 +167,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/DisturbLabel/mnist-disturb.py
View file @
dc709e94
...
...
@@ -8,6 +8,7 @@ import os
import
sys
import
argparse
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
import
tensorflow
as
tf
...
...
@@ -65,4 +66,4 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/DisturbLabel/svhn-disturb.py
View file @
dc709e94
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
os
import
imp
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
...
...
@@ -56,4 +57,4 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
config
=
get_config
(
args
.
prob
)
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/DoReFa-Net/alexnet-dorefa.py
View file @
dc709e94
...
...
@@ -11,6 +11,7 @@ import multiprocessing
import
os
import
sys
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
...
...
@@ -322,4 +323,4 @@ if __name__ == '__main__':
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
nr_tower
=
nr_tower
SyncMultiGPUTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
configi
,
SyncMultiGPUTrainer
(
list
(
range
(
nr_tower
)))
)
examples/DoReFa-Net/svhn-digit-dorefa.py
View file @
dc709e94
...
...
@@ -7,6 +7,7 @@ import argparse
import
numpy
as
np
import
os
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
...
...
@@ -163,7 +164,7 @@ def get_config():
data_test
=
BatchData
(
data_test
,
128
,
remainder
=
True
)
return
TrainConfig
(
data
flow
=
data_train
,
data
=
QueueInput
(
data_train
)
,
callbacks
=
[
ModelSaver
(),
InferenceRunner
(
data_test
,
...
...
@@ -183,4 +184,4 @@ if __name__ == '__main__':
BITW
,
BITA
,
BITG
=
map
(
int
,
args
.
dorefa
.
split
(
','
))
config
=
get_config
()
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/DynamicFilterNetwork/steering-filter.py
View file @
dc709e94
...
...
@@ -6,10 +6,12 @@ import argparse
import
numpy
as
np
import
tensorflow
as
tf
import
cv2
import
os
from
scipy.signal
import
convolve2d
from
six.moves
import
range
,
zip
import
multiprocessing
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.utils
import
logger
from
tensorpack.utils.viz
import
*
...
...
@@ -262,5 +264,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
nr_tower
=
NR_GPU
SyncMultiGPUTrainer
(
config
)
.
train
()
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
list
(
range
(
NR_GPU
))))
examples/FasterRCNN/train.py
View file @
dc709e94
...
...
@@ -13,6 +13,7 @@ import numpy as np
import
json
import
tensorflow
as
tf
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
add_moving_summary
...
...
@@ -300,6 +301,6 @@ if __name__ == '__main__':
steps_per_epoch
=
stepnum
,
max_epoch
=
205000
*
factor
//
stepnum
,
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
nr_tower
=
get_nr_gpu
()
)
SyncMultiGPUTrainerReplicated
(
cfg
,
gpu_prefetch
=
False
)
.
train
()
trainer
=
SyncMultiGPUTrainerReplicated
(
range
(
len
(
get_nr_gpu
())))
launch_train_with_config
(
cfg
,
trainer
)
examples/HED/hed.py
View file @
dc709e94
...
...
@@ -11,6 +11,7 @@ from six.moves import zip
import
os
import
sys
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.dataflow
import
dataset
...
...
@@ -231,5 +232,6 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
config
.
nr_tower
=
max
(
get_nr_gpu
(),
1
)
SyncMultiGPUTrainer
(
config
)
.
train
()
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
range
(
max
(
get_nr_gpu
(),
1
))))
examples/boilerplate.py
View file @
dc709e94
...
...
@@ -5,6 +5,7 @@
import
os
import
argparse
import
tensorflow
as
tf
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
"""
...
...
@@ -51,7 +52,7 @@ def get_config():
return
TrainConfig
(
model
=
Model
(),
data
flow
=
ds_train
,
data
=
QueueInput
(
ds_train
)
,
callbacks
=
[
ModelSaver
(),
InferenceRunner
(
ds_test
,
[
ScalarStats
(
'total_costs'
)]),
...
...
@@ -77,4 +78,4 @@ if __name__ == '__main__':
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
SyncMultiGPUTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/mnist-tfslim.py
View file @
dc709e94
...
...
@@ -14,6 +14,7 @@ the only differences are:
2. use slim names to summarize weights
"""
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
import
tensorflow
as
tf
...
...
@@ -101,4 +102,4 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
SimpleTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/mnist-visualizations.py
View file @
dc709e94
...
...
@@ -11,6 +11,7 @@ import argparse
MNIST ConvNet example with weights/activations visualization.
"""
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
import
tensorflow
as
tf
...
...
@@ -161,4 +162,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
SimpleTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/svhn-digit-convnet.py
View file @
dc709e94
...
...
@@ -7,6 +7,7 @@ import argparse
import
numpy
as
np
import
os
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
prediction_incorrect
from
tensorpack.dataflow
import
dataset
...
...
@@ -99,7 +100,7 @@ def get_config():
return
TrainConfig
(
model
=
Model
(),
data
flow
=
data_train
,
data
=
QueueInput
(
data_train
)
,
callbacks
=
[
ModelSaver
(),
InferenceRunner
(
data_test
,
...
...
@@ -125,4 +126,4 @@ if __name__ == '__main__':
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
tensorpack/train/interface.py
View file @
dc709e94
...
...
@@ -43,12 +43,12 @@ def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
def
launch_train_with_config
(
config
,
trainer
):
"""
Train with a :class:`TrainConfig` and a
new version of
:class:`Trainer`, to
Train with a :class:`TrainConfig` and a :class:`Trainer`, to
mimic the old training interface.
Args:
config (TrainConfig):
trainer (Trainer): an instance of
the new t
rainer
trainer (Trainer): an instance of
a SingleCostT
rainer
Examples:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment