Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
88af1f1d
Commit
88af1f1d
authored
Jan 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
a better handling of optional import.
parent
6e24b953
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
115 additions
and
97 deletions
+115
-97
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+1
-6
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+12
-12
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+8
-8
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+7
-6
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+9
-7
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+20
-22
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+2
-7
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+1
-9
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+12
-13
tensorpack/utils/dependency.py
tensorpack/utils/dependency.py
+42
-0
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+1
-7
No files found.
examples/A3C-Gym/simulator.py
View file @
88af1f1d
...
...
@@ -12,6 +12,7 @@ from collections import defaultdict
import
six
from
six.moves
import
queue
import
zmq
from
tensorpack.models.common
import
disable_layer_logging
from
tensorpack.callbacks
import
Callback
...
...
@@ -25,12 +26,6 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
try
:
import
zmq
except
ImportError
:
logger
.
warn_dependency
(
'Simulator'
,
'zmq'
)
__all__
=
[]
class
TransitionExperience
(
object
):
""" A transition of state, or experience"""
...
...
tensorpack/RL/gymenv.py
View file @
88af1f1d
...
...
@@ -5,18 +5,6 @@
import
time
from
..utils
import
logger
try
:
import
gym
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__
=
[
'GymEnv'
]
except
ImportError
:
logger
.
warn_dependency
(
'GymEnv'
,
'gym'
)
__all__
=
[]
import
threading
from
..utils.fs
import
mkdir_p
...
...
@@ -24,6 +12,7 @@ from ..utils.stats import StatCounter
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
__all__
=
[
'GymEnv'
]
_ENV_LOCK
=
threading
.
Lock
()
...
...
@@ -84,6 +73,17 @@ class GymEnv(RLEnvironment):
return
DiscreteActionSpace
(
spc
.
n
)
try
:
import
gym
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
except
ImportError
:
from
..utils.dependency
import
create_dummy_class
GymEnv
=
create_dummy_class
(
'GymEnv'
,
'gym'
)
# noqa
if
__name__
==
'__main__'
:
env
=
GymEnv
(
'Breakout-v0'
,
viz
=
0.1
)
num
=
env
.
get_action_space
()
.
num_actions
()
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
88af1f1d
...
...
@@ -8,17 +8,11 @@ import glob
import
cv2
import
numpy
as
np
from
...utils
import
logger
,
get_dataset_path
from
...utils
import
get_dataset_path
from
...utils.fs
import
download
from
..base
import
RNGDataFlow
try
:
from
scipy.io
import
loadmat
__all__
=
[
'BSDS500'
]
except
ImportError
:
logger
.
warn_dependency
(
'BSDS500'
,
'scipy.io'
)
__all__
=
[]
__all__
=
[
'BSDS500'
]
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
...
...
@@ -95,6 +89,12 @@ class BSDS500(RNGDataFlow):
yield
[
self
.
data
[
k
],
self
.
label
[
k
]]
try
:
from
scipy.io
import
loadmat
except
ImportError
:
from
...utils.dependency
import
create_dummy_class
BSDS500
=
create_dummy_class
(
'BSDS500'
,
'scipy.io'
)
# noqa
if
__name__
==
'__main__'
:
a
=
BSDS500
(
'val'
)
for
k
in
a
.
get_data
():
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
88af1f1d
...
...
@@ -9,12 +9,7 @@ import numpy as np
from
...utils
import
logger
,
get_dataset_path
from
..base
import
RNGDataFlow
try
:
import
scipy.io
__all__
=
[
'SVHNDigit'
]
except
ImportError
:
logger
.
warn_dependency
(
'SVHNDigit'
,
'scipy.io'
)
__all__
=
[]
__all__
=
[
'SVHNDigit'
]
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
...
...
@@ -73,6 +68,12 @@ class SVHNDigit(RNGDataFlow):
return
np
.
concatenate
((
a
.
X
,
b
.
X
,
c
.
X
))
.
mean
(
axis
=
0
)
try
:
import
scipy.io
except
ImportError
:
from
...utils.dependency
import
create_dummy_class
SVHNDigit
=
create_dummy_class
(
'SVHNDigit'
,
'scipy.io'
)
# noqa
if
__name__
==
'__main__'
:
a
=
SVHNDigit
(
'train'
)
b
=
SVHNDigit
.
get_per_pixel_mean
()
tensorpack/dataflow/dftools.py
View file @
88af1f1d
...
...
@@ -15,13 +15,8 @@ from ..utils.concurrency import DIE
from
..utils.serialize
import
dumps
from
..utils.fs
import
mkdir_p
__all__
=
[
'dump_dataset_images'
,
'dataflow_to_process_queue'
]
try
:
import
lmdb
except
ImportError
:
logger
.
warn_dependency
(
"dump_dataflow_to_lmdb"
,
'lmdb'
)
else
:
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
__all__
=
[
'dump_dataset_images'
,
'dataflow_to_process_queue'
,
'dump_dataflow_to_lmdb'
]
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
...
...
@@ -84,6 +79,13 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
db
.
sync
()
try
:
import
lmdb
except
ImportError
:
from
..utils.dependency
import
create_dummy_func
dump_dataflow_to_lmdb
=
create_dummy_func
(
'dump_dataflow_to_lmdb'
,
'lmdb'
)
# noqa
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
...
...
tensorpack/dataflow/format.py
View file @
88af1f1d
...
...
@@ -13,28 +13,8 @@ from ..utils.serialize import loads
from
..utils.argtools
import
log_once
from
.base
import
RNGDataFlow
try
:
import
h5py
except
ImportError
:
logger
.
warn_dependency
(
"HDF5Data"
,
'h5py'
)
__all__
=
[]
else
:
__all__
=
[
'HDF5Data'
]
try
:
import
lmdb
except
ImportError
:
logger
.
warn_dependency
(
"LMDBData"
,
'lmdb'
)
else
:
__all__
.
extend
([
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB'
])
try
:
import
sklearn.datasets
except
ImportError
:
logger
.
warn_dependency
(
'SVMLightData'
,
'sklearn'
)
else
:
__all__
.
extend
([
'SVMLightData'
])
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB'
,
'SVMLightData'
]
"""
Adapters for different data format.
...
...
@@ -214,3 +194,21 @@ class SVMLightData(RNGDataFlow):
self
.
rng
.
shuffle
(
idxs
)
for
id
in
idxs
:
yield
[
self
.
X
[
id
,
:],
self
.
y
[
id
]]
from
..utils.dependency
import
create_dummy_class
# noqa
try
:
import
h5py
except
ImportError
:
HDF5Data
=
create_dummy_class
(
'HDF5Data'
,
'h5py'
)
# noqa
try
:
import
lmdb
except
ImportError
:
for
klass
in
[
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB'
]:
globals
()[
klass
]
=
create_dummy_class
(
klass
,
'lmdb'
)
try
:
import
sklearn.datasets
except
ImportError
:
SVMLightData
=
create_dummy_class
(
'SVMLightData'
,
'sklearn'
)
# noqa
tensorpack/dataflow/prefetch.py
View file @
88af1f1d
...
...
@@ -8,6 +8,7 @@ import itertools
from
six.moves
import
range
,
zip
import
uuid
import
os
import
zmq
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
(
ensure_proc_terminate
,
...
...
@@ -16,13 +17,7 @@ from ..utils.serialize import loads, dumps
from
..utils
import
logger
from
..utils.gpu
import
change_gpu
__all__
=
[
'PrefetchData'
,
'BlockParallel'
]
try
:
import
zmq
except
ImportError
:
logger
.
warn_dependency
(
'PrefetchDataZMQ'
,
'zmq'
)
else
:
__all__
.
extend
([
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
])
__all__
=
[
'PrefetchData'
,
'BlockParallel'
,
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
]
class
PrefetchProcess
(
mp
.
Process
):
...
...
tensorpack/dataflow/tf_func.py
View file @
88af1f1d
...
...
@@ -3,22 +3,14 @@
# File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
.base
import
ProxyDataFlow
from
..utils
import
logger
try
:
import
tensorflow
as
tf
except
ImportError
:
logger
.
warn_dependency
(
'TFFuncMapper'
,
'tensorflow'
)
__all__
=
[]
else
:
__all__
=
[]
""" This file was deprecated """
class
TFFuncMapper
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
"""
...
...
tensorpack/predict/concurrency.py
View file @
88af1f1d
...
...
@@ -9,20 +9,9 @@ from six.moves import queue, range
from
..utils.concurrency
import
DIE
,
StoppableThread
from
..tfutils.modelutils
import
describe_model
from
..utils
import
logger
from
.base
import
OfflinePredictor
,
AsyncPredictorBase
try
:
if
six
.
PY2
:
from
tornado.concurrent
import
Future
else
:
from
concurrent.futures
import
Future
except
ImportError
:
logger
.
warn_dependency
(
'MultiThreadAsyncPredictor'
,
'tornado.concurrent'
)
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
]
else
:
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
...
...
@@ -171,3 +160,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
f
.
add_done_callback
(
callback
)
self
.
input_queue
.
put
((
dp
,
f
))
return
f
try
:
if
six
.
PY2
:
from
tornado.concurrent
import
Future
else
:
from
concurrent.futures
import
Future
except
ImportError
:
from
..utils.dependency
import
create_dummy_class
MultiThreadAsyncPredictor
=
create_dummy_class
(
'MultiThreadAsyncPredictor'
,
'tornado.concurrent'
)
# noqa
tensorpack/utils/dependency.py
0 → 100644
View file @
88af1f1d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dependency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Utilities to handle dependency """
__all__
=
[
'create_dummy_func'
,
'create_dummy_class'
]
def
create_dummy_class
(
klass
,
dependency
):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
Returns:
class: a class object
"""
class
_Dummy
(
object
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
raise
ImportError
(
"Cannot import '{}', therefore '{}' is not available"
.
format
(
dependency
,
klass
))
return
_Dummy
def
create_dummy_func
(
func
,
dependency
):
"""
When a dependency of a function is not available, create a dummy function which throws ImportError when used.
Args:
func (str): name of the function.
dependency (str): name of the dependency.
Returns:
function: a function object
"""
def
_dummy
(
*
args
,
**
kwargs
):
raise
ImportError
(
"Cannot import '{}', therefore '{}' is not available"
.
format
(
dependency
,
func
))
return
_dummy
tensorpack/utils/logger.py
View file @
88af1f1d
...
...
@@ -11,8 +11,7 @@ from datetime import datetime
from
six.moves
import
input
import
sys
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
,
'warn_dependency'
]
__all__
=
[
'set_logger_dir'
,
'disable_logger'
,
'auto_set_dir'
]
class
_MyFormatter
(
logging
.
Formatter
):
...
...
@@ -128,8 +127,3 @@ def auto_set_dir(action=None, overwrite=False):
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]),
action
=
action
)
def
warn_dependency
(
name
,
dependencies
):
""" Print warning about an import failure due to missing dependencies. """
warn
(
"Failed to import '{}', {} won't be available'"
.
format
(
dependencies
,
name
))
# noqa: F821
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment