Commit 70802354 authored by Yuxin Wu's avatar Yuxin Wu

misc. update about docs

parent 5140610d
......@@ -62,7 +62,7 @@ def nms_fastrcnn_results(boxes, probs):
continue
probs_k = probs[ids, klass].flatten()
boxes_k = boxes[ids, :]
selected_ids = nms_func(boxes_k[:, [1, 0, 3, 2]], probs_k)
selected_ids = nms_func(boxes_k, probs_k)
selected_boxes = boxes_k[selected_ids, :].copy()
ret.append(DetectionResult(klass, selected_boxes, probs_k[selected_ids]))
......
......@@ -71,8 +71,9 @@ class GANTrainer(TowerTrainer):
inputs_desc = model.get_inputs_desc()
cbs = input.setup(inputs_desc)
tower_func = TowerFuncWrapper(
model.build_graph, inputs_desc)
# we need to set towerfunc because it's a TowerTrainer,
# and only TowerTrainer supports automatic graph creation for inference during training.
tower_func = TowerFuncWrapper(model.build_graph, inputs_desc)
with TowerContext('', is_training=True):
tower_func(*input.get_input_tensors())
opt = model.get_optimizer()
......@@ -138,6 +139,7 @@ class MultiGPUGANTrainer(TowerTrainer):
input = StagingInput(input, list(range(nr_gpu)))
cbs = input.setup(model.get_inputs_desc())
# build the graph
def get_cost(*inputs):
model.build_graph(inputs)
return [model.d_loss, model.g_loss]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment