Initial commit

This commit is contained in:
fang 2025-03-31 16:32:17 +08:00
commit 0d05740a06
5 changed files with 493 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
*.pt

82
inference.py Normal file
View File

@ -0,0 +1,82 @@
import torch
import time
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
import argparse
import os
from glob import glob
from utils import LoadData
from model import UNet
def get_filenames(path):
valid_rgbs = sorted(glob(f'{path}/*'))
print(f'Validation samples: {len(valid_rgbs)}')
return valid_rgbs
def parse_args():
parser = argparse.ArgumentParser(description='RGB to RAW Inference')
parser.add_argument('--folder', type=str, default='/workspace/yx/rgb2RAW',
help='Input folder containing RGB images')
parser.add_argument('--output', type=str, default='./submission/',
help='Output folder for RAW predictions')
parser.add_argument('--model', type=str, default='model.pt',
help='Path to the model weights')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for inference')
return parser.parse_args()
def inference():
# Parse command line arguments
args = parse_args()
# Create output directory if it doesn't exist
os.makedirs(args.output, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = UNet(enc_chs=(3, 64, 128, 256),
dec_chs=(256, 128, 64),
out_ch=4,
out_sz=(512, 512))
model.load_state_dict(torch.load(args.model))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params_in_k = total_params / 1000
print(f"Parameter quantity{total_params_in_k:.2f}K")
model = model.to(device)
# Get test data
valid_rgbs = get_filenames(args.folder)
test_dataset = LoadData(root=args.folder, rgb_files=valid_rgbs, test=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=False)
runtime = []
model.eval()
with torch.no_grad():
for (rgb_batch, rgb_name) in tqdm(test_loader):
rgb_batch = rgb_batch.to(device)
rgb_name = rgb_name[0].split('/')[-1].replace('.png', '')
st = time.time()
recon_raw = model(rgb_batch)
tt = time.time() - st
runtime.append(tt)
recon_raw = recon_raw[0].detach().cpu().permute(1, 2, 0).numpy()
# Save as np.uint16
assert recon_raw.shape[-1] == 4
recon_raw = (recon_raw * (2**12-1)).astype(np.uint16)
np.save(os.path.join(args.output, f'{rgb_name}.npy'), recon_raw)
print(f"Average inference time: {np.mean(runtime):.3f}s")
print(f"Results saved to: {args.output}")
if __name__ == "__main__":
inference()

154
model.py Normal file
View File

@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class LayerNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
ctx.eps = eps
N, C, H, W = x.size()
mu = x.mean(1, keepdim=True)
var = (x - mu).pow(2).mean(1, keepdim=True)
y = (x - mu) / (var + eps).sqrt()
ctx.save_for_backward(y, var, weight)
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
N, C, H, W = grad_output.size()
y, var, weight = ctx.saved_variables
g = grad_output * weight.view(1, C, 1, 1)
mean_g = g.mean(dim=1, keepdim=True)
mean_gy = (g * y).mean(dim=1, keepdim=True)
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None
class LayerNorm2d(nn.Module):
def __init__(self, channels, eps=1e-6):
super(LayerNorm2d, self).__init__()
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.eps = eps
def forward(self, x):
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
class SimpleGate(nn.Module):
def forward(self, x):
x1, x2 = x.chunk(2, dim=1)
return x1 * x2
class NAFBlock(nn.Module):
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
super().__init__()
dw_channel = c * DW_Expand
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True)
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.sca = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True),
)
self.sg = SimpleGate()
ffn_channel = FFN_Expand * c
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
self.norm1 = LayerNorm2d(c)
self.norm2 = LayerNorm2d(c)
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
def forward(self, inp):
x = inp
x = self.norm1(x)
x = self.conv1(x)
x = self.conv2(x)
x = self.sg(x)
x = x * self.sca(x)
x = self.conv3(x)
x = self.dropout1(x)
y = inp + x * self.beta
x = self.conv4(self.norm2(y))
x = self.sg(x)
x = self.conv5(x)
x = self.dropout2(x)
return y + x * self.gamma
class Block(nn.Module):
def __init__(self, in_ch, out_ch, dw_expand=2, ffn_expand=2, dropout=0.):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
self.naf_block = NAFBlock(
c=out_ch,
DW_Expand=dw_expand,
FFN_Expand=ffn_expand,
drop_out_rate=dropout
)
def forward(self, x):
x = self.conv(x)
x = self.naf_block(x)
return x
class Encoder(nn.Module):
def __init__(self, chs=(3,64,128,256)):
super().__init__()
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
self.pool = nn.MaxPool2d(2)
def forward(self, x):
ftrs = []
for block in self.enc_blocks:
x = block(x)
ftrs.append(x)
x = self.pool(x)
return ftrs
class Decoder(nn.Module):
def __init__(self, chs=(256, 128, 64)):
super().__init__()
self.chs = chs
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
def forward(self, x, encoder_features):
for i in range(len(self.chs)-1):
x = self.upconvs[i](x)
enc_ftrs = encoder_features[i]
x = torch.cat([x, enc_ftrs], dim=1)
x = self.dec_blocks[i](x)
return x
class UNet(nn.Module):
def __init__(self, enc_chs=(3, 32, 64, 128), dec_chs=(128, 64, 32), out_ch=4, out_sz=(252, 252)):
super().__init__()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs)
self.head = nn.Conv2d(dec_chs[-1], out_ch, 1)
self.out_sz = out_sz
def forward(self, x):
enc_ftrs = self.encoder(x)
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
out = self.head(out)
out = F.interpolate(out, self.out_sz)
out = torch.clamp(out, min=0., max=1.)
return out
class hard_log_loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
loss = (-1 * torch.log(1 - torch.clamp(torch.abs(x - y),0,1) + 1e-6)).mean()
return loss

114
train.py Normal file
View File

@ -0,0 +1,114 @@
import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import defaultdict
from torch.utils.data import DataLoader
from tqdm import tqdm
from IPython import display
import matplotlib.pyplot as plt
from torchmetrics import StructuralSimilarityIndexMeasure
from utils import LoadData, get_filenames, postprocess_raw, demosaic
from model import UNet, hard_log_loss
class CFG:
encoder = (3, 64, 128, 256)
decoder = (256, 128, 64)
out_ch = 4
out_sz = (512, 512)
lr = 1e-4
lr_decay = 1e-6
epochs = 40
loss = nn.MSELoss()
name = 'unet-rev-isp.pt'
out_dir = '../'
save_freq = 5
def train():
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATH = '/workspace/yx/rgb2RAW' # Update this path
BATCH_TRAIN = 2
BATCH_TEST = 1
DEBUG = False
# Get data
train_raws, train_rgbs, valid_rgbs = get_filenames(PATH)
train_dataset = LoadData(root=PATH, rgb_files=train_rgbs, raw_files=train_raws, debug=DEBUG, test=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_TRAIN, shuffle=True, num_workers=0,
pin_memory=True, drop_last=True)
# Initialize model and training components
model = UNet(enc_chs=CFG.encoder, dec_chs=CFG.decoder, out_ch=CFG.out_ch, out_sz=CFG.out_sz).to(device)
opt = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.lr_decay)
scheduler = CosineAnnealingLR(opt, T_max=CFG.epochs)
criterion = CFG.loss
ssim_loss = StructuralSimilarityIndexMeasure().to(device)
hard_log_loss_fn = hard_log_loss().to(device)
metrics = defaultdict(list)
# Training loop
for epoch in range(CFG.epochs):
torch.cuda.empty_cache()
start_time = time.time()
train_loss = []
model.train()
for rgb_batch, raw_batch in tqdm(train_loader):
opt.zero_grad()
rgb_batch = rgb_batch.to(device)
raw_batch = raw_batch.to(device)
recon_raw = model(rgb_batch)
mse_loss = criterion(raw_batch, recon_raw)
ssim = ssim_loss(recon_raw, raw_batch)
ssim_loss_value = 1 - ssim
hard_log = hard_log_loss_fn(recon_raw, raw_batch)
combined_loss = mse_loss + 0.05*ssim_loss_value + 0.1*hard_log
combined_loss.backward()
opt.step()
train_loss.append(combined_loss.item())
metrics['train_loss'].append(np.mean(train_loss))
scheduler.step()
# Visualization
display.clear_output()
plt.figure(figsize=(20, 7))
ax1 = plt.subplot(2, 2, 1)
reconst_raw = postprocess_raw(demosaic(recon_raw[-1].detach().cpu().permute(1, 2, 0).numpy()))
gt_raw = postprocess_raw(demosaic(raw_batch[-1].detach().cpu().permute(1, 2, 0).numpy()))
cmp_raw_gt = np.concatenate([gt_raw, reconst_raw], axis=1)
ax1.imshow(cmp_raw_gt)
ax1.set_title('(left) GT / (right) Reconst. RAW')
ax2 = plt.subplot(2, 2, 3)
ax2.imshow(rgb_batch[-1].detach().cpu().permute(1, 2, 0).numpy())
ax2.set_title('RGB')
ax3 = plt.subplot(1, 2, 2)
ax3.plot(metrics['train_loss'], label='train')
ax3.set_xlabel('Epochs', fontsize=18)
ax3.set_ylabel('Loss', fontsize=18)
ax3.grid()
save_path = os.path.join(CFG.out_dir, f'epoch_{epoch + 1}_plot.png')
plt.savefig(save_path)
plt.show()
print(f"Epoch {epoch + 1} of {CFG.epochs} took {time.time() - start_time:.3f}s\n")
if ((epoch + 1) % CFG.save_freq == 0):
torch.save(model.state_dict(), os.path.join(CFG.out_dir, f'{epoch + 1}.pt'))
torch.save(model.state_dict(), os.path.join(CFG.out_dir, CFG.name))
if __name__ == "__main__":
train()

142
utils.py Normal file
View File

@ -0,0 +1,142 @@
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from glob import glob
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as FF
from torch.utils.data import Dataset, DataLoader
import random
def load_img(filename, debug=False, norm=True, resize=None):
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if norm:
img = img / 255.
img = img.astype(np.float32)
if debug:
print(img.shape, img.dtype, img.min(), img.max())
if resize:
img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA)
return img
def save_rgb(img, filename):
if np.max(img) <= 1:
img = img * 255
img = img.astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(filename, img)
def load_raw(raw, max_val=2**12-1):
raw = np.load(raw) / max_val
return raw.astype(np.float32)
def demosaic(raw):
"""Simple demosaicing to visualize RAW images"""
assert raw.shape[-1] == 4
shape = raw.shape
red = raw[:,:,0]
green_red = raw[:,:,1]
green_blue = raw[:,:,2]
blue = raw[:,:,3]
avg_green = (green_red + green_blue) / 2
image = np.stack((red, avg_green, blue), axis=-1)
image = cv2.resize(image, (shape[1]*2, shape[0]*2))
return image
def mosaic(rgb):
"""Extracts RGGB Bayer planes from an RGB image."""
assert rgb.shape[-1] == 3
shape = rgb.shape
red = rgb[0::2, 0::2, 0]
green_red = rgb[0::2, 1::2, 1]
green_blue = rgb[1::2, 0::2, 1]
blue = rgb[1::2, 1::2, 2]
image = np.stack((red, green_red, green_blue, blue), axis=-1)
return image
def gamma_compression(image):
"""Converts from linear to gamma space."""
return np.maximum(image, 1e-8) ** (1.0 / 2.2)
def tonemap(image):
"""Simple S-curved global tonemap"""
return (3*(image**2)) - (2*(image**3))
def postprocess_raw(raw):
"""Simple post-processing to visualize demosaic RAW images"""
raw = gamma_compression(raw)
raw = tonemap(raw)
raw = np.clip(raw, 0, 1)
return raw
def plot_pair(rgb, raw, t1='RGB', t2='RAW', axis='off'):
fig = plt.figure(figsize=(12, 6), dpi=80)
plt.subplot(1,2,1)
plt.title(t1)
plt.axis(axis)
plt.imshow(rgb)
plt.subplot(1,2,2)
plt.title(t2)
plt.axis(axis)
plt.imshow(raw)
plt.show()
def PSNR(y_true, y_pred):
mse = np.mean((y_true - y_pred) ** 2)
if(mse == 0):
return np.inf
max_pixel = np.max(y_true)
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
return psnr
def get_filenames(path):
devices = ['samsung-s9', 'iphone-x', 'lq-iphone']
train_raws, train_rgbs = [], []
for device in devices:
train_raw = sorted(glob(f'{path}/train/{device}/*.npy'))
train_rgb = sorted(glob(f'{path}/train/{device}/*.png'))
train_raws.extend(train_raw)
train_rgbs.extend(train_rgb)
valid_rgbs = sorted(glob(f'{path}/rgbs/*'))
assert len(train_raws) == len(train_rgbs)
print(f'Training samples: {len(train_raws)} \t Validation samples: {len(valid_rgbs)}')
return train_raws, train_rgbs, valid_rgbs
class LoadData(Dataset):
def __init__(self, root, rgb_files, raw_files=None, debug=False, test=None):
self.root = root
self.test = test
self.rgbs = sorted(rgb_files)
if self.test:
self.raws = None
else:
self.raws = sorted(raw_files)
self.debug = debug
if self.debug:
self.rgbs = self.rgbs[:100]
self.raws = self.raws[:100]
def __len__(self):
return len(self.rgbs)
def __getitem__(self, idx):
rgb = load_img(self.rgbs[idx], norm=True)
rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
if self.test:
return rgb, self.rgbs[idx]
else:
raw = load_raw(self.raws[idx])
raw = torch.from_numpy(raw.transpose((2, 0, 1)))
return rgb, raw