import torch import time from tqdm import tqdm import numpy as np from torch.utils.data import DataLoader import argparse import os from glob import glob from utils import LoadData from model import UNet def get_filenames(path): valid_rgbs = sorted(glob(f'{path}/*')) print(f'Validation samples: {len(valid_rgbs)}') return valid_rgbs def parse_args(): parser = argparse.ArgumentParser(description='RGB to RAW Inference') parser.add_argument('--folder', type=str, default='/workspace/yx/rgb2RAW', help='Input folder containing RGB images') parser.add_argument('--output', type=str, default='./submission/', help='Output folder for RAW predictions') parser.add_argument('--model', type=str, default='model.pt', help='Path to the model weights') parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference') return parser.parse_args() def inference(): # Parse command line arguments args = parse_args() # Create output directory if it doesn't exist os.makedirs(args.output, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = UNet(enc_chs=(3, 64, 128, 256), dec_chs=(256, 128, 64), out_ch=4, out_sz=(512, 512)) model.load_state_dict(torch.load(args.model)) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params_in_k = total_params / 1000 print(f"Parameter quantity:{total_params_in_k:.2f}K") model = model.to(device) # Get test data valid_rgbs = get_filenames(args.folder) test_dataset = LoadData(root=args.folder, rgb_files=valid_rgbs, test=True) test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=False) runtime = [] model.eval() with torch.no_grad(): for (rgb_batch, rgb_name) in tqdm(test_loader): rgb_batch = rgb_batch.to(device) rgb_name = rgb_name[0].split('/')[-1].replace('.png', '') st = time.time() recon_raw = model(rgb_batch) tt = time.time() - st runtime.append(tt) recon_raw = recon_raw[0].detach().cpu().permute(1, 2, 0).numpy() # Save as np.uint16 assert recon_raw.shape[-1] == 4 recon_raw = (recon_raw * (2**12-1)).astype(np.uint16) np.save(os.path.join(args.output, f'{rgb_name}.npy'), recon_raw) print(f"Average inference time: {np.mean(runtime):.3f}s") print(f"Results saved to: {args.output}") if __name__ == "__main__": inference()