sRGB2RAW/train.py
2025-03-31 16:32:17 +08:00

114 lines
3.9 KiB
Python

import os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import defaultdict
from torch.utils.data import DataLoader
from tqdm import tqdm
from IPython import display
import matplotlib.pyplot as plt
from torchmetrics import StructuralSimilarityIndexMeasure
from utils import LoadData, get_filenames, postprocess_raw, demosaic
from model import UNet, hard_log_loss
class CFG:
encoder = (3, 64, 128, 256)
decoder = (256, 128, 64)
out_ch = 4
out_sz = (512, 512)
lr = 1e-4
lr_decay = 1e-6
epochs = 40
loss = nn.MSELoss()
name = 'unet-rev-isp.pt'
out_dir = '../'
save_freq = 5
def train():
# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATH = '/workspace/yx/rgb2RAW' # Update this path
BATCH_TRAIN = 2
BATCH_TEST = 1
DEBUG = False
# Get data
train_raws, train_rgbs, valid_rgbs = get_filenames(PATH)
train_dataset = LoadData(root=PATH, rgb_files=train_rgbs, raw_files=train_raws, debug=DEBUG, test=False)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_TRAIN, shuffle=True, num_workers=0,
pin_memory=True, drop_last=True)
# Initialize model and training components
model = UNet(enc_chs=CFG.encoder, dec_chs=CFG.decoder, out_ch=CFG.out_ch, out_sz=CFG.out_sz).to(device)
opt = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.lr_decay)
scheduler = CosineAnnealingLR(opt, T_max=CFG.epochs)
criterion = CFG.loss
ssim_loss = StructuralSimilarityIndexMeasure().to(device)
hard_log_loss_fn = hard_log_loss().to(device)
metrics = defaultdict(list)
# Training loop
for epoch in range(CFG.epochs):
torch.cuda.empty_cache()
start_time = time.time()
train_loss = []
model.train()
for rgb_batch, raw_batch in tqdm(train_loader):
opt.zero_grad()
rgb_batch = rgb_batch.to(device)
raw_batch = raw_batch.to(device)
recon_raw = model(rgb_batch)
mse_loss = criterion(raw_batch, recon_raw)
ssim = ssim_loss(recon_raw, raw_batch)
ssim_loss_value = 1 - ssim
hard_log = hard_log_loss_fn(recon_raw, raw_batch)
combined_loss = mse_loss + 0.05*ssim_loss_value + 0.1*hard_log
combined_loss.backward()
opt.step()
train_loss.append(combined_loss.item())
metrics['train_loss'].append(np.mean(train_loss))
scheduler.step()
# Visualization
display.clear_output()
plt.figure(figsize=(20, 7))
ax1 = plt.subplot(2, 2, 1)
reconst_raw = postprocess_raw(demosaic(recon_raw[-1].detach().cpu().permute(1, 2, 0).numpy()))
gt_raw = postprocess_raw(demosaic(raw_batch[-1].detach().cpu().permute(1, 2, 0).numpy()))
cmp_raw_gt = np.concatenate([gt_raw, reconst_raw], axis=1)
ax1.imshow(cmp_raw_gt)
ax1.set_title('(left) GT / (right) Reconst. RAW')
ax2 = plt.subplot(2, 2, 3)
ax2.imshow(rgb_batch[-1].detach().cpu().permute(1, 2, 0).numpy())
ax2.set_title('RGB')
ax3 = plt.subplot(1, 2, 2)
ax3.plot(metrics['train_loss'], label='train')
ax3.set_xlabel('Epochs', fontsize=18)
ax3.set_ylabel('Loss', fontsize=18)
ax3.grid()
save_path = os.path.join(CFG.out_dir, f'epoch_{epoch + 1}_plot.png')
plt.savefig(save_path)
plt.show()
print(f"Epoch {epoch + 1} of {CFG.epochs} took {time.time() - start_time:.3f}s\n")
if ((epoch + 1) % CFG.save_freq == 0):
torch.save(model.state_dict(), os.path.join(CFG.out_dir, f'{epoch + 1}.pt'))
torch.save(model.state_dict(), os.path.join(CFG.out_dir, CFG.name))
if __name__ == "__main__":
train()