Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
879995d9
Commit
879995d9
authored
Jan 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix rnn sample(). fix gym breaking changes.
parent
14b3578a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
39 deletions
+46
-39
docs/user/models.md
docs/user/models.md
+4
-4
examples/Char-RNN/char-rnn.py
examples/Char-RNN/char-rnn.py
+40
-34
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+2
-1
No files found.
docs/user/models.md
View file @
879995d9
...
...
@@ -6,18 +6,18 @@ you'll need to subclass `ModelDesc` and implement several methods:
```
python
class
MyModel
(
ModelDesc
):
def
_get_input
_var
s
(
self
):
def
_get_inputs
(
self
):
return
[
InputVar
(
...
),
InputVar
(
...
)]
def
_build_graph
(
self
,
input
_tensor
s
):
def
_build_graph
(
self
,
inputs
):
# build the graph
```
Basically,
`_get_input
_var
s`
should define the metainfo of the input
Basically,
`_get_inputs`
should define the metainfo of the input
of the model. It should match what is produced by the data you're training with.
`_build_graph`
should add tensors/operations to the graph, where
the argument
`input_tensors`
is the list of input tensors matching the return value of
`_get_input
_var
s`
.
`_get_inputs`
.
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer.
...
...
examples/Char-RNN/char-rnn.py
View file @
879995d9
...
...
@@ -17,6 +17,7 @@ from tensorpack import *
from
tensorpack.tfutils.gradproc
import
*
from
tensorpack.utils.lut
import
LookUpTable
from
tensorpack.utils.globvars
import
globalns
as
param
rnn
=
tf
.
contrib
.
rnn
# some model hyperparams to set
param
.
batch_size
=
128
...
...
@@ -67,10 +68,18 @@ class Model(ModelDesc):
def
_build_graph
(
self
,
inputs
):
input
,
nextinput
=
inputs
cell
=
tf
.
contrib
.
rnn
.
BasicLSTMCell
(
num_units
=
param
.
rnn_size
)
cell
=
tf
.
contrib
.
rnn
.
MultiRNNCell
([
cell
]
*
param
.
num_rnn_layer
)
cell
=
rnn
.
BasicLSTMCell
(
num_units
=
param
.
rnn_size
)
cell
=
rnn
.
MultiRNNCell
([
cell
]
*
param
.
num_rnn_layer
)
self
.
initial
=
initial
=
cell
.
zero_state
(
tf
.
shape
(
input
)[
0
],
tf
.
float32
)
def
get_v
(
n
):
ret
=
tf
.
get_variable
(
n
+
'_unused'
,
[
param
.
batch_size
,
param
.
rnn_size
],
trainable
=
False
,
initializer
=
tf
.
constant_initializer
())
ret
=
symbolic_functions
.
shapeless_placeholder
(
ret
,
0
,
name
=
n
)
return
ret
self
.
initial
=
initial
=
\
(
rnn
.
LSTMStateTuple
(
get_v
(
'c0'
),
get_v
(
'h0'
)),
rnn
.
LSTMStateTuple
(
get_v
(
'c1'
),
get_v
(
'h1'
)))
embeddingW
=
tf
.
get_variable
(
'embedding'
,
[
param
.
vocab_size
,
param
.
rnn_size
])
input_feature
=
tf
.
nn
.
embedding_lookup
(
embeddingW
,
input
)
# B x seqlen x rnnsize
...
...
@@ -78,13 +87,13 @@ class Model(ModelDesc):
input_list
=
tf
.
unstack
(
input_feature
,
axis
=
1
)
# seqlen x (Bxrnnsize)
# seqlen is 1 in inference. don't need loop_function
outputs
,
last_state
=
tf
.
contrib
.
rnn
.
static_rnn
(
cell
,
input_list
,
initial
,
scope
=
'rnnlm'
)
outputs
,
last_state
=
rnn
.
static_rnn
(
cell
,
input_list
,
initial
,
scope
=
'rnnlm'
)
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
# seqlen x (Bxrnnsize)
output
=
tf
.
reshape
(
tf
.
concat
(
outputs
,
1
),
[
-
1
,
param
.
rnn_size
])
# (Bxseqlen) x rnnsize
logits
=
FullyConnected
(
'fc'
,
output
,
param
.
vocab_size
,
nl
=
tf
.
identity
)
self
.
prob
=
tf
.
nn
.
softmax
(
logits
/
param
.
softmax_temprature
)
self
.
prob
=
tf
.
nn
.
softmax
(
logits
/
param
.
softmax_temprature
,
name
=
'prob'
)
xent_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
symbolic_functions
.
flatten
(
nextinput
))
...
...
@@ -130,19 +139,17 @@ def sample(path, start, length):
param
.
seq_len
=
1
ds
=
CharRNNData
(
param
.
corpus
,
100000
)
model
=
Model
()
inputs
=
model
.
get_reuse_placehdrs
()
model
.
build_graph
(
inputs
,
False
)
sess
=
tf
.
Session
()
tfutils
.
SaverRestore
(
path
)
.
init
(
sess
)
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
session_init
=
SaverRestore
(
path
),
input_names
=
[
'input'
,
'c0'
,
'h0'
,
'c1'
,
'h1'
],
output_names
=
[
'prob'
,
'last_state'
])
)
dummy_input
=
np
.
zeros
((
1
,
1
),
dtype
=
'int32'
)
with
sess
.
as_default
():
# feed the starting sentence
state
=
model
.
initial
.
eval
({
inputs
[
0
]:
dummy_input
}
)
initial
=
np
.
zeros
((
1
,
param
.
rnn_size
)
)
for
c
in
start
[:
-
1
]:
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)]],
dtype
=
'int32'
)
state
=
model
.
last_state
.
eval
({
inputs
[
0
]:
x
,
model
.
initial
:
state
}
)
_
,
state
=
pred
(
x
,
initial
,
initial
,
initial
,
initial
)
def
pick
(
prob
):
t
=
np
.
cumsum
(
prob
)
...
...
@@ -154,8 +161,7 @@ def sample(path, start, length):
c
=
start
[
-
1
]
for
k
in
range
(
length
):
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)]],
dtype
=
'int32'
)
[
prob
,
state
]
=
sess
.
run
([
model
.
prob
,
model
.
last_state
],
{
inputs
[
0
]:
x
,
model
.
initial
:
state
})
prob
,
state
=
pred
(
x
,
state
[
0
,
0
],
state
[
0
,
1
],
state
[
1
,
0
],
state
[
1
,
1
])
c
=
ds
.
lut
.
get_obj
(
pick
(
prob
[
0
]))
ret
+=
c
print
(
ret
)
...
...
tensorpack/RL/gymenv.py
View file @
879995d9
...
...
@@ -34,7 +34,7 @@ class GymEnv(RLEnvironment):
self
.
gymenv
=
gym
.
make
(
name
)
if
dumpdir
:
mkdir_p
(
dumpdir
)
self
.
gymenv
.
monitor
.
start
(
dumpdir
)
self
.
gymenv
=
gym
.
wrappers
.
Monitor
(
self
.
gymenv
,
dumpdir
)
self
.
use_dir
=
dumpdir
self
.
reset_stat
()
...
...
@@ -75,6 +75,7 @@ class GymEnv(RLEnvironment):
try
:
import
gym
import
gym.wrappers
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment