Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3754faac
Commit
3754faac
authored
Oct 31, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
cifar with plateu detection
parent
5aab2d2d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
16 deletions
+21
-16
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-1
examples/cifar-convnet.py
examples/cifar-convnet.py
+13
-13
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+6
-1
No files found.
examples/Atari2600/DQN.py
View file @
3754faac
...
@@ -40,7 +40,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.01
...
@@ -40,7 +40,7 @@ EXPLORATION_EPOCH_ANNEAL = 0.01
END_EXPLORATION
=
0.1
END_EXPLORATION
=
0.1
MEMORY_SIZE
=
1e6
MEMORY_SIZE
=
1e6
# NOTE: will consume at least 1e6 * 84 * 84
* 4 bytes = 2
6G memory.
# NOTE: will consume at least 1e6 * 84 * 84
bytes == 6.
6G memory.
# Suggest using tcmalloc to manage memory space better.
# Suggest using tcmalloc to manage memory space better.
INIT_MEMORY_SIZE
=
5e4
INIT_MEMORY_SIZE
=
5e4
STEP_PER_EPOCH
=
10000
STEP_PER_EPOCH
=
10000
...
...
examples/cifar-convnet.py
View file @
3754faac
...
@@ -15,8 +15,7 @@ from tensorpack.tfutils.summary import *
...
@@ -15,8 +15,7 @@ from tensorpack.tfutils.summary import *
A small convnet model for Cifar10 or Cifar100 dataset.
A small convnet model for Cifar10 or Cifar100 dataset.
Cifar10:
Cifar10:
90
%
validation accuracy after 40k step.
91
%
accuracy after 50k step.
91
%
accuracy after 80k step.
19.3 step/s on Tesla M40
19.3 step/s on Tesla M40
Not a good model for Cifar100, just for demonstration.
Not a good model for Cifar100, just for demonstration.
...
@@ -66,7 +65,7 @@ class Model(ModelDesc):
...
@@ -66,7 +65,7 @@ class Model(ModelDesc):
add_moving_summary
(
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
))
add_moving_summary
(
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
))
# weight decay on all W of fc layers
# weight decay on all W of fc layers
wd_cost
=
tf
.
mul
(
0.004
,
wd_cost
=
tf
.
mul
(
0.00
0
4
,
regularize_cost
(
'fc.*/W'
,
tf
.
nn
.
l2_loss
),
regularize_cost
(
'fc.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'regularize_loss'
)
name
=
'regularize_loss'
)
add_moving_summary
(
cost
,
wd_cost
)
add_moving_summary
(
cost
,
wd_cost
)
...
@@ -112,26 +111,27 @@ def get_config(cifar_classnum):
...
@@ -112,26 +111,27 @@ def get_config(cifar_classnum):
sess_config
=
get_default_sess_config
(
0.5
)
sess_config
=
get_default_sess_config
(
0.5
)
nr_gpu
=
get_nr_gpu
()
lr
=
tf
.
Variable
(
1e-2
,
name
=
'learning_rate'
,
lr
=
tf
.
train
.
exponential_decay
(
dtype
=
tf
.
float32
,
trainable
=
False
)
learning_rate
=
1e-2
,
global_step
=
get_global_step_var
(),
decay_steps
=
step_per_epoch
*
(
30
if
nr_gpu
==
1
else
20
),
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
def
lr_func
(
lr
):
if
lr
<
3e-5
:
raise
StopTraining
()
return
lr
*
0.31
return
TrainConfig
(
return
TrainConfig
(
dataset
=
dataset_train
,
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
ModelSaver
(),
ModelSaver
(),
InferenceRunner
(
dataset_test
,
ClassificationError
()),
InferenceRunner
(
dataset_test
,
ClassificationError
())
StatMonitorParamSetter
(
'learning_rate'
,
'val_error'
,
lr_func
,
threshold
=
0.001
,
last_k
=
10
),
]),
]),
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(
cifar_classnum
),
model
=
Model
(
cifar_classnum
),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
2
50
,
max_epoch
=
1
50
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/callbacks/base.py
View file @
3754faac
...
@@ -8,7 +8,7 @@ import os
...
@@ -8,7 +8,7 @@ import os
import
time
import
time
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
__all__
=
[
'Callback'
,
'PeriodicCallback'
]
__all__
=
[
'Callback'
,
'PeriodicCallback'
,
'ProxyCallback'
]
class
Callback
(
object
):
class
Callback
(
object
):
""" Base class for all callbacks """
""" Base class for all callbacks """
...
...
tensorpack/train/base.py
View file @
3754faac
...
@@ -18,7 +18,10 @@ from ..callbacks import StatHolder
...
@@ -18,7 +18,10 @@ from ..callbacks import StatHolder
from
..tfutils
import
get_global_step
,
get_global_step_var
from
..tfutils
import
get_global_step
,
get_global_step_var
from
..tfutils.summary
import
create_summary
from
..tfutils.summary
import
create_summary
__all__
=
[
'Trainer'
]
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
pass
class
Trainer
(
object
):
class
Trainer
(
object
):
"""
"""
...
@@ -138,6 +141,8 @@ class Trainer(object):
...
@@ -138,6 +141,8 @@ class Trainer(object):
#callbacks.trigger_step() # not useful?
#callbacks.trigger_step() # not useful?
self
.
global_step
+=
1
self
.
global_step
+=
1
self
.
trigger_epoch
()
self
.
trigger_epoch
()
except
StopTraining
:
logger
.
info
(
"Training was stopped."
)
except
(
KeyboardInterrupt
,
Exception
):
except
(
KeyboardInterrupt
,
Exception
):
raise
raise
finally
:
finally
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment