Commit f1a8acf4 authored by Yuxin Wu's avatar Yuxin Wu

make Image2Image support windows (fix #1412)

parent e50423db
...@@ -146,11 +146,12 @@ class Model(GANModelDesc): ...@@ -146,11 +146,12 @@ class Model(GANModelDesc):
return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3) return tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3)
def split_input(img): def split_input(dp):
""" """
img: an RGB image of shape (s, 2s, 3). dp: the datapoint. first component is an RGB image of shape (s, 2s, 3).
:return: [input, output] :return: [input, output]
""" """
img = dp[0]
# split the image into left + right pairs # split the image into left + right pairs
s = img.shape[0] s = img.shape[0]
assert img.shape[1] == 2 * s assert img.shape[1] == 2 * s
...@@ -169,7 +170,7 @@ def get_data(): ...@@ -169,7 +170,7 @@ def get_data():
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0])) ds = MapData(ds, split_input)
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)] augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
...@@ -186,7 +187,7 @@ def sample(datadir, model_path): ...@@ -186,7 +187,7 @@ def sample(datadir, model_path):
imgs = glob.glob(os.path.join(datadir, '*.jpg')) imgs = glob.glob(os.path.join(datadir, '*.jpg'))
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = MapData(ds, lambda dp: split_input(dp[0])) ds = MapData(ds, split_input)
ds = AugmentImageComponents(ds, [imgaug.Resize(256)], (0, 1)) ds = AugmentImageComponents(ds, [imgaug.Resize(256)], (0, 1))
ds = BatchData(ds, 6) ds = BatchData(ds, 6)
......
# Copyright (c) Tensorpack Contributors. All Rights Reserved
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: batch_norm.py # File: batch_norm.py
......
...@@ -201,7 +201,7 @@ def collect_env_info(): ...@@ -201,7 +201,7 @@ def collect_env_info():
# List devices with NVML # List devices with NVML
data.append( data.append(
("CUDA_VISIBLE_DEVICES", ("CUDA_VISIBLE_DEVICES",
os.environ.get("CUDA_VISIBLE_DEVICES", str(None)))) os.environ.get("CUDA_VISIBLE_DEVICES", "Unspecified")))
try: try:
devs = defaultdict(list) devs = defaultdict(list)
with NVMLContext() as ctx: with NVMLContext() as ctx:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment