Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b059ce49
Commit
b059ce49
authored
May 01, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update
parent
aed3438b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
2 deletions
+16
-2
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+3
-1
tensorpack/models/_common.py
tensorpack/models/_common.py
+7
-0
tensorpack/predict.py
tensorpack/predict.py
+5
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-0
No files found.
tensorpack/dataflow/common.py
View file @
b059ce49
...
...
@@ -154,6 +154,7 @@ class MapData(ProxyDataFlow):
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a original datapoint, returns a new
datapoint. return None to skip this data point.
Note that if you use filter, ds.size() won't be correct.
"""
super
(
MapData
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
...
...
@@ -171,6 +172,7 @@ class MapDataComponent(ProxyDataFlow):
:param ds: a :mod:`DataFlow` instance.
:param func: a function that takes a datapoint component dp[index], returns a
new value of dp[index]. return None to skip this datapoint.
Note that if you use filter, ds.size() won't be correct.
"""
super
(
MapDataComponent
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
...
...
tensorpack/models/_common.py
View file @
b059ce49
...
...
@@ -16,6 +16,13 @@ from ..utils import logger
# make sure each layer is only logged once
_layer_logged
=
set
()
def
disable_layer_logging
():
class
ContainEverything
:
def
__contains__
(
self
,
x
):
return
True
# can use nonlocal in python3, but how
globals
()[
'_layer_logged'
]
=
ContainEverything
()
def
layer_register
(
summary_activation
=
False
,
log_shape
=
True
):
"""
Register a layer.
...
...
tensorpack/predict.py
View file @
b059ce49
...
...
@@ -115,9 +115,13 @@ class PredictWorker(multiprocessing.Process):
self
.
config
=
config
def
run
(
self
):
logger
.
info
(
"Worker {} use GPU {}"
.
format
(
self
.
idx
,
self
.
gpuid
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
):
if
self
.
idx
!=
0
:
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
self
.
func
=
get_predict_func
(
self
.
config
)
if
self
.
idx
==
0
:
describe_model
()
...
...
@@ -173,13 +177,13 @@ class DatasetPredictor(object):
die_cnt
=
0
while
True
:
res
=
self
.
result_queue
.
get
()
pbar
.
update
()
if
res
[
0
]
!=
DIE
:
yield
res
[
1
]
else
:
die_cnt
+=
1
if
die_cnt
==
self
.
nr_gpu
:
break
pbar
.
update
()
self
.
inqueue_proc
.
join
()
self
.
inqueue_proc
.
terminate
()
for
p
in
self
.
workers
:
...
...
tensorpack/train/trainer.py
View file @
b059ce49
...
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
import
threading
import
time
import
copy
import
re
import
functools
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment