Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d382ea9d
Commit
d382ea9d
authored
Mar 03, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix cifar bug
parent
28e42e11
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
5 deletions
+7
-5
example_cifar10.py
example_cifar10.py
+0
-2
example_mnist.py
example_mnist.py
+2
-2
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+5
-1
No files found.
example_cifar10.py
View file @
d382ea9d
...
...
@@ -157,8 +157,6 @@ if __name__ == '__main__':
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
else
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
'/cpu:0'
):
...
...
example_mnist.py
View file @
d382ea9d
...
...
@@ -96,8 +96,8 @@ def get_config():
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-3
,
global_step
=
get_global_step_var
(),
decay_steps
=
dataset_train
.
size
()
*
2
0
,
decay_rate
=
0.
1
,
staircase
=
True
,
name
=
'learning_rate'
)
decay_steps
=
dataset_train
.
size
()
*
1
0
,
decay_rate
=
0.
5
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
...
...
tensorpack/dataflow/dataset/cifar10.py
View file @
d382ea9d
...
...
@@ -6,6 +6,7 @@ import os, sys
import
pickle
import
numpy
as
np
import
random
import
six
from
six.moves
import
urllib
,
range
import
copy
import
tarfile
...
...
@@ -44,7 +45,10 @@ def read_cifar10(filenames):
ret
=
[]
for
fname
in
filenames
:
fo
=
open
(
fname
,
'rb'
)
if
six
.
PY3
:
dic
=
pickle
.
load
(
fo
,
encoding
=
'bytes'
)
else
:
dic
=
pickle
.
load
(
fo
)
data
=
dic
[
b
'data'
]
label
=
dic
[
b
'labels'
]
fo
.
close
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment