Commit b8c7b6f4 authored by Yuxin Wu's avatar Yuxin Wu

fix wrong use of len(data) in examples

parent 193f6056
...@@ -223,6 +223,6 @@ if __name__ == '__main__': ...@@ -223,6 +223,6 @@ if __name__ == '__main__':
PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3), PeriodicTrigger(VisualizeTestSet(), every_k_epochs=3),
], ],
max_epoch=195, max_epoch=195,
steps_per_epoch=len(data), steps_per_epoch=data.size(),
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
...@@ -223,7 +223,7 @@ if __name__ == '__main__': ...@@ -223,7 +223,7 @@ if __name__ == '__main__':
PeriodicTrigger(ModelSaver(), every_k_epochs=3), PeriodicTrigger(ModelSaver(), every_k_epochs=3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
], ],
steps_per_epoch=len(data), steps_per_epoch=data.size(),
max_epoch=300, max_epoch=300,
session_init=SaverRestore(args.load) if args.load else None session_init=SaverRestore(args.load) if args.load else None
) )
...@@ -298,6 +298,6 @@ if __name__ == '__main__': ...@@ -298,6 +298,6 @@ if __name__ == '__main__':
ModelSaver(keep_checkpoint_every_n_hours=2) ModelSaver(keep_checkpoint_every_n_hours=2)
], ],
session_init=session_init, session_init=session_init,
steps_per_epoch=len(data) // 4, steps_per_epoch=data.size() // 4,
max_epoch=300 max_epoch=300
) )
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment