Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bf9da6d5
Commit
bf9da6d5
authored
Nov 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix build
parent
ab2cd7e6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
10 deletions
+14
-10
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+1
-1
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+13
-9
No files found.
examples/GAN/InfoGAN-mnist.py
View file @
bf9da6d5
...
...
@@ -247,7 +247,7 @@ if __name__ == '__main__':
logger
.
auto_set_dir
()
GANTrainer
(
QueueInput
(
get_data
()),
Model
())
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(
keep_
freq
=
0.1
)],
callbacks
=
[
ModelSaver
(
keep_
checkpoint_every_n_hours
=
0.1
)],
steps_per_epoch
=
500
,
max_epoch
=
100
,
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
load
else
None
...
...
tensorpack/dataflow/raw.py
View file @
bf9da6d5
...
...
@@ -8,6 +8,7 @@ import copy
import
six
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
from
..utils.develop
import
log_deprecated
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
,
'DataFromGenerator'
]
...
...
@@ -97,18 +98,21 @@ class DataFromList(RNGDataFlow):
class
DataFromGenerator
(
DataFlow
):
"""
Wrap a generator to a DataFlow
Wrap a generator to a DataFlow
.
"""
def
__init__
(
self
,
gen
,
size
=
None
):
self
.
_gen
=
gen
self
.
_size
=
size
def
size
(
self
):
if
self
.
_size
:
return
self
.
_size
return
super
(
DataFromGenerator
,
self
)
.
size
()
"""
Args:
gen: iterable, or a callable that returns an iterable
"""
if
not
callable
(
gen
):
self
.
_gen
=
lambda
:
gen
else
:
self
.
_gen
=
gen
if
size
is
not
None
:
log_deprecated
(
"DataFromGenerator(size=)"
,
"It doesn't make much sense."
)
def
get_data
(
self
):
# yield from
for
dp
in
self
.
_gen
:
for
dp
in
self
.
_gen
()
:
yield
dp
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment