sRGB2RAW/utils.py
2025-03-31 16:32:17 +08:00

142 lines
4.0 KiB
Python

import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from glob import glob
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as FF
from torch.utils.data import Dataset, DataLoader
import random
def load_img(filename, debug=False, norm=True, resize=None):
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if norm:
img = img / 255.
img = img.astype(np.float32)
if debug:
print(img.shape, img.dtype, img.min(), img.max())
if resize:
img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA)
return img
def save_rgb(img, filename):
if np.max(img) <= 1:
img = img * 255
img = img.astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(filename, img)
def load_raw(raw, max_val=2**12-1):
raw = np.load(raw) / max_val
return raw.astype(np.float32)
def demosaic(raw):
"""Simple demosaicing to visualize RAW images"""
assert raw.shape[-1] == 4
shape = raw.shape
red = raw[:,:,0]
green_red = raw[:,:,1]
green_blue = raw[:,:,2]
blue = raw[:,:,3]
avg_green = (green_red + green_blue) / 2
image = np.stack((red, avg_green, blue), axis=-1)
image = cv2.resize(image, (shape[1]*2, shape[0]*2))
return image
def mosaic(rgb):
"""Extracts RGGB Bayer planes from an RGB image."""
assert rgb.shape[-1] == 3
shape = rgb.shape
red = rgb[0::2, 0::2, 0]
green_red = rgb[0::2, 1::2, 1]
green_blue = rgb[1::2, 0::2, 1]
blue = rgb[1::2, 1::2, 2]
image = np.stack((red, green_red, green_blue, blue), axis=-1)
return image
def gamma_compression(image):
"""Converts from linear to gamma space."""
return np.maximum(image, 1e-8) ** (1.0 / 2.2)
def tonemap(image):
"""Simple S-curved global tonemap"""
return (3*(image**2)) - (2*(image**3))
def postprocess_raw(raw):
"""Simple post-processing to visualize demosaic RAW images"""
raw = gamma_compression(raw)
raw = tonemap(raw)
raw = np.clip(raw, 0, 1)
return raw
def plot_pair(rgb, raw, t1='RGB', t2='RAW', axis='off'):
fig = plt.figure(figsize=(12, 6), dpi=80)
plt.subplot(1,2,1)
plt.title(t1)
plt.axis(axis)
plt.imshow(rgb)
plt.subplot(1,2,2)
plt.title(t2)
plt.axis(axis)
plt.imshow(raw)
plt.show()
def PSNR(y_true, y_pred):
mse = np.mean((y_true - y_pred) ** 2)
if(mse == 0):
return np.inf
max_pixel = np.max(y_true)
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
return psnr
def get_filenames(path):
devices = ['samsung-s9', 'iphone-x', 'lq-iphone']
train_raws, train_rgbs = [], []
for device in devices:
train_raw = sorted(glob(f'{path}/train/{device}/*.npy'))
train_rgb = sorted(glob(f'{path}/train/{device}/*.png'))
train_raws.extend(train_raw)
train_rgbs.extend(train_rgb)
valid_rgbs = sorted(glob(f'{path}/rgbs/*'))
assert len(train_raws) == len(train_rgbs)
print(f'Training samples: {len(train_raws)} \t Validation samples: {len(valid_rgbs)}')
return train_raws, train_rgbs, valid_rgbs
class LoadData(Dataset):
def __init__(self, root, rgb_files, raw_files=None, debug=False, test=None):
self.root = root
self.test = test
self.rgbs = sorted(rgb_files)
if self.test:
self.raws = None
else:
self.raws = sorted(raw_files)
self.debug = debug
if self.debug:
self.rgbs = self.rgbs[:100]
self.raws = self.raws[:100]
def __len__(self):
return len(self.rgbs)
def __getitem__(self, idx):
rgb = load_img(self.rgbs[idx], norm=True)
rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
if self.test:
return rgb, self.rgbs[idx]
else:
raw = load_raw(self.raws[idx])
raw = torch.from_numpy(raw.transpose((2, 0, 1)))
return rgb, raw