Initial commit
This commit is contained in:
commit
0d05740a06
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
*.pt
|
||||
82
inference.py
Normal file
82
inference.py
Normal file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
import argparse
|
||||
import os
|
||||
from glob import glob
|
||||
from utils import LoadData
|
||||
from model import UNet
|
||||
|
||||
def get_filenames(path):
|
||||
valid_rgbs = sorted(glob(f'{path}/*'))
|
||||
print(f'Validation samples: {len(valid_rgbs)}')
|
||||
return valid_rgbs
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='RGB to RAW Inference')
|
||||
parser.add_argument('--folder', type=str, default='/workspace/yx/rgb2RAW',
|
||||
help='Input folder containing RGB images')
|
||||
parser.add_argument('--output', type=str, default='./submission/',
|
||||
help='Output folder for RAW predictions')
|
||||
parser.add_argument('--model', type=str, default='model.pt',
|
||||
help='Path to the model weights')
|
||||
parser.add_argument('--batch_size', type=int, default=1,
|
||||
help='Batch size for inference')
|
||||
return parser.parse_args()
|
||||
|
||||
def inference():
|
||||
# Parse command line arguments
|
||||
args = parse_args()
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model
|
||||
model = UNet(enc_chs=(3, 64, 128, 256),
|
||||
dec_chs=(256, 128, 64),
|
||||
out_ch=4,
|
||||
out_sz=(512, 512))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
total_params_in_k = total_params / 1000
|
||||
print(f"Parameter quantity:{total_params_in_k:.2f}K")
|
||||
model = model.to(device)
|
||||
|
||||
# Get test data
|
||||
valid_rgbs = get_filenames(args.folder)
|
||||
test_dataset = LoadData(root=args.folder, rgb_files=valid_rgbs, test=True)
|
||||
test_loader = DataLoader(dataset=test_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=0,
|
||||
pin_memory=True,
|
||||
drop_last=False)
|
||||
|
||||
runtime = []
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for (rgb_batch, rgb_name) in tqdm(test_loader):
|
||||
rgb_batch = rgb_batch.to(device)
|
||||
rgb_name = rgb_name[0].split('/')[-1].replace('.png', '')
|
||||
|
||||
st = time.time()
|
||||
recon_raw = model(rgb_batch)
|
||||
tt = time.time() - st
|
||||
runtime.append(tt)
|
||||
|
||||
recon_raw = recon_raw[0].detach().cpu().permute(1, 2, 0).numpy()
|
||||
|
||||
# Save as np.uint16
|
||||
assert recon_raw.shape[-1] == 4
|
||||
recon_raw = (recon_raw * (2**12-1)).astype(np.uint16)
|
||||
np.save(os.path.join(args.output, f'{rgb_name}.npy'), recon_raw)
|
||||
|
||||
print(f"Average inference time: {np.mean(runtime):.3f}s")
|
||||
print(f"Results saved to: {args.output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference()
|
||||
154
model.py
Normal file
154
model.py
Normal file
@ -0,0 +1,154 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LayerNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
ctx.eps = eps
|
||||
N, C, H, W = x.size()
|
||||
mu = x.mean(1, keepdim=True)
|
||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||
y = (x - mu) / (var + eps).sqrt()
|
||||
ctx.save_for_backward(y, var, weight)
|
||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
eps = ctx.eps
|
||||
N, C, H, W = grad_output.size()
|
||||
y, var, weight = ctx.saved_variables
|
||||
g = grad_output * weight.view(1, C, 1, 1)
|
||||
mean_g = g.mean(dim=1, keepdim=True)
|
||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, channels, eps=1e-6):
|
||||
super(LayerNorm2d, self).__init__()
|
||||
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
||||
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||
|
||||
class SimpleGate(nn.Module):
|
||||
def forward(self, x):
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
return x1 * x2
|
||||
|
||||
class NAFBlock(nn.Module):
|
||||
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
||||
super().__init__()
|
||||
dw_channel = c * DW_Expand
|
||||
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, bias=True)
|
||||
self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
||||
|
||||
self.sca = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, groups=1, bias=True),
|
||||
)
|
||||
|
||||
self.sg = SimpleGate()
|
||||
|
||||
ffn_channel = FFN_Expand * c
|
||||
self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
||||
self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
||||
|
||||
self.norm1 = LayerNorm2d(c)
|
||||
self.norm2 = LayerNorm2d(c)
|
||||
|
||||
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
||||
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
||||
|
||||
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
||||
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
x = self.norm1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.sg(x)
|
||||
x = x * self.sca(x)
|
||||
x = self.conv3(x)
|
||||
x = self.dropout1(x)
|
||||
y = inp + x * self.beta
|
||||
x = self.conv4(self.norm2(y))
|
||||
x = self.sg(x)
|
||||
x = self.conv5(x)
|
||||
x = self.dropout2(x)
|
||||
return y + x * self.gamma
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, dw_expand=2, ffn_expand=2, dropout=0.):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False)
|
||||
self.naf_block = NAFBlock(
|
||||
c=out_ch,
|
||||
DW_Expand=dw_expand,
|
||||
FFN_Expand=ffn_expand,
|
||||
drop_out_rate=dropout
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.naf_block(x)
|
||||
return x
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, chs=(3,64,128,256)):
|
||||
super().__init__()
|
||||
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
ftrs = []
|
||||
for block in self.enc_blocks:
|
||||
x = block(x)
|
||||
ftrs.append(x)
|
||||
x = self.pool(x)
|
||||
return ftrs
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, chs=(256, 128, 64)):
|
||||
super().__init__()
|
||||
self.chs = chs
|
||||
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
|
||||
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
||||
|
||||
def forward(self, x, encoder_features):
|
||||
for i in range(len(self.chs)-1):
|
||||
x = self.upconvs[i](x)
|
||||
enc_ftrs = encoder_features[i]
|
||||
x = torch.cat([x, enc_ftrs], dim=1)
|
||||
x = self.dec_blocks[i](x)
|
||||
return x
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, enc_chs=(3, 32, 64, 128), dec_chs=(128, 64, 32), out_ch=4, out_sz=(252, 252)):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(enc_chs)
|
||||
self.decoder = Decoder(dec_chs)
|
||||
self.head = nn.Conv2d(dec_chs[-1], out_ch, 1)
|
||||
self.out_sz = out_sz
|
||||
|
||||
def forward(self, x):
|
||||
enc_ftrs = self.encoder(x)
|
||||
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
|
||||
out = self.head(out)
|
||||
out = F.interpolate(out, self.out_sz)
|
||||
out = torch.clamp(out, min=0., max=1.)
|
||||
return out
|
||||
|
||||
class hard_log_loss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
loss = (-1 * torch.log(1 - torch.clamp(torch.abs(x - y),0,1) + 1e-6)).mean()
|
||||
return loss
|
||||
114
train.py
Normal file
114
train.py
Normal file
@ -0,0 +1,114 @@
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from collections import defaultdict
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from IPython import display
|
||||
import matplotlib.pyplot as plt
|
||||
from torchmetrics import StructuralSimilarityIndexMeasure
|
||||
|
||||
from utils import LoadData, get_filenames, postprocess_raw, demosaic
|
||||
from model import UNet, hard_log_loss
|
||||
|
||||
class CFG:
|
||||
encoder = (3, 64, 128, 256)
|
||||
decoder = (256, 128, 64)
|
||||
out_ch = 4
|
||||
out_sz = (512, 512)
|
||||
lr = 1e-4
|
||||
lr_decay = 1e-6
|
||||
epochs = 40
|
||||
loss = nn.MSELoss()
|
||||
name = 'unet-rev-isp.pt'
|
||||
out_dir = '../'
|
||||
save_freq = 5
|
||||
|
||||
def train():
|
||||
# Setup
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
PATH = '/workspace/yx/rgb2RAW' # Update this path
|
||||
BATCH_TRAIN = 2
|
||||
BATCH_TEST = 1
|
||||
DEBUG = False
|
||||
|
||||
# Get data
|
||||
train_raws, train_rgbs, valid_rgbs = get_filenames(PATH)
|
||||
train_dataset = LoadData(root=PATH, rgb_files=train_rgbs, raw_files=train_raws, debug=DEBUG, test=False)
|
||||
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_TRAIN, shuffle=True, num_workers=0,
|
||||
pin_memory=True, drop_last=True)
|
||||
|
||||
# Initialize model and training components
|
||||
model = UNet(enc_chs=CFG.encoder, dec_chs=CFG.decoder, out_ch=CFG.out_ch, out_sz=CFG.out_sz).to(device)
|
||||
opt = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.lr_decay)
|
||||
scheduler = CosineAnnealingLR(opt, T_max=CFG.epochs)
|
||||
criterion = CFG.loss
|
||||
ssim_loss = StructuralSimilarityIndexMeasure().to(device)
|
||||
hard_log_loss_fn = hard_log_loss().to(device)
|
||||
metrics = defaultdict(list)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(CFG.epochs):
|
||||
torch.cuda.empty_cache()
|
||||
start_time = time.time()
|
||||
train_loss = []
|
||||
model.train()
|
||||
|
||||
for rgb_batch, raw_batch in tqdm(train_loader):
|
||||
opt.zero_grad()
|
||||
rgb_batch = rgb_batch.to(device)
|
||||
raw_batch = raw_batch.to(device)
|
||||
|
||||
recon_raw = model(rgb_batch)
|
||||
|
||||
mse_loss = criterion(raw_batch, recon_raw)
|
||||
ssim = ssim_loss(recon_raw, raw_batch)
|
||||
ssim_loss_value = 1 - ssim
|
||||
hard_log = hard_log_loss_fn(recon_raw, raw_batch)
|
||||
combined_loss = mse_loss + 0.05*ssim_loss_value + 0.1*hard_log
|
||||
|
||||
combined_loss.backward()
|
||||
opt.step()
|
||||
train_loss.append(combined_loss.item())
|
||||
|
||||
metrics['train_loss'].append(np.mean(train_loss))
|
||||
scheduler.step()
|
||||
|
||||
# Visualization
|
||||
display.clear_output()
|
||||
plt.figure(figsize=(20, 7))
|
||||
ax1 = plt.subplot(2, 2, 1)
|
||||
reconst_raw = postprocess_raw(demosaic(recon_raw[-1].detach().cpu().permute(1, 2, 0).numpy()))
|
||||
gt_raw = postprocess_raw(demosaic(raw_batch[-1].detach().cpu().permute(1, 2, 0).numpy()))
|
||||
cmp_raw_gt = np.concatenate([gt_raw, reconst_raw], axis=1)
|
||||
ax1.imshow(cmp_raw_gt)
|
||||
ax1.set_title('(left) GT / (right) Reconst. RAW')
|
||||
|
||||
ax2 = plt.subplot(2, 2, 3)
|
||||
ax2.imshow(rgb_batch[-1].detach().cpu().permute(1, 2, 0).numpy())
|
||||
ax2.set_title('RGB')
|
||||
|
||||
ax3 = plt.subplot(1, 2, 2)
|
||||
ax3.plot(metrics['train_loss'], label='train')
|
||||
ax3.set_xlabel('Epochs', fontsize=18)
|
||||
ax3.set_ylabel('Loss', fontsize=18)
|
||||
ax3.grid()
|
||||
|
||||
save_path = os.path.join(CFG.out_dir, f'epoch_{epoch + 1}_plot.png')
|
||||
plt.savefig(save_path)
|
||||
plt.show()
|
||||
|
||||
print(f"Epoch {epoch + 1} of {CFG.epochs} took {time.time() - start_time:.3f}s\n")
|
||||
|
||||
if ((epoch + 1) % CFG.save_freq == 0):
|
||||
torch.save(model.state_dict(), os.path.join(CFG.out_dir, f'{epoch + 1}.pt'))
|
||||
|
||||
torch.save(model.state_dict(), os.path.join(CFG.out_dir, CFG.name))
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
142
utils.py
Normal file
142
utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
from glob import glob
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms.functional as FF
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import random
|
||||
|
||||
def load_img(filename, debug=False, norm=True, resize=None):
|
||||
img = cv2.imread(filename)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
if norm:
|
||||
img = img / 255.
|
||||
img = img.astype(np.float32)
|
||||
if debug:
|
||||
print(img.shape, img.dtype, img.min(), img.max())
|
||||
|
||||
if resize:
|
||||
img = cv2.resize(img, (resize[0], resize[1]), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return img
|
||||
|
||||
def save_rgb(img, filename):
|
||||
if np.max(img) <= 1:
|
||||
img = img * 255
|
||||
|
||||
img = img.astype(np.float32)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(filename, img)
|
||||
|
||||
def load_raw(raw, max_val=2**12-1):
|
||||
raw = np.load(raw) / max_val
|
||||
return raw.astype(np.float32)
|
||||
|
||||
def demosaic(raw):
|
||||
"""Simple demosaicing to visualize RAW images"""
|
||||
assert raw.shape[-1] == 4
|
||||
shape = raw.shape
|
||||
|
||||
red = raw[:,:,0]
|
||||
green_red = raw[:,:,1]
|
||||
green_blue = raw[:,:,2]
|
||||
blue = raw[:,:,3]
|
||||
avg_green = (green_red + green_blue) / 2
|
||||
image = np.stack((red, avg_green, blue), axis=-1)
|
||||
image = cv2.resize(image, (shape[1]*2, shape[0]*2))
|
||||
return image
|
||||
|
||||
def mosaic(rgb):
|
||||
"""Extracts RGGB Bayer planes from an RGB image."""
|
||||
assert rgb.shape[-1] == 3
|
||||
shape = rgb.shape
|
||||
|
||||
red = rgb[0::2, 0::2, 0]
|
||||
green_red = rgb[0::2, 1::2, 1]
|
||||
green_blue = rgb[1::2, 0::2, 1]
|
||||
blue = rgb[1::2, 1::2, 2]
|
||||
|
||||
image = np.stack((red, green_red, green_blue, blue), axis=-1)
|
||||
return image
|
||||
|
||||
def gamma_compression(image):
|
||||
"""Converts from linear to gamma space."""
|
||||
return np.maximum(image, 1e-8) ** (1.0 / 2.2)
|
||||
|
||||
def tonemap(image):
|
||||
"""Simple S-curved global tonemap"""
|
||||
return (3*(image**2)) - (2*(image**3))
|
||||
|
||||
def postprocess_raw(raw):
|
||||
"""Simple post-processing to visualize demosaic RAW images"""
|
||||
raw = gamma_compression(raw)
|
||||
raw = tonemap(raw)
|
||||
raw = np.clip(raw, 0, 1)
|
||||
return raw
|
||||
|
||||
def plot_pair(rgb, raw, t1='RGB', t2='RAW', axis='off'):
|
||||
fig = plt.figure(figsize=(12, 6), dpi=80)
|
||||
plt.subplot(1,2,1)
|
||||
plt.title(t1)
|
||||
plt.axis(axis)
|
||||
plt.imshow(rgb)
|
||||
|
||||
plt.subplot(1,2,2)
|
||||
plt.title(t2)
|
||||
plt.axis(axis)
|
||||
plt.imshow(raw)
|
||||
plt.show()
|
||||
|
||||
def PSNR(y_true, y_pred):
|
||||
mse = np.mean((y_true - y_pred) ** 2)
|
||||
if(mse == 0):
|
||||
return np.inf
|
||||
|
||||
max_pixel = np.max(y_true)
|
||||
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
|
||||
return psnr
|
||||
|
||||
def get_filenames(path):
|
||||
devices = ['samsung-s9', 'iphone-x', 'lq-iphone']
|
||||
train_raws, train_rgbs = [], []
|
||||
for device in devices:
|
||||
train_raw = sorted(glob(f'{path}/train/{device}/*.npy'))
|
||||
train_rgb = sorted(glob(f'{path}/train/{device}/*.png'))
|
||||
train_raws.extend(train_raw)
|
||||
train_rgbs.extend(train_rgb)
|
||||
valid_rgbs = sorted(glob(f'{path}/rgbs/*'))
|
||||
assert len(train_raws) == len(train_rgbs)
|
||||
print(f'Training samples: {len(train_raws)} \t Validation samples: {len(valid_rgbs)}')
|
||||
return train_raws, train_rgbs, valid_rgbs
|
||||
|
||||
class LoadData(Dataset):
|
||||
def __init__(self, root, rgb_files, raw_files=None, debug=False, test=None):
|
||||
self.root = root
|
||||
self.test = test
|
||||
self.rgbs = sorted(rgb_files)
|
||||
if self.test:
|
||||
self.raws = None
|
||||
else:
|
||||
self.raws = sorted(raw_files)
|
||||
|
||||
self.debug = debug
|
||||
if self.debug:
|
||||
self.rgbs = self.rgbs[:100]
|
||||
self.raws = self.raws[:100]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.rgbs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
rgb = load_img(self.rgbs[idx], norm=True)
|
||||
rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
|
||||
|
||||
if self.test:
|
||||
return rgb, self.rgbs[idx]
|
||||
else:
|
||||
raw = load_raw(self.raws[idx])
|
||||
raw = torch.from_numpy(raw.transpose((2, 0, 1)))
|
||||
return rgb, raw
|
||||
Loading…
x
Reference in New Issue
Block a user