142 lines
4.0 KiB
Python
142 lines
4.0 KiB
Python
import os
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import cv2
|
|
from glob import glob
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchvision.transforms.functional as FF
|
|
from torch.utils.data import Dataset, DataLoader
|
|
import random
|
|
|
|
def load_img(filename, debug=False, norm=True, resize=None):
|
|
img = cv2.imread(filename)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
if norm:
|
|
img = img / 255.
|
|
img = img.astype(np.float32)
|
|
if debug:
|
|
print(img.shape, img.dtype, img.min(), img.max())
|
|
|
|
if resize:
|
|
img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA)
|
|
|
|
return img
|
|
|
|
def save_rgb(img, filename):
|
|
if np.max(img) <= 1:
|
|
img = img * 255
|
|
|
|
img = img.astype(np.float32)
|
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
|
cv2.imwrite(filename, img)
|
|
|
|
def load_raw(raw, max_val=2**12-1):
|
|
raw = np.load(raw) / max_val
|
|
return raw.astype(np.float32)
|
|
|
|
def demosaic(raw):
|
|
"""Simple demosaicing to visualize RAW images"""
|
|
assert raw.shape[-1] == 4
|
|
shape = raw.shape
|
|
|
|
red = raw[:,:,0]
|
|
green_red = raw[:,:,1]
|
|
green_blue = raw[:,:,2]
|
|
blue = raw[:,:,3]
|
|
avg_green = (green_red + green_blue) / 2
|
|
image = np.stack((red, avg_green, blue), axis=-1)
|
|
image = cv2.resize(image, (shape[1]*2, shape[0]*2))
|
|
return image
|
|
|
|
def mosaic(rgb):
|
|
"""Extracts RGGB Bayer planes from an RGB image."""
|
|
assert rgb.shape[-1] == 3
|
|
shape = rgb.shape
|
|
|
|
red = rgb[0::2, 0::2, 0]
|
|
green_red = rgb[0::2, 1::2, 1]
|
|
green_blue = rgb[1::2, 0::2, 1]
|
|
blue = rgb[1::2, 1::2, 2]
|
|
|
|
image = np.stack((red, green_red, green_blue, blue), axis=-1)
|
|
return image
|
|
|
|
def gamma_compression(image):
|
|
"""Converts from linear to gamma space."""
|
|
return np.maximum(image, 1e-8) ** (1.0 / 2.2)
|
|
|
|
def tonemap(image):
|
|
"""Simple S-curved global tonemap"""
|
|
return (3*(image**2)) - (2*(image**3))
|
|
|
|
def postprocess_raw(raw):
|
|
"""Simple post-processing to visualize demosaic RAW images"""
|
|
raw = gamma_compression(raw)
|
|
raw = tonemap(raw)
|
|
raw = np.clip(raw, 0, 1)
|
|
return raw
|
|
|
|
def plot_pair(rgb, raw, t1='RGB', t2='RAW', axis='off'):
|
|
fig = plt.figure(figsize=(12, 6), dpi=80)
|
|
plt.subplot(1,2,1)
|
|
plt.title(t1)
|
|
plt.axis(axis)
|
|
plt.imshow(rgb)
|
|
|
|
plt.subplot(1,2,2)
|
|
plt.title(t2)
|
|
plt.axis(axis)
|
|
plt.imshow(raw)
|
|
plt.show()
|
|
|
|
def PSNR(y_true, y_pred):
|
|
mse = np.mean((y_true - y_pred) ** 2)
|
|
if(mse == 0):
|
|
return np.inf
|
|
|
|
max_pixel = np.max(y_true)
|
|
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
|
|
return psnr
|
|
|
|
def get_filenames(path):
|
|
devices = ['samsung-s9', 'iphone-x', 'lq-iphone']
|
|
train_raws, train_rgbs = [], []
|
|
for device in devices:
|
|
train_raw = sorted(glob(f'{path}/train/{device}/*.npy'))
|
|
train_rgb = sorted(glob(f'{path}/train/{device}/*.png'))
|
|
train_raws.extend(train_raw)
|
|
train_rgbs.extend(train_rgb)
|
|
valid_rgbs = sorted(glob(f'{path}/rgbs/*'))
|
|
assert len(train_raws) == len(train_rgbs)
|
|
print(f'Training samples: {len(train_raws)} \t Validation samples: {len(valid_rgbs)}')
|
|
return train_raws, train_rgbs, valid_rgbs
|
|
|
|
class LoadData(Dataset):
|
|
def __init__(self, root, rgb_files, raw_files=None, debug=False, test=None):
|
|
self.root = root
|
|
self.test = test
|
|
self.rgbs = sorted(rgb_files)
|
|
if self.test:
|
|
self.raws = None
|
|
else:
|
|
self.raws = sorted(raw_files)
|
|
|
|
self.debug = debug
|
|
if self.debug:
|
|
self.rgbs = self.rgbs[:100]
|
|
self.raws = self.raws[:100]
|
|
|
|
def __len__(self):
|
|
return len(self.rgbs)
|
|
|
|
def __getitem__(self, idx):
|
|
rgb = load_img(self.rgbs[idx], norm=True)
|
|
rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
|
|
|
|
if self.test:
|
|
return rgb, self.rgbs[idx]
|
|
else:
|
|
raw = load_raw(self.raws[idx])
|
|
raw = torch.from_numpy(raw.transpose((2, 0, 1)))
|
|
return rgb, raw |