Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
aed3438b
Commit
aed3438b
authored
May 01, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use rnn.rnn
parent
a3d6c93d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
5 deletions
+7
-5
examples/char-rnn.py
examples/char-rnn.py
+7
-5
No files found.
examples/char-rnn.py
View file @
aed3438b
...
...
@@ -19,8 +19,8 @@ from tensorpack.tfutils.gradproc import *
from
tensorpack.utils.lut
import
LookUpTable
from
tensorpack.callbacks
import
*
from
tensorflow.
models.rnn
import
rnn_cell
from
tensorflow.
models.rnn
import
seq2seq
from
tensorflow.
python.ops
import
rnn_cell
from
tensorflow.
python.ops
import
rnn
if
six
.
PY2
:
class
NS
:
pass
# this is a hack
...
...
@@ -30,7 +30,7 @@ else:
param
=
NS
()
# some model hyperparams to set
param
.
batch_size
=
128
param
.
rnn_size
=
128
param
.
rnn_size
=
256
param
.
num_rnn_layer
=
2
param
.
seq_len
=
50
param
.
grad_clip
=
5.
...
...
@@ -90,7 +90,7 @@ class Model(ModelDesc):
input_list
=
[
tf
.
squeeze
(
x
,
[
1
])
for
x
in
input_list
]
# seqlen is 1 in inference. don't need loop_function
outputs
,
last_state
=
seq2seq
.
rnn_decoder
(
input_list
,
initial
,
cel
l
,
scope
=
'rnnlm'
)
outputs
,
last_state
=
rnn
.
rnn
(
cell
,
input_list
,
initia
l
,
scope
=
'rnnlm'
)
self
.
last_state
=
tf
.
identity
(
last_state
,
'last_state'
)
# seqlen x (Bxrnnsize)
output
=
tf
.
reshape
(
tf
.
concat
(
1
,
outputs
),
[
-
1
,
param
.
rnn_size
])
# (seqlenxB) x rnnsize
...
...
@@ -125,7 +125,8 @@ def get_config():
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
HumanHyperParamSetter
(
'learning_rate'
,
'hyper.txt'
)
#HumanHyperParamSetter('learning_rate', 'hyper.txt')
SeduledHyperParamSetter
(
'learning_rate'
,
[(
25
,
2e-4
)])
]),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
...
...
@@ -184,6 +185,7 @@ if __name__ == '__main__':
default
=
'The '
,
help
=
'initial text sequence'
)
parser_sample
.
add_argument
(
'-t'
,
'--temperature'
,
type
=
float
,
default
=
1
,
help
=
'softmax temperature'
)
parser_train
=
subparsers
.
add_parser
(
'train'
,
help
=
'train'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment