Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cdd71bfe
Commit
cdd71bfe
authored
Feb 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix get_input_queue
parent
0a012166
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
8 deletions
+10
-8
example_cifar10.py
example_cifar10.py
+8
-6
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
example_cifar10.py
View file @
cdd71bfe
...
@@ -33,6 +33,7 @@ class Model(ModelDesc):
...
@@ -33,6 +33,7 @@ class Model(ModelDesc):
def
_get_cost
(
self
,
input_vars
,
is_training
):
def
_get_cost
(
self
,
input_vars
,
is_training
):
image
,
label
=
input_vars
image
,
label
=
input_vars
keep_prob
=
tf
.
constant
(
0.5
if
is_training
else
1.0
)
if
is_training
:
if
is_training
:
image
,
label
=
tf
.
train
.
shuffle_batch
(
image
,
label
=
tf
.
train
.
shuffle_batch
(
...
@@ -40,7 +41,7 @@ class Model(ModelDesc):
...
@@ -40,7 +41,7 @@ class Model(ModelDesc):
num_threads
=
6
,
enqueue_many
=
True
)
num_threads
=
6
,
enqueue_many
=
True
)
tf
.
image_summary
(
"train_image"
,
image
,
10
)
tf
.
image_summary
(
"train_image"
,
image
,
10
)
l
=
Conv2D
(
'conv1.1'
,
image
,
out_channel
=
64
,
kernel_shape
=
3
,
padding
=
'SAME'
)
l
=
Conv2D
(
'conv1.1'
,
image
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
Conv2D
(
'conv1.2'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
,
nl
=
tf
.
identity
)
l
=
Conv2D
(
'conv1.2'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
,
nl
=
tf
.
identity
)
l
=
BatchNorm
(
'bn1'
,
l
,
is_training
)
l
=
BatchNorm
(
'bn1'
,
l
,
is_training
)
l
=
tf
.
nn
.
relu
(
l
)
l
=
tf
.
nn
.
relu
(
l
)
...
@@ -56,8 +57,9 @@ class Model(ModelDesc):
...
@@ -56,8 +57,9 @@ class Model(ModelDesc):
l
=
Conv2D
(
'conv3.2'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
padding
=
'VALID'
,
nl
=
tf
.
identity
)
l
=
Conv2D
(
'conv3.2'
,
l
,
out_channel
=
128
,
kernel_shape
=
3
,
padding
=
'VALID'
,
nl
=
tf
.
identity
)
l
=
BatchNorm
(
'bn3'
,
l
,
is_training
)
l
=
BatchNorm
(
'bn3'
,
l
,
is_training
)
l
=
tf
.
nn
.
relu
(
l
)
l
=
tf
.
nn
.
relu
(
l
)
l
=
FullyConnected
(
'fc0'
,
l
,
512
,
l
=
FullyConnected
(
'fc0'
,
l
,
1024
+
512
,
b_init
=
tf
.
constant_initializer
(
0.1
))
b_init
=
tf
.
constant_initializer
(
0.1
))
l
=
tf
.
nn
.
dropout
(
l
,
keep_prob
)
l
=
FullyConnected
(
'fc1'
,
l
,
out_dim
=
512
,
l
=
FullyConnected
(
'fc1'
,
l
,
out_dim
=
512
,
b_init
=
tf
.
constant_initializer
(
0.1
))
b_init
=
tf
.
constant_initializer
(
0.1
))
# fc will have activation summary by default. disable for the output layer
# fc will have activation summary by default. disable for the output layer
...
@@ -120,13 +122,13 @@ def get_config():
...
@@ -120,13 +122,13 @@ def get_config():
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-2
,
learning_rate
=
1e-2
,
global_step
=
get_global_step_var
(),
global_step
=
get_global_step_var
(),
decay_steps
=
dataset_train
.
size
()
*
3
0
,
decay_steps
=
dataset_train
.
size
()
*
4
0
,
decay_rate
=
0.
5
,
staircase
=
True
,
name
=
'learning_rate'
)
decay_rate
=
0.
4
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
return
TrainConfig
(
dataset
=
dataset_train
,
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
StatPrinter
(),
StatPrinter
(),
PeriodicSaver
(),
PeriodicSaver
(),
...
@@ -135,7 +137,7 @@ def get_config():
...
@@ -135,7 +137,7 @@ def get_config():
session_config
=
sess_config
,
session_config
=
sess_config
,
model
=
Model
(),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
5
00
,
max_epoch
=
3
00
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/callbacks/summary.py
View file @
cdd71bfe
...
@@ -40,7 +40,7 @@ class StatHolder(object):
...
@@ -40,7 +40,7 @@ class StatHolder(object):
def
_print_stat
(
self
):
def
_print_stat
(
self
):
for
k
,
v
in
sorted
(
self
.
stat_now
.
items
(),
key
=
operator
.
itemgetter
(
0
)):
for
k
,
v
in
sorted
(
self
.
stat_now
.
items
(),
key
=
operator
.
itemgetter
(
0
)):
if
self
.
print_tag
is
None
or
k
in
self
.
print_tag
:
if
self
.
print_tag
is
None
or
k
in
self
.
print_tag
:
logger
.
info
(
'{}: {:.
4
f}'
.
format
(
k
,
v
))
logger
.
info
(
'{}: {:.
5
f}'
.
format
(
k
,
v
))
def
_write_stat
(
self
):
def
_write_stat
(
self
):
tmp_filename
=
self
.
filename
+
'.tmp'
tmp_filename
=
self
.
filename
+
'.tmp'
...
...
tensorpack/train/trainer.py
View file @
cdd71bfe
...
@@ -93,7 +93,7 @@ class QueueInputTrainer(Trainer):
...
@@ -93,7 +93,7 @@ class QueueInputTrainer(Trainer):
def
train
(
self
):
def
train
(
self
):
model
=
self
.
model
model
=
self
.
model
input_vars
=
model
.
get_input_vars
()
input_vars
=
model
.
get_input_vars
()
input_queue
=
model
.
get_input_queue
()
input_queue
=
model
.
get_input_queue
(
input_vars
)
enqueue_op
=
input_queue
.
enqueue
(
input_vars
)
enqueue_op
=
input_queue
.
enqueue
(
input_vars
)
def
get_model_inputs
():
def
get_model_inputs
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment