Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4adbaa94
Commit
4adbaa94
authored
Mar 04, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve docs
parent
43f7ca75
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
32 additions
and
39 deletions
+32
-39
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+3
-3
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+4
-2
tensorpack/callbacks/misc.py
tensorpack/callbacks/misc.py
+2
-1
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+1
-1
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+2
-9
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+4
-3
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+16
-20
No files found.
examples/ResNet/imagenet-resnet.py
View file @
4adbaa94
...
...
@@ -80,7 +80,7 @@ def get_config(model, fake=False):
EstimatedTimeLeft
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
30
,
BASE_LR
*
1e-1
),
(
60
,
BASE_LR
*
1e-2
),
(
85
,
BASE_LR
*
1e-3
),
(
95
,
BASE_LR
*
1e-4
),
(
105
,
BASE_LR
*
1e-5
)]),
(
90
,
BASE_LR
*
1e-3
),
(
100
,
BASE_LR
*
1e-4
)]),
]
if
BASE_LR
>
0.1
:
callbacks
.
append
(
...
...
@@ -102,7 +102,7 @@ def get_config(model, fake=False):
dataflow
=
dataset_train
,
callbacks
=
callbacks
,
steps_per_epoch
=
100
if
args
.
fake
else
1280000
//
args
.
batch
,
max_epoch
=
1
10
,
max_epoch
=
1
05
,
)
...
...
@@ -115,7 +115,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--data_format'
,
help
=
'specify NCHW or NHWC'
,
type
=
str
,
default
=
'NCHW'
)
parser
.
add_argument
(
'-d'
,
'--depth'
,
help
=
'resnet depth'
,
type
=
int
,
default
=
18
,
choices
=
[
18
,
34
,
50
,
101
,
152
])
type
=
int
,
default
=
50
,
choices
=
[
18
,
34
,
50
,
101
,
152
])
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--batch'
,
default
=
256
,
type
=
int
,
help
=
'total batch size. 32 per GPU gives best accuracy, higher values should be similarly good'
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
4adbaa94
...
...
@@ -78,10 +78,8 @@ class InferenceRunnerBase(Callback):
try
:
self
.
_size
=
input
.
size
()
logger
.
info
(
"InferenceRunner will eval {} iterations"
.
format
(
input
.
size
()))
except
NotImplementedError
:
self
.
_size
=
0
logger
.
warn
(
"InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!"
)
self
.
_hooks
=
[]
...
...
@@ -95,6 +93,10 @@ class InferenceRunnerBase(Callback):
def
_before_train
(
self
):
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_input_callbacks
.
before_train
()
if
self
.
_size
>
0
:
logger
.
info
(
"InferenceRunner will eval {} iterations"
.
format
(
self
.
_size
))
else
:
logger
.
warn
(
"InferenceRunner got an input with unknown size! It will iterate until OutOfRangeError!"
)
def
_after_train
(
self
):
self
.
_input_callbacks
.
after_train
()
...
...
tensorpack/callbacks/misc.py
View file @
4adbaa94
...
...
@@ -94,4 +94,5 @@ class EstimatedTimeLeft(Callback):
average_epoch_time
=
np
.
mean
(
self
.
_times
)
time_left
=
(
self
.
_max_epoch
-
self
.
epoch_num
)
*
average_epoch_time
if
time_left
>
0
:
logger
.
info
(
"Estimated Time Left: "
+
humanize_time_delta
(
time_left
))
tensorpack/dataflow/dataset/ilsvrc.py
View file @
4adbaa94
...
...
@@ -114,7 +114,7 @@ def _guess_dir_structure(dir):
else
:
dir_structure
=
'original'
logger
.
info
(
"
Assuming directory {} has {}
structure."
.
format
(
"
[ILSVRC12] Assuming directory {} has '{}'
structure."
.
format
(
dir
,
dir_structure
))
return
dir_structure
...
...
tensorpack/graph_builder/training.py
View file @
4adbaa94
...
...
@@ -12,7 +12,6 @@ from six.moves import zip, range
from
..utils
import
logger
from
..tfutils.tower
import
TowerContext
from
..tfutils.common
import
get_tf_version_number
from
..tfutils.gradproc
import
ScaleGradient
from
.utils
import
(
...
...
@@ -39,16 +38,10 @@ class DataParallelBuilder(GraphBuilder):
towers(list[int]): list of GPU ids.
"""
if
len
(
towers
)
>
1
:
logger
.
info
(
"Training a model of {} towers"
.
format
(
len
(
towers
)))
DataParallelBuilder
.
_check_tf_version
()
logger
.
info
(
"[DataParallel] Training a model of {} towers."
.
format
(
len
(
towers
)))
self
.
towers
=
towers
@
staticmethod
def
_check_tf_version
():
assert
get_tf_version_number
()
>=
1.1
,
\
"TF version {} is too old to run multi GPU training!"
.
format
(
tf
.
VERSION
)
@
staticmethod
def
_check_grad_list
(
grad_list
):
"""
...
...
@@ -103,7 +96,7 @@ class DataParallelBuilder(GraphBuilder):
index
=
idx
,
vs_name
=
tower_names
[
idx
]
if
usevs
else
''
):
if
len
(
str
(
device
))
<
10
:
# a device function doesn't have good string description
logger
.
info
(
"Building graph for training tower {} on device {}..."
.
format
(
idx
,
device
))
logger
.
info
(
"Building graph for training tower {} on device {}
..."
.
format
(
idx
,
device
))
else
:
logger
.
info
(
"Building graph for training tower {} ..."
.
format
(
idx
))
...
...
tensorpack/models/regularize.py
View file @
4adbaa94
...
...
@@ -78,8 +78,8 @@ def regularize_cost(regex, func, name='regularize_cost'):
return
name
[
prefixlen
:]
return
name
names
=
list
(
map
(
f
,
names
))
logger
.
info
(
"regularize_cost()
found
{} tensors."
.
format
(
len
(
names
)))
_log_once
(
"
Applying regularizer for
{}"
.
format
(
', '
.
join
(
names
)))
logger
.
info
(
"regularize_cost()
applying regularizers on
{} tensors."
.
format
(
len
(
names
)))
_log_once
(
"
The following tensors will be regularized:
{}"
.
format
(
', '
.
join
(
names
)))
return
tf
.
add_n
(
costs
,
name
=
name
)
...
...
@@ -106,7 +106,8 @@ def regularize_cost_from_collection(name='regularize_cost'):
else
:
losses
=
tf
.
get_collection
(
tf
.
GraphKeys
.
REGULARIZATION_LOSSES
)
if
len
(
losses
)
>
0
:
logger
.
info
(
"regularize_cost_from_collection() found {} tensors in REGULARIZATION_LOSSES."
.
format
(
len
(
losses
)))
logger
.
info
(
"regularize_cost_from_collection() applying regularizers on "
"{} tensors in REGULARIZATION_LOSSES."
.
format
(
len
(
losses
)))
reg_loss
=
tf
.
add_n
(
losses
,
name
=
name
)
return
reg_loss
else
:
...
...
tensorpack/utils/utils.py
View file @
4adbaa94
...
...
@@ -24,42 +24,38 @@ def humanize_time_delta(sec):
"""Humanize timedelta given in seconds
Args:
sec (float): time difference in seconds.
sec (float): time difference in seconds.
Must be positive.
Examples:
Returns:
str - time difference as a readable string
Several time differences as a human readable string
Examples:
.. code-block:: python
print humanize_seconds(1) # 1 second
print humanize_seconds(60 + 1) # 1 minute 1 second
print humanize_seconds(87.6) # 1 minute 27 seconds
print humanize_seconds(0.01) # 0.01 seconds
print humanize_seconds(60 * 60 + 1) # 1 hour 0 minutes 1 second
print humanize_seconds(60 * 60 * 24 + 1) # 1 day 0 hours 0 minutes 1 second
print humanize_seconds(60 * 60 * 24 + 60 * 2 + 60*60*9+ 3) # 1 day 9 hours 2 minutes 3 seconds
Returns:
time difference as a readable string
print(humanize_time_delta(1)) # 1 second
print(humanize_time_delta(60 + 1)) # 1 minute 1 second
print(humanize_time_delta(87.6)) # 1 minute 27 seconds
print(humanize_time_delta(0.01)) # 0.01 seconds
print(humanize_time_delta(60 * 60 + 1)) # 1 hour 1 second
print(humanize_time_delta(60 * 60 * 24 + 1)) # 1 day 1 second
print(humanize_time_delta(60 * 60 * 24 + 60 * 2 + 60*60*9 + 3)) # 1 day 9 hours 2 minutes 3 seconds
"""
assert
sec
>=
0
,
sec
if
sec
==
0
:
return
"0 second"
time
=
datetime
(
2000
,
1
,
1
)
+
timedelta
(
seconds
=
int
(
sec
))
units
=
[
'day'
,
'hour'
,
'minute'
,
'second'
]
vals
=
[
time
.
day
-
1
,
time
.
hour
,
time
.
minute
,
time
.
second
]
vals
=
[
int
(
sec
//
86400
)
,
time
.
hour
,
time
.
minute
,
time
.
second
]
if
sec
<
60
:
vals
[
-
1
]
=
sec
def
_format
(
v
,
u
):
return
"{:.3g} {}{}"
.
format
(
v
,
u
,
"s"
if
v
>
1
else
""
)
required
=
False
ans
=
[]
for
v
,
u
in
zip
(
vals
,
units
):
if
not
required
:
if
v
>
0
:
required
=
True
ans
.
append
(
_format
(
v
,
u
))
else
:
ans
.
append
(
_format
(
v
,
u
))
return
" "
.
join
(
ans
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment