Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8efd12b1
Commit
8efd12b1
authored
Jan 01, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add progress bar
parent
26edfabe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
11 deletions
+18
-11
example_cifar10.py
example_cifar10.py
+1
-1
requirements.txt
requirements.txt
+1
-0
tensorpack/train.py
tensorpack/train.py
+3
-1
tensorpack/utils/validation_callback.py
tensorpack/utils/validation_callback.py
+13
-9
No files found.
example_cifar10.py
View file @
8efd12b1
...
@@ -140,7 +140,7 @@ def get_config():
...
@@ -140,7 +140,7 @@ def get_config():
get_model_func
=
get_model
,
get_model_func
=
get_model
,
batched_model_input
=
False
,
batched_model_input
=
False
,
step_per_epoch
=
step_per_epoch
,
step_per_epoch
=
step_per_epoch
,
max_epoch
=
1
00
,
max_epoch
=
5
00
,
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
requirements.txt
View file @
8efd12b1
...
@@ -2,3 +2,4 @@ pip @ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-
...
@@ -2,3 +2,4 @@ pip @ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-
termcolor
termcolor
numpy
numpy
protobuf
~=3.0.0a1
protobuf
~=3.0.0a1
tqdm
tensorpack/train.py
View file @
8efd12b1
...
@@ -7,6 +7,7 @@ import tensorflow as tf
...
@@ -7,6 +7,7 @@ import tensorflow as tf
from
itertools
import
count
from
itertools
import
count
import
argparse
import
argparse
import
tqdm
from
utils
import
*
from
utils
import
*
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.callback
import
Callbacks
from
utils.callback
import
Callbacks
...
@@ -134,7 +135,8 @@ def start_train(config):
...
@@ -134,7 +135,8 @@ def start_train(config):
callbacks
.
before_train
()
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
config
.
max_epoch
):
for
epoch
in
xrange
(
1
,
config
.
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
step
in
xrange
(
config
.
step_per_epoch
):
for
step
in
tqdm
.
trange
(
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.2
):
if
coord
.
should_stop
():
if
coord
.
should_stop
():
return
return
# TODO if no one uses trigger_step, train_op can be
# TODO if no one uses trigger_step, train_op can be
...
...
tensorpack/utils/validation_callback.py
View file @
8efd12b1
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tqdm
import
tqdm
from
.stat
import
*
from
.stat
import
*
from
.callback
import
PeriodicCallback
,
Callback
from
.callback
import
PeriodicCallback
,
Callback
from
.naming
import
*
from
.naming
import
*
...
@@ -42,17 +44,19 @@ class ValidationError(PeriodicCallback):
...
@@ -42,17 +44,19 @@ class ValidationError(PeriodicCallback):
cnt
=
0
cnt
=
0
err_stat
=
Accuracy
()
err_stat
=
Accuracy
()
cost_sum
=
0
cost_sum
=
0
for
dp
in
self
.
ds
.
get_data
():
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
for
dp
in
self
.
ds
.
get_data
():
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
cnt
+=
batch_size
cnt
+=
batch_size
wrong
,
cost
=
self
.
sess
.
run
(
wrong
,
cost
=
self
.
sess
.
run
(
[
self
.
wrong_var
,
self
.
cost_var
],
feed_dict
=
feed
)
[
self
.
wrong_var
,
self
.
cost_var
],
feed_dict
=
feed
)
err_stat
.
feed
(
wrong
,
batch_size
)
err_stat
.
feed
(
wrong
,
batch_size
)
# each batch might not have the same size in validation
# each batch might not have the same size in validation
cost_sum
+=
cost
*
batch_size
cost_sum
+=
cost
*
batch_size
pbar
.
update
()
cost_avg
=
cost_sum
/
cnt
cost_avg
=
cost_sum
/
cnt
self
.
writer
.
add_summary
(
self
.
writer
.
add_summary
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment