114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
import os
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
from collections import defaultdict
|
|
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from IPython import display
|
|
import matplotlib.pyplot as plt
|
|
from torchmetrics import StructuralSimilarityIndexMeasure
|
|
|
|
from utils import LoadData, get_filenames, postprocess_raw, demosaic
|
|
from model import UNet, hard_log_loss
|
|
|
|
class CFG:
|
|
encoder = (3, 64, 128, 256)
|
|
decoder = (256, 128, 64)
|
|
out_ch = 4
|
|
out_sz = (512, 512)
|
|
lr = 1e-4
|
|
lr_decay = 1e-6
|
|
epochs = 40
|
|
loss = nn.MSELoss()
|
|
name = 'unet-rev-isp.pt'
|
|
out_dir = '../'
|
|
save_freq = 5
|
|
|
|
def train():
|
|
# Setup
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
PATH = '/workspace/yx/rgb2RAW' # Update this path
|
|
BATCH_TRAIN = 2
|
|
BATCH_TEST = 1
|
|
DEBUG = False
|
|
|
|
# Get data
|
|
train_raws, train_rgbs, valid_rgbs = get_filenames(PATH)
|
|
train_dataset = LoadData(root=PATH, rgb_files=train_rgbs, raw_files=train_raws, debug=DEBUG, test=False)
|
|
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_TRAIN, shuffle=True, num_workers=0,
|
|
pin_memory=True, drop_last=True)
|
|
|
|
# Initialize model and training components
|
|
model = UNet(enc_chs=CFG.encoder, dec_chs=CFG.decoder, out_ch=CFG.out_ch, out_sz=CFG.out_sz).to(device)
|
|
opt = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.lr_decay)
|
|
scheduler = CosineAnnealingLR(opt, T_max=CFG.epochs)
|
|
criterion = CFG.loss
|
|
ssim_loss = StructuralSimilarityIndexMeasure().to(device)
|
|
hard_log_loss_fn = hard_log_loss().to(device)
|
|
metrics = defaultdict(list)
|
|
|
|
# Training loop
|
|
for epoch in range(CFG.epochs):
|
|
torch.cuda.empty_cache()
|
|
start_time = time.time()
|
|
train_loss = []
|
|
model.train()
|
|
|
|
for rgb_batch, raw_batch in tqdm(train_loader):
|
|
opt.zero_grad()
|
|
rgb_batch = rgb_batch.to(device)
|
|
raw_batch = raw_batch.to(device)
|
|
|
|
recon_raw = model(rgb_batch)
|
|
|
|
mse_loss = criterion(raw_batch, recon_raw)
|
|
ssim = ssim_loss(recon_raw, raw_batch)
|
|
ssim_loss_value = 1 - ssim
|
|
hard_log = hard_log_loss_fn(recon_raw, raw_batch)
|
|
combined_loss = mse_loss + 0.05*ssim_loss_value + 0.1*hard_log
|
|
|
|
combined_loss.backward()
|
|
opt.step()
|
|
train_loss.append(combined_loss.item())
|
|
|
|
metrics['train_loss'].append(np.mean(train_loss))
|
|
scheduler.step()
|
|
|
|
# Visualization
|
|
display.clear_output()
|
|
plt.figure(figsize=(20, 7))
|
|
ax1 = plt.subplot(2, 2, 1)
|
|
reconst_raw = postprocess_raw(demosaic(recon_raw[-1].detach().cpu().permute(1, 2, 0).numpy()))
|
|
gt_raw = postprocess_raw(demosaic(raw_batch[-1].detach().cpu().permute(1, 2, 0).numpy()))
|
|
cmp_raw_gt = np.concatenate([gt_raw, reconst_raw], axis=1)
|
|
ax1.imshow(cmp_raw_gt)
|
|
ax1.set_title('(left) GT / (right) Reconst. RAW')
|
|
|
|
ax2 = plt.subplot(2, 2, 3)
|
|
ax2.imshow(rgb_batch[-1].detach().cpu().permute(1, 2, 0).numpy())
|
|
ax2.set_title('RGB')
|
|
|
|
ax3 = plt.subplot(1, 2, 2)
|
|
ax3.plot(metrics['train_loss'], label='train')
|
|
ax3.set_xlabel('Epochs', fontsize=18)
|
|
ax3.set_ylabel('Loss', fontsize=18)
|
|
ax3.grid()
|
|
|
|
save_path = os.path.join(CFG.out_dir, f'epoch_{epoch + 1}_plot.png')
|
|
plt.savefig(save_path)
|
|
plt.show()
|
|
|
|
print(f"Epoch {epoch + 1} of {CFG.epochs} took {time.time() - start_time:.3f}s\n")
|
|
|
|
if ((epoch + 1) % CFG.save_freq == 0):
|
|
torch.save(model.state_dict(), os.path.join(CFG.out_dir, f'{epoch + 1}.pt'))
|
|
|
|
torch.save(model.state_dict(), os.path.join(CFG.out_dir, CFG.name))
|
|
|
|
if __name__ == "__main__":
|
|
train() |