sRGB2RAW/inference.py
2025-03-31 16:32:17 +08:00

82 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import time
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
import argparse
import os
from glob import glob
from utils import LoadData
from model import UNet
def get_filenames(path):
valid_rgbs = sorted(glob(f'{path}/*'))
print(f'Validation samples: {len(valid_rgbs)}')
return valid_rgbs
def parse_args():
parser = argparse.ArgumentParser(description='RGB to RAW Inference')
parser.add_argument('--folder', type=str, default='/workspace/yx/rgb2RAW',
help='Input folder containing RGB images')
parser.add_argument('--output', type=str, default='./submission/',
help='Output folder for RAW predictions')
parser.add_argument('--model', type=str, default='model.pt',
help='Path to the model weights')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for inference')
return parser.parse_args()
def inference():
# Parse command line arguments
args = parse_args()
# Create output directory if it doesn't exist
os.makedirs(args.output, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = UNet(enc_chs=(3, 64, 128, 256),
dec_chs=(256, 128, 64),
out_ch=4,
out_sz=(512, 512))
model.load_state_dict(torch.load(args.model))
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params_in_k = total_params / 1000
print(f"Parameter quantity{total_params_in_k:.2f}K")
model = model.to(device)
# Get test data
valid_rgbs = get_filenames(args.folder)
test_dataset = LoadData(root=args.folder, rgb_files=valid_rgbs, test=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=False)
runtime = []
model.eval()
with torch.no_grad():
for (rgb_batch, rgb_name) in tqdm(test_loader):
rgb_batch = rgb_batch.to(device)
rgb_name = rgb_name[0].split('/')[-1].replace('.png', '')
st = time.time()
recon_raw = model(rgb_batch)
tt = time.time() - st
runtime.append(tt)
recon_raw = recon_raw[0].detach().cpu().permute(1, 2, 0).numpy()
# Save as np.uint16
assert recon_raw.shape[-1] == 4
recon_raw = (recon_raw * (2**12-1)).astype(np.uint16)
np.save(os.path.join(args.output, f'{rgb_name}.npy'), recon_raw)
print(f"Average inference time: {np.mean(runtime):.3f}s")
print(f"Results saved to: {args.output}")
if __name__ == "__main__":
inference()