Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
879995d9
Commit
879995d9
authored
Jan 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix rnn sample(). fix gym breaking changes.
parent
14b3578a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
39 deletions
+46
-39
docs/user/models.md
docs/user/models.md
+4
-4
examples/Char-RNN/char-rnn.py
examples/Char-RNN/char-rnn.py
+40
-34
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+2
-1
No files found.
docs/user/models.md
View file @
879995d9
...
@@ -6,18 +6,18 @@ you'll need to subclass `ModelDesc` and implement several methods:
...
@@ -6,18 +6,18 @@ you'll need to subclass `ModelDesc` and implement several methods:
```
python
```
python
class
MyModel
(
ModelDesc
):
class
MyModel
(
ModelDesc
):
def
_get_input
_var
s
(
self
):
def
_get_inputs
(
self
):
return
[
InputVar
(
...
),
InputVar
(
...
)]
return
[
InputVar
(
...
),
InputVar
(
...
)]
def
_build_graph
(
self
,
input
_tensor
s
):
def
_build_graph
(
self
,
inputs
):
# build the graph
# build the graph
```
```
Basically,
`_get_input
_var
s`
should define the metainfo of the input
Basically,
`_get_inputs`
should define the metainfo of the input
of the model. It should match what is produced by the data you're training with.
of the model. It should match what is produced by the data you're training with.
`_build_graph`
should add tensors/operations to the graph, where
`_build_graph`
should add tensors/operations to the graph, where
the argument
`input_tensors`
is the list of input tensors matching the return value of
the argument
`input_tensors`
is the list of input tensors matching the return value of
`_get_input
_var
s`
.
`_get_inputs`
.
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer.
functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer.
...
...
examples/Char-RNN/char-rnn.py
View file @
879995d9
...
@@ -17,6 +17,7 @@ from tensorpack import *
...
@@ -17,6 +17,7 @@ from tensorpack import *
from
tensorpack.tfutils.gradproc
import
*
from
tensorpack.tfutils.gradproc
import
*
from
tensorpack.utils.lut
import
LookUpTable
from
tensorpack.utils.lut
import
LookUpTable
from
tensorpack.utils.globvars
import
globalns
as
param
from
tensorpack.utils.globvars
import
globalns
as
param
rnn
=
tf
.
contrib
.
rnn
# some model hyperparams to set
# some model hyperparams to set
param
.
batch_size
=
128
param
.
batch_size
=
128
...
@@ -67,10 +68,18 @@ class Model(ModelDesc):
...
@@ -67,10 +68,18 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
inputs
):
def
_build_graph
(
self
,
inputs
):
input
,
nextinput
=
inputs
input
,
nextinput
=
inputs
cell
=
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
num_units
=
param
.
rnn_size
)
cell
=
rnn
.
BasicLSTMCell
(
num_units
=
param
.
rnn_size
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
([
cell
]
*
param
.
num_rnn_layer
)
cell
=
rnn
.
MultiRNNCell
([
cell
]
*
param
.
num_rnn_layer
)
self
.
initial
=
initial
=
cell
.
zero_state
(
tf
.
shape
(
input
)[
0
],
tf
.
float32
)
def
get_v
(
n
):
ret
=
tf
.
get_variable
(
n
+
'_unused'
,
[
param
.
batch_size
,
param
.
rnn_size
],
trainable
=
False
,
initializer
=
tf
.
constant_initializer
())
ret
=
symbolic_functions
.
shapeless_placeholder
(
ret
,
0
,
name
=
n
)
return
ret
self
.
initial
=
initial
=
\
(
rnn
.
LSTMStateTuple
(
get_v
(
'c0'
),
get_v
(
'h0'
)),
rnn
.
LSTMStateTuple
(
get_v
(
'c1'
),
get_v
(
'h1'
)))
embeddingW
=
tf
.
get_variable
(
'embedding'
,
[
param
.
vocab_size
,
param
.
rnn_size
])
embeddingW
=
tf
.
get_variable
(
'embedding'
,
[
param
.
vocab_size
,
param
.
rnn_size
])
input_feature
=
tf
.
nn
.
embedding_lookup
(
embeddingW
,
input
)
# B x seqlen x rnnsize
input_feature
=
tf
.
nn
.
embedding_lookup
(
embeddingW
,
input
)
# B x seqlen x rnnsize
...
@@ -78,13 +87,13 @@ class Model(ModelDesc):
...
@@ -78,13 +87,13 @@ class Model(ModelDesc):
input_list
=
tf
.
unstack
(
input_feature
,
axis
=
1
)
# seqlen x (Bxrnnsize)
input_list
=
tf
.
unstack
(
input_feature
,
axis
=
1
)
# seqlen x (Bxrnnsize)
# seqlen is 1 in inference. don't need loop_function
# seqlen is 1 in inference. don't need loop_function
outputs
,
last_state
=
tf
.
contrib
.
rnn
.
static_rnn
(
cell
,
input_list
,
initial
,
scope
=
'rnnlm'
)
outputs
,
last_state
=
rnn
.
static_rnn
(
cell
,
input_list
,
initial
,
scope
=
'rnnlm'
)
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
# seqlen x (Bxrnnsize)
# seqlen x (Bxrnnsize)
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
param
.
rnn_size
])
# (Bxseqlen) x rnnsize
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
param
.
rnn_size
])
# (Bxseqlen) x rnnsize
logits
=
FullyConnected
(
'fc'
,
output
,
param
.
vocab_size
,
nl
=
tf
.
identity
)
logits
=
FullyConnected
(
'fc'
,
output
,
param
.
vocab_size
,
nl
=
tf
.
identity
)
self
.
prob
=
tf
.
nn
.
softmax
(
logits
/
param
.
softmax_temprature
)
self
.
prob
=
tf
.
nn
.
softmax
(
logits
/
param
.
softmax_temprature
,
name
=
'prob'
)
xent_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
xent_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
symbolic_functions
.
flatten
(
nextinput
))
logits
=
logits
,
labels
=
symbolic_functions
.
flatten
(
nextinput
))
...
@@ -130,35 +139,32 @@ def sample(path, start, length):
...
@@ -130,35 +139,32 @@ def sample(path, start, length):
param
.
seq_len
=
1
param
.
seq_len
=
1
ds
=
CharRNNData
(
param
.
corpus
,
100000
)
ds
=
CharRNNData
(
param
.
corpus
,
100000
)
model
=
Model
()
pred
=
OfflinePredictor
(
PredictConfig
(
inputs
=
model
.
get_reuse_placehdrs
()
model
=
Model
(),
model
.
build_graph
(
inputs
,
False
)
session_init
=
SaverRestore
(
path
),
sess
=
tf
.
Session
()
input_names
=
[
'input'
,
'c0'
,
'h0'
,
'c1'
,
'h1'
],
tfutils
.
SaverRestore
(
path
)
.
init
(
sess
)
output_names
=
[
'prob'
,
'last_state'
]))
dummy_input
=
np
.
zeros
((
1
,
1
),
dtype
=
'int32'
)
# feed the starting sentence
with
sess
.
as_default
():
initial
=
np
.
zeros
((
1
,
param
.
rnn_size
))
# feed the starting sentence
for
c
in
start
[:
-
1
]:
state
=
model
.
initial
.
eval
({
inputs
[
0
]:
dummy_input
})
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)]],
dtype
=
'int32'
)
for
c
in
start
[:
-
1
]:
_
,
state
=
pred
(
x
,
initial
,
initial
,
initial
,
initial
)
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)]],
dtype
=
'int32'
)
state
=
model
.
last_state
.
eval
({
inputs
[
0
]:
x
,
model
.
initial
:
state
})
def
pick
(
prob
):
t
=
np
.
cumsum
(
prob
)
def
pick
(
prob
):
s
=
np
.
sum
(
prob
)
t
=
np
.
cumsum
(
prob
)
return
(
int
(
np
.
searchsorted
(
t
,
np
.
random
.
rand
(
1
)
*
s
)))
s
=
np
.
sum
(
prob
)
return
(
int
(
np
.
searchsorted
(
t
,
np
.
random
.
rand
(
1
)
*
s
)))
# generate more
ret
=
start
# generate more
c
=
start
[
-
1
]
ret
=
start
for
k
in
range
(
length
):
c
=
start
[
-
1
]
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)]],
dtype
=
'int32'
)
for
k
in
range
(
length
):
prob
,
state
=
pred
(
x
,
state
[
0
,
0
],
state
[
0
,
1
],
state
[
1
,
0
],
state
[
1
,
1
])
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)]],
dtype
=
'int32'
)
c
=
ds
.
lut
.
get_obj
(
pick
(
prob
[
0
]))
[
prob
,
state
]
=
sess
.
run
([
model
.
prob
,
model
.
last_state
],
ret
+=
c
{
inputs
[
0
]:
x
,
model
.
initial
:
state
})
print
(
ret
)
c
=
ds
.
lut
.
get_obj
(
pick
(
prob
[
0
]))
ret
+=
c
print
(
ret
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/RL/gymenv.py
View file @
879995d9
...
@@ -34,7 +34,7 @@ class GymEnv(RLEnvironment):
...
@@ -34,7 +34,7 @@ class GymEnv(RLEnvironment):
self
.
gymenv
=
gym
.
make
(
name
)
self
.
gymenv
=
gym
.
make
(
name
)
if
dumpdir
:
if
dumpdir
:
mkdir_p
(
dumpdir
)
mkdir_p
(
dumpdir
)
self
.
gymenv
.
monitor
.
start
(
dumpdir
)
self
.
gymenv
=
gym
.
wrappers
.
Monitor
(
self
.
gymenv
,
dumpdir
)
self
.
use_dir
=
dumpdir
self
.
use_dir
=
dumpdir
self
.
reset_stat
()
self
.
reset_stat
()
...
@@ -75,6 +75,7 @@ class GymEnv(RLEnvironment):
...
@@ -75,6 +75,7 @@ class GymEnv(RLEnvironment):
try
:
try
:
import
gym
import
gym
import
gym.wrappers
# TODO
# TODO
# gym.undo_logger_setup()
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# https://github.com/openai/gym/pull/199
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment