Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e7ede3eb
Commit
e7ede3eb
authored
Jun 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dataset dir
parent
4a59173c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
26 additions
and
31 deletions
+26
-31
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+5
-1
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+2
-2
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+1
-2
tensorpack/dataflow/dataset/common.py
tensorpack/dataflow/dataset/common.py
+0
-17
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+1
-2
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+1
-2
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+1
-2
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+15
-3
No files found.
tensorpack/RL/atari.py
View file @
e7ede3eb
...
...
@@ -9,7 +9,7 @@ import cv2
from
collections
import
deque
import
six
from
six.moves
import
range
from
..utils
import
get_rng
,
logger
,
memoized
from
..utils
import
get_rng
,
logger
,
memoized
,
get_dataset_dir
from
..utils.stat
import
StatCounter
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
...
...
@@ -46,6 +46,10 @@ class AtariPlayer(RLEnvironment):
:param live_losts_as_eoe: consider lost of lives as end of episode. useful for training.
"""
super
(
AtariPlayer
,
self
)
.
__init__
()
if
not
os
.
path
.
isfile
(
rom_file
)
and
'/'
not
in
rom_file
:
rom_file
=
os
.
path
.
join
(
get_dataset_dir
(
'atari_rom'
),
rom_file
)
assert
os
.
path
.
isfile
(
rom_file
),
"rom {} not found"
.
format
(
rom_file
)
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
e7ede3eb
...
...
@@ -7,10 +7,10 @@ import os, glob
import
cv2
import
numpy
as
np
from
scipy.io
import
loadmat
from
...utils
import
logger
,
get_rng
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'BSDS500'
]
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
e7ede3eb
...
...
@@ -13,10 +13,9 @@ from six.moves import urllib, range
import
copy
import
logging
from
...utils
import
logger
,
get_rng
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'Cifar10'
,
'Cifar100'
]
...
...
tensorpack/dataflow/dataset/common.py
deleted
100644 → 0
View file @
4a59173c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
__all__
=
[
'get_dataset_dir'
]
def
get_dataset_dir
(
name
):
d
=
os
.
environ
[
'TENSORPACK_DATASET'
]:
if
d
:
assert
os
.
path
.
isdir
(
d
)
else
:
d
=
os
.
path
.
dirname
(
__file__
)
return
os
.
path
.
join
(
d
,
name
)
tensorpack/dataflow/dataset/ilsvrc.py
View file @
e7ede3eb
...
...
@@ -7,10 +7,9 @@ import tarfile
import
cv2
import
numpy
as
np
from
...utils
import
logger
,
get_rng
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
mkdir_p
,
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
]
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
e7ede3eb
...
...
@@ -9,10 +9,9 @@ import random
import
numpy
from
six.moves
import
urllib
,
range
from
...utils
import
logger
from
...utils
import
logger
,
get_dataset_dir
from
...utils.fs
import
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'Mnist'
]
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
e7ede3eb
...
...
@@ -10,9 +10,8 @@ import scipy
import
scipy.io
from
six.moves
import
range
from
...utils
import
logger
,
get_rng
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'SVHNDigit'
]
...
...
tensorpack/utils/utils.py
View file @
e7ede3eb
...
...
@@ -12,7 +12,10 @@ import numpy as np
from
.
import
logger
__all__
=
[
'change_env'
,
'get_rng'
,
'memoized'
,
'get_nr_gpu'
,
'get_gpus'
]
'get_rng'
,
'memoized'
,
'get_nr_gpu'
,
'get_gpus'
,
'get_dataset_dir'
]
#def expand_dim_if_necessary(var, dp):
# """
...
...
@@ -73,11 +76,20 @@ def get_rng(self):
return
np
.
random
.
RandomState
(
seed
)
def
get_nr_gpu
():
env
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
# TODO
return
len
(
env
.
split
(
','
))
def
get_gpus
():
env
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
None
)
assert
env
is
not
None
# TODO
return
map
(
int
,
env
.
strip
()
.
split
(
','
))
def
get_dataset_dir
(
name
):
d
=
os
.
environ
.
get
(
'TENSORPACK_DATASET'
,
None
)
if
d
:
assert
os
.
path
.
isdir
(
d
)
else
:
d
=
os
.
path
.
dirname
(
__file__
)
return
os
.
path
.
join
(
d
,
name
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment