82 lines
2.9 KiB
Python
82 lines
2.9 KiB
Python
import torch
|
||
import time
|
||
from tqdm import tqdm
|
||
import numpy as np
|
||
from torch.utils.data import DataLoader
|
||
import argparse
|
||
import os
|
||
from glob import glob
|
||
from utils import LoadData
|
||
from model import UNet
|
||
|
||
def get_filenames(path):
|
||
valid_rgbs = sorted(glob(f'{path}/*'))
|
||
print(f'Validation samples: {len(valid_rgbs)}')
|
||
return valid_rgbs
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description='RGB to RAW Inference')
|
||
parser.add_argument('--folder', type=str, default='/workspace/yx/rgb2RAW',
|
||
help='Input folder containing RGB images')
|
||
parser.add_argument('--output', type=str, default='./submission/',
|
||
help='Output folder for RAW predictions')
|
||
parser.add_argument('--model', type=str, default='model.pt',
|
||
help='Path to the model weights')
|
||
parser.add_argument('--batch_size', type=int, default=1,
|
||
help='Batch size for inference')
|
||
return parser.parse_args()
|
||
|
||
def inference():
|
||
# Parse command line arguments
|
||
args = parse_args()
|
||
|
||
# Create output directory if it doesn't exist
|
||
os.makedirs(args.output, exist_ok=True)
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# Load model
|
||
model = UNet(enc_chs=(3, 64, 128, 256),
|
||
dec_chs=(256, 128, 64),
|
||
out_ch=4,
|
||
out_sz=(512, 512))
|
||
model.load_state_dict(torch.load(args.model))
|
||
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
total_params_in_k = total_params / 1000
|
||
print(f"Parameter quantity:{total_params_in_k:.2f}K")
|
||
model = model.to(device)
|
||
|
||
# Get test data
|
||
valid_rgbs = get_filenames(args.folder)
|
||
test_dataset = LoadData(root=args.folder, rgb_files=valid_rgbs, test=True)
|
||
test_loader = DataLoader(dataset=test_dataset,
|
||
batch_size=args.batch_size,
|
||
shuffle=False,
|
||
num_workers=0,
|
||
pin_memory=True,
|
||
drop_last=False)
|
||
|
||
runtime = []
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for (rgb_batch, rgb_name) in tqdm(test_loader):
|
||
rgb_batch = rgb_batch.to(device)
|
||
rgb_name = rgb_name[0].split('/')[-1].replace('.png', '')
|
||
|
||
st = time.time()
|
||
recon_raw = model(rgb_batch)
|
||
tt = time.time() - st
|
||
runtime.append(tt)
|
||
|
||
recon_raw = recon_raw[0].detach().cpu().permute(1, 2, 0).numpy()
|
||
|
||
# Save as np.uint16
|
||
assert recon_raw.shape[-1] == 4
|
||
recon_raw = (recon_raw * (2**12-1)).astype(np.uint16)
|
||
np.save(os.path.join(args.output, f'{rgb_name}.npy'), recon_raw)
|
||
|
||
print(f"Average inference time: {np.mean(runtime):.3f}s")
|
||
print(f"Results saved to: {args.output}")
|
||
|
||
if __name__ == "__main__":
|
||
inference() |