Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
01cab873
Commit
01cab873
authored
Jul 03, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bugfix in char-rnn example (fix #323)
parent
8faea40d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
examples/Char-RNN/char-rnn.py
examples/Char-RNN/char-rnn.py
+5
-5
No files found.
examples/Char-RNN/char-rnn.py
View file @
01cab873
...
@@ -49,8 +49,8 @@ class CharRNNData(RNGDataFlow):
...
@@ -49,8 +49,8 @@ class CharRNNData(RNGDataFlow):
print
(
sorted
(
self
.
chars
))
print
(
sorted
(
self
.
chars
))
self
.
vocab_size
=
len
(
self
.
chars
)
self
.
vocab_size
=
len
(
self
.
chars
)
param
.
vocab_size
=
self
.
vocab_size
param
.
vocab_size
=
self
.
vocab_size
char2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
self
.
chars
)}
self
.
char2idx
=
{
c
:
i
for
i
,
c
in
enumerate
(
self
.
chars
)}
self
.
whole_seq
=
np
.
array
([
char2idx
[
c
]
for
c
in
data
],
dtype
=
'int32'
)
self
.
whole_seq
=
np
.
array
([
self
.
char2idx
[
c
]
for
c
in
data
],
dtype
=
'int32'
)
logger
.
info
(
"Corpus loaded. Vocab size: {}"
.
format
(
self
.
vocab_size
))
logger
.
info
(
"Corpus loaded. Vocab size: {}"
.
format
(
self
.
vocab_size
))
def
size
(
self
):
def
size
(
self
):
...
@@ -146,7 +146,7 @@ def sample(path, start, length):
...
@@ -146,7 +146,7 @@ def sample(path, start, length):
# feed the starting sentence
# feed the starting sentence
initial
=
np
.
zeros
((
1
,
param
.
rnn_size
))
initial
=
np
.
zeros
((
1
,
param
.
rnn_size
))
for
c
in
start
[:
-
1
]:
for
c
in
start
[:
-
1
]:
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)
]],
dtype
=
'int32'
)
x
=
np
.
array
([[
ds
.
char2idx
[
c
]
]],
dtype
=
'int32'
)
_
,
state
=
pred
(
x
,
initial
,
initial
,
initial
,
initial
)
_
,
state
=
pred
(
x
,
initial
,
initial
,
initial
,
initial
)
def
pick
(
prob
):
def
pick
(
prob
):
...
@@ -158,9 +158,9 @@ def sample(path, start, length):
...
@@ -158,9 +158,9 @@ def sample(path, start, length):
ret
=
start
ret
=
start
c
=
start
[
-
1
]
c
=
start
[
-
1
]
for
k
in
range
(
length
):
for
k
in
range
(
length
):
x
=
np
.
array
([[
ds
.
lut
.
get_idx
(
c
)
]],
dtype
=
'int32'
)
x
=
np
.
array
([[
ds
.
char2idx
[
c
]
]],
dtype
=
'int32'
)
prob
,
state
=
pred
(
x
,
state
[
0
,
0
],
state
[
0
,
1
],
state
[
1
,
0
],
state
[
1
,
1
])
prob
,
state
=
pred
(
x
,
state
[
0
,
0
],
state
[
0
,
1
],
state
[
1
,
0
],
state
[
1
,
1
])
c
=
ds
.
lut
.
get_obj
(
pick
(
prob
[
0
]))
c
=
ds
.
chars
[
pick
(
prob
[
0
])]
ret
+=
c
ret
+=
c
print
(
ret
)
print
(
ret
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment