import os import time import numpy as np import torch import torch.nn as nn from torch.optim.lr_scheduler import CosineAnnealingLR from collections import defaultdict from torch.utils.data import DataLoader from tqdm import tqdm from IPython import display import matplotlib.pyplot as plt from torchmetrics import StructuralSimilarityIndexMeasure from utils import LoadData, get_filenames, postprocess_raw, demosaic from model import UNet, hard_log_loss class CFG: encoder = (3, 64, 128, 256) decoder = (256, 128, 64) out_ch = 4 out_sz = (512, 512) lr = 1e-4 lr_decay = 1e-6 epochs = 40 loss = nn.MSELoss() name = 'unet-rev-isp.pt' out_dir = '../' save_freq = 5 def train(): # Setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") PATH = '/workspace/yx/rgb2RAW' # Update this path BATCH_TRAIN = 2 BATCH_TEST = 1 DEBUG = False # Get data train_raws, train_rgbs, valid_rgbs = get_filenames(PATH) train_dataset = LoadData(root=PATH, rgb_files=train_rgbs, raw_files=train_raws, debug=DEBUG, test=False) train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_TRAIN, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) # Initialize model and training components model = UNet(enc_chs=CFG.encoder, dec_chs=CFG.decoder, out_ch=CFG.out_ch, out_sz=CFG.out_sz).to(device) opt = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.lr_decay) scheduler = CosineAnnealingLR(opt, T_max=CFG.epochs) criterion = CFG.loss ssim_loss = StructuralSimilarityIndexMeasure().to(device) hard_log_loss_fn = hard_log_loss().to(device) metrics = defaultdict(list) # Training loop for epoch in range(CFG.epochs): torch.cuda.empty_cache() start_time = time.time() train_loss = [] model.train() for rgb_batch, raw_batch in tqdm(train_loader): opt.zero_grad() rgb_batch = rgb_batch.to(device) raw_batch = raw_batch.to(device) recon_raw = model(rgb_batch) mse_loss = criterion(raw_batch, recon_raw) ssim = ssim_loss(recon_raw, raw_batch) ssim_loss_value = 1 - ssim hard_log = hard_log_loss_fn(recon_raw, raw_batch) combined_loss = mse_loss + 0.05*ssim_loss_value + 0.1*hard_log combined_loss.backward() opt.step() train_loss.append(combined_loss.item()) metrics['train_loss'].append(np.mean(train_loss)) scheduler.step() # Visualization display.clear_output() plt.figure(figsize=(20, 7)) ax1 = plt.subplot(2, 2, 1) reconst_raw = postprocess_raw(demosaic(recon_raw[-1].detach().cpu().permute(1, 2, 0).numpy())) gt_raw = postprocess_raw(demosaic(raw_batch[-1].detach().cpu().permute(1, 2, 0).numpy())) cmp_raw_gt = np.concatenate([gt_raw, reconst_raw], axis=1) ax1.imshow(cmp_raw_gt) ax1.set_title('(left) GT / (right) Reconst. RAW') ax2 = plt.subplot(2, 2, 3) ax2.imshow(rgb_batch[-1].detach().cpu().permute(1, 2, 0).numpy()) ax2.set_title('RGB') ax3 = plt.subplot(1, 2, 2) ax3.plot(metrics['train_loss'], label='train') ax3.set_xlabel('Epochs', fontsize=18) ax3.set_ylabel('Loss', fontsize=18) ax3.grid() save_path = os.path.join(CFG.out_dir, f'epoch_{epoch + 1}_plot.png') plt.savefig(save_path) plt.show() print(f"Epoch {epoch + 1} of {CFG.epochs} took {time.time() - start_time:.3f}s\n") if ((epoch + 1) % CFG.save_freq == 0): torch.save(model.state_dict(), os.path.join(CFG.out_dir, f'{epoch + 1}.pt')) torch.save(model.state_dict(), os.path.join(CFG.out_dir, CFG.name)) if __name__ == "__main__": train()