commit 0d05740a0689c9d8121af04f13ee9e8af925b285 Author: fang <2627729817@qq.com> Date: Mon Mar 31 16:32:17 2025 +0800 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eede66d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +*.pt \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..d25efcb --- /dev/null +++ b/inference.py @@ -0,0 +1,82 @@ +import torch +import time +from tqdm import tqdm +import numpy as np +from torch.utils.data import DataLoader +import argparse +import os +from glob import glob +from utils import LoadData +from model import UNet + +def get_filenames(path): + valid_rgbs = sorted(glob(f'{path}/*')) + print(f'Validation samples: {len(valid_rgbs)}') + return valid_rgbs + +def parse_args(): + parser = argparse.ArgumentParser(description='RGB to RAW Inference') + parser.add_argument('--folder', type=str, default='/workspace/yx/rgb2RAW', + help='Input folder containing RGB images') + parser.add_argument('--output', type=str, default='./submission/', + help='Output folder for RAW predictions') + parser.add_argument('--model', type=str, default='model.pt', + help='Path to the model weights') + parser.add_argument('--batch_size', type=int, default=1, + help='Batch size for inference') + return parser.parse_args() + +def inference(): + # Parse command line arguments + args = parse_args() + + # Create output directory if it doesn't exist + os.makedirs(args.output, exist_ok=True) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # Load model + model = UNet(enc_chs=(3, 64, 128, 256), + dec_chs=(256, 128, 64), + out_ch=4, + out_sz=(512, 512)) + model.load_state_dict(torch.load(args.model)) + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + total_params_in_k = total_params / 1000 + print(f"Parameter quantity:{total_params_in_k:.2f}K") + model = model.to(device) + + # Get test data + valid_rgbs = get_filenames(args.folder) + test_dataset = LoadData(root=args.folder, rgb_files=valid_rgbs, test=True) + test_loader = DataLoader(dataset=test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + drop_last=False) + + runtime = [] + model.eval() + with torch.no_grad(): + for (rgb_batch, rgb_name) in tqdm(test_loader): + rgb_batch = rgb_batch.to(device) + rgb_name = rgb_name[0].split('/')[-1].replace('.png', '') + + st = time.time() + recon_raw = model(rgb_batch) + tt = time.time() - st + runtime.append(tt) + + recon_raw = recon_raw[0].detach().cpu().permute(1, 2, 0).numpy() + + # Save as np.uint16 + assert recon_raw.shape[-1] == 4 + recon_raw = (recon_raw * (2**12-1)).astype(np.uint16) + np.save(os.path.join(args.output, f'{rgb_name}.npy'), recon_raw) + + print(f"Average inference time: {np.mean(runtime):.3f}s") + print(f"Results saved to: {args.output}") + +if __name__ == "__main__": + inference() \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..f040d92 --- /dev/null +++ b/model.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter('weight', nn.Parameter(torch.ones(channels))) + self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + +class SimpleGate(nn.Module): + def forward(self, x): + x1, x2 = x.chunk(2, dim=1) + return x1 * x2 + +class NAFBlock(nn.Module): + def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): + super().__init__() + dw_channel = c * DW_Expand + self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True) + self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.sca = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True), + ) + + self.sg = SimpleGate() + + ffn_channel = FFN_Expand * c + self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) + + self.norm1 = LayerNorm2d(c) + self.norm2 = LayerNorm2d(c) + + self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() + + self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) + + def forward(self, inp): + x = inp + x = self.norm1(x) + x = self.conv1(x) + x = self.conv2(x) + x = self.sg(x) + x = x * self.sca(x) + x = self.conv3(x) + x = self.dropout1(x) + y = inp + x * self.beta + x = self.conv4(self.norm2(y)) + x = self.sg(x) + x = self.conv5(x) + x = self.dropout2(x) + return y + x * self.gamma + +class Block(nn.Module): + def __init__(self, in_ch, out_ch, dw_expand=2, ffn_expand=2, dropout=0.): + super().__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False) + self.naf_block = NAFBlock( + c=out_ch, + DW_Expand=dw_expand, + FFN_Expand=ffn_expand, + drop_out_rate=dropout + ) + + def forward(self, x): + x = self.conv(x) + x = self.naf_block(x) + return x + +class Encoder(nn.Module): + def __init__(self, chs=(3,64,128,256)): + super().__init__() + self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) + self.pool = nn.MaxPool2d(2) + + def forward(self, x): + ftrs = [] + for block in self.enc_blocks: + x = block(x) + ftrs.append(x) + x = self.pool(x) + return ftrs + +class Decoder(nn.Module): + def __init__(self, chs=(256, 128, 64)): + super().__init__() + self.chs = chs + self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)]) + self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) + + def forward(self, x, encoder_features): + for i in range(len(self.chs)-1): + x = self.upconvs[i](x) + enc_ftrs = encoder_features[i] + x = torch.cat([x, enc_ftrs], dim=1) + x = self.dec_blocks[i](x) + return x + +class UNet(nn.Module): + def __init__(self, enc_chs=(3, 32, 64, 128), dec_chs=(128, 64, 32), out_ch=4, out_sz=(252, 252)): + super().__init__() + self.encoder = Encoder(enc_chs) + self.decoder = Decoder(dec_chs) + self.head = nn.Conv2d(dec_chs[-1], out_ch, 1) + self.out_sz = out_sz + + def forward(self, x): + enc_ftrs = self.encoder(x) + out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:]) + out = self.head(out) + out = F.interpolate(out, self.out_sz) + out = torch.clamp(out, min=0., max=1.) + return out + +class hard_log_loss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + loss = (-1 * torch.log(1 - torch.clamp(torch.abs(x - y),0,1) + 1e-6)).mean() + return loss \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..bbebe50 --- /dev/null +++ b/train.py @@ -0,0 +1,114 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from collections import defaultdict + +from torch.utils.data import DataLoader +from tqdm import tqdm +from IPython import display +import matplotlib.pyplot as plt +from torchmetrics import StructuralSimilarityIndexMeasure + +from utils import LoadData, get_filenames, postprocess_raw, demosaic +from model import UNet, hard_log_loss + +class CFG: + encoder = (3, 64, 128, 256) + decoder = (256, 128, 64) + out_ch = 4 + out_sz = (512, 512) + lr = 1e-4 + lr_decay = 1e-6 + epochs = 40 + loss = nn.MSELoss() + name = 'unet-rev-isp.pt' + out_dir = '../' + save_freq = 5 + +def train(): + # Setup + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + PATH = '/workspace/yx/rgb2RAW' # Update this path + BATCH_TRAIN = 2 + BATCH_TEST = 1 + DEBUG = False + + # Get data + train_raws, train_rgbs, valid_rgbs = get_filenames(PATH) + train_dataset = LoadData(root=PATH, rgb_files=train_rgbs, raw_files=train_raws, debug=DEBUG, test=False) + train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_TRAIN, shuffle=True, num_workers=0, + pin_memory=True, drop_last=True) + + # Initialize model and training components + model = UNet(enc_chs=CFG.encoder, dec_chs=CFG.decoder, out_ch=CFG.out_ch, out_sz=CFG.out_sz).to(device) + opt = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.lr_decay) + scheduler = CosineAnnealingLR(opt, T_max=CFG.epochs) + criterion = CFG.loss + ssim_loss = StructuralSimilarityIndexMeasure().to(device) + hard_log_loss_fn = hard_log_loss().to(device) + metrics = defaultdict(list) + + # Training loop + for epoch in range(CFG.epochs): + torch.cuda.empty_cache() + start_time = time.time() + train_loss = [] + model.train() + + for rgb_batch, raw_batch in tqdm(train_loader): + opt.zero_grad() + rgb_batch = rgb_batch.to(device) + raw_batch = raw_batch.to(device) + + recon_raw = model(rgb_batch) + + mse_loss = criterion(raw_batch, recon_raw) + ssim = ssim_loss(recon_raw, raw_batch) + ssim_loss_value = 1 - ssim + hard_log = hard_log_loss_fn(recon_raw, raw_batch) + combined_loss = mse_loss + 0.05*ssim_loss_value + 0.1*hard_log + + combined_loss.backward() + opt.step() + train_loss.append(combined_loss.item()) + + metrics['train_loss'].append(np.mean(train_loss)) + scheduler.step() + + # Visualization + display.clear_output() + plt.figure(figsize=(20, 7)) + ax1 = plt.subplot(2, 2, 1) + reconst_raw = postprocess_raw(demosaic(recon_raw[-1].detach().cpu().permute(1, 2, 0).numpy())) + gt_raw = postprocess_raw(demosaic(raw_batch[-1].detach().cpu().permute(1, 2, 0).numpy())) + cmp_raw_gt = np.concatenate([gt_raw, reconst_raw], axis=1) + ax1.imshow(cmp_raw_gt) + ax1.set_title('(left) GT / (right) Reconst. RAW') + + ax2 = plt.subplot(2, 2, 3) + ax2.imshow(rgb_batch[-1].detach().cpu().permute(1, 2, 0).numpy()) + ax2.set_title('RGB') + + ax3 = plt.subplot(1, 2, 2) + ax3.plot(metrics['train_loss'], label='train') + ax3.set_xlabel('Epochs', fontsize=18) + ax3.set_ylabel('Loss', fontsize=18) + ax3.grid() + + save_path = os.path.join(CFG.out_dir, f'epoch_{epoch + 1}_plot.png') + plt.savefig(save_path) + plt.show() + + print(f"Epoch {epoch + 1} of {CFG.epochs} took {time.time() - start_time:.3f}s\n") + + if ((epoch + 1) % CFG.save_freq == 0): + torch.save(model.state_dict(), os.path.join(CFG.out_dir, f'{epoch + 1}.pt')) + + torch.save(model.state_dict(), os.path.join(CFG.out_dir, CFG.name)) + +if __name__ == "__main__": + train() \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..b0871eb --- /dev/null +++ b/utils.py @@ -0,0 +1,142 @@ +import os +import numpy as np +import matplotlib.pyplot as plt +import cv2 +from glob import glob +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as FF +from torch.utils.data import Dataset, DataLoader +import random + +def load_img(filename, debug=False, norm=True, resize=None): + img = cv2.imread(filename) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if norm: + img = img / 255. + img = img.astype(np.float32) + if debug: + print(img.shape, img.dtype, img.min(), img.max()) + + if resize: + img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA) + + return img + +def save_rgb(img, filename): + if np.max(img) <= 1: + img = img * 255 + + img = img.astype(np.float32) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + cv2.imwrite(filename, img) + +def load_raw(raw, max_val=2**12-1): + raw = np.load(raw) / max_val + return raw.astype(np.float32) + +def demosaic(raw): + """Simple demosaicing to visualize RAW images""" + assert raw.shape[-1] == 4 + shape = raw.shape + + red = raw[:,:,0] + green_red = raw[:,:,1] + green_blue = raw[:,:,2] + blue = raw[:,:,3] + avg_green = (green_red + green_blue) / 2 + image = np.stack((red, avg_green, blue), axis=-1) + image = cv2.resize(image, (shape[1]*2, shape[0]*2)) + return image + +def mosaic(rgb): + """Extracts RGGB Bayer planes from an RGB image.""" + assert rgb.shape[-1] == 3 + shape = rgb.shape + + red = rgb[0::2, 0::2, 0] + green_red = rgb[0::2, 1::2, 1] + green_blue = rgb[1::2, 0::2, 1] + blue = rgb[1::2, 1::2, 2] + + image = np.stack((red, green_red, green_blue, blue), axis=-1) + return image + +def gamma_compression(image): + """Converts from linear to gamma space.""" + return np.maximum(image, 1e-8) ** (1.0 / 2.2) + +def tonemap(image): + """Simple S-curved global tonemap""" + return (3*(image**2)) - (2*(image**3)) + +def postprocess_raw(raw): + """Simple post-processing to visualize demosaic RAW images""" + raw = gamma_compression(raw) + raw = tonemap(raw) + raw = np.clip(raw, 0, 1) + return raw + +def plot_pair(rgb, raw, t1='RGB', t2='RAW', axis='off'): + fig = plt.figure(figsize=(12, 6), dpi=80) + plt.subplot(1,2,1) + plt.title(t1) + plt.axis(axis) + plt.imshow(rgb) + + plt.subplot(1,2,2) + plt.title(t2) + plt.axis(axis) + plt.imshow(raw) + plt.show() + +def PSNR(y_true, y_pred): + mse = np.mean((y_true - y_pred) ** 2) + if(mse == 0): + return np.inf + + max_pixel = np.max(y_true) + psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) + return psnr + +def get_filenames(path): + devices = ['samsung-s9', 'iphone-x', 'lq-iphone'] + train_raws, train_rgbs = [], [] + for device in devices: + train_raw = sorted(glob(f'{path}/train/{device}/*.npy')) + train_rgb = sorted(glob(f'{path}/train/{device}/*.png')) + train_raws.extend(train_raw) + train_rgbs.extend(train_rgb) + valid_rgbs = sorted(glob(f'{path}/rgbs/*')) + assert len(train_raws) == len(train_rgbs) + print(f'Training samples: {len(train_raws)} \t Validation samples: {len(valid_rgbs)}') + return train_raws, train_rgbs, valid_rgbs + +class LoadData(Dataset): + def __init__(self, root, rgb_files, raw_files=None, debug=False, test=None): + self.root = root + self.test = test + self.rgbs = sorted(rgb_files) + if self.test: + self.raws = None + else: + self.raws = sorted(raw_files) + + self.debug = debug + if self.debug: + self.rgbs = self.rgbs[:100] + self.raws = self.raws[:100] + + def __len__(self): + return len(self.rgbs) + + def __getitem__(self, idx): + rgb = load_img(self.rgbs[idx], norm=True) + rgb = torch.from_numpy(rgb.transpose((2, 0, 1))) + + if self.test: + return rgb, self.rgbs[idx] + else: + raw = load_raw(self.raws[idx]) + raw = torch.from_numpy(raw.transpose((2, 0, 1))) + return rgb, raw \ No newline at end of file