import os.path import logging import torch import argparse import json import glob from pprint import pprint from fvcore.nn import FlopCountAnalysis from utils.model_summary import get_model_activation, get_model_flops from utils import utils_logger from utils import utils_image as util def select_model(args, device): # Model ID is assigned according to the order of the submissions. # Different networks are trained with input range of either [0,1] or [0,255]. The range is determined manually. model_id = args.model_id if model_id == 0: # Baseline: The 1st Place of the `Overall Performance`` of the NTIRE 2023 Efficient SR Challenge # Edge-enhanced Feature Distillation Network for Efficient Super-Resolution # arXiv: https://arxiv.org/pdf/2204.08759 # Original Code: https://github.com/icandle/EFDN # Ckpts: EFDN_gv.pth from models.team00_EFDN import EFDN name, data_range = f"{model_id:02}_EFDN_baseline", 1.0 model_path = os.path.join('model_zoo', 'team00_EFDN.pth') model = EFDN() model.load_state_dict(torch.load(model_path), strict=True) elif model_id == 23: from models.team23_DSCF import DSCF name, data_range = f"{model_id:02}_DSCF", 1.0 model_path = os.path.join('model_zoo', 'team23_DSCF.pth') model = DSCF(3,3,feature_channels=26,upscale=4) state_dict = torch.load(model_path) model.load_state_dict(state_dict, strict=False) else: raise NotImplementedError(f"Model {model_id} is not implemented.") # print(model) model.eval() tile = None for k, v in model.named_parameters(): v.requires_grad = False model = model.to(device) return model, name, data_range, tile def select_dataset(data_dir, mode): # inference on the DIV2K_LSDIR_test set if mode == "test": path = [ ( p.replace("_HR", "_LR").replace(".png", "x4.png"), p ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_test_HR/*.png"))) ] # inference on the DIV2K_LSDIR_valid set elif mode == "valid": path = [ ( p.replace("_HR", "_LR").replace(".png", "x4.png"), p ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_valid_HR/*.png"))) ] else: raise NotImplementedError(f"{mode} is not implemented in select_dataset") return path def forward(img_lq, model, tile=None, tile_overlap=32, scale=4): if tile is None: # test the image as a whole output = model(img_lq) else: # test the image tile by tile b, c, h, w = img_lq.size() tile = min(tile, h, w) tile_overlap = tile_overlap sf = scale stride = tile - tile_overlap h_idx_list = list(range(0, h-tile, stride)) + [h-tile] w_idx_list = list(range(0, w-tile, stride)) + [w-tile] E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq) W = torch.zeros_like(E) for h_idx in h_idx_list: for w_idx in w_idx_list: in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch) W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask) output = E.div_(W) return output def run(model, model_name, data_range, tile, logger, device, args, mode="test"): sf = 4 border = sf results = dict() results[f"{mode}_runtime"] = [] results[f"{mode}_psnr"] = [] if args.ssim: results[f"{mode}_ssim"] = [] # results[f"{mode}_psnr_y"] = [] # results[f"{mode}_ssim_y"] = [] # -------------------------------- # dataset path # -------------------------------- data_path = select_dataset(args.data_dir, mode) save_path = os.path.join(args.save_dir, model_name, mode) util.mkdir(save_path) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i, (img_lr, img_hr) in enumerate(data_path): # -------------------------------- # (1) img_lr # -------------------------------- img_name, ext = os.path.splitext(os.path.basename(img_hr)) img_lr = util.imread_uint(img_lr, n_channels=3) img_lr = util.uint2tensor4(img_lr, data_range) img_lr = img_lr.to(device) # -------------------------------- # (2) img_sr # -------------------------------- start.record() img_sr = forward(img_lr, model, tile) end.record() torch.cuda.synchronize() results[f"{mode}_runtime"].append(start.elapsed_time(end)) # milliseconds img_sr = util.tensor2uint(img_sr, data_range) # -------------------------------- # (3) img_hr # -------------------------------- img_hr = util.imread_uint(img_hr, n_channels=3) img_hr = img_hr.squeeze() img_hr = util.modcrop(img_hr, sf) # -------------------------------- # PSNR and SSIM # -------------------------------- # print(img_sr.shape, img_hr.shape) psnr = util.calculate_psnr(img_sr, img_hr, border=border) results[f"{mode}_psnr"].append(psnr) if args.ssim: ssim = util.calculate_ssim(img_sr, img_hr, border=border) results[f"{mode}_ssim"].append(ssim) logger.info("{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.".format(img_name + ext, psnr, ssim)) else: logger.info("{:s} - PSNR: {:.2f} dB".format(img_name + ext, psnr)) # if np.ndim(img_hr) == 3: # RGB image # img_sr_y = util.rgb2ycbcr(img_sr, only_y=True) # img_hr_y = util.rgb2ycbcr(img_hr, only_y=True) # psnr_y = util.calculate_psnr(img_sr_y, img_hr_y, border=border) # ssim_y = util.calculate_ssim(img_sr_y, img_hr_y, border=border) # results[f"{mode}_psnr_y"].append(psnr_y) # results[f"{mode}_ssim_y"].append(ssim_y) # print(os.path.join(save_path, img_name+ext)) # --- Save Restored Images --- # util.imsave(img_sr, os.path.join(save_path, img_name+ext)) results[f"{mode}_memory"] = torch.cuda.max_memory_allocated(torch.cuda.current_device()) / 1024 ** 2 results[f"{mode}_ave_runtime"] = sum(results[f"{mode}_runtime"]) / len(results[f"{mode}_runtime"]) #/ 1000.0 results[f"{mode}_ave_psnr"] = sum(results[f"{mode}_psnr"]) / len(results[f"{mode}_psnr"]) if args.ssim: results[f"{mode}_ave_ssim"] = sum(results[f"{mode}_ssim"]) / len(results[f"{mode}_ssim"]) # results[f"{mode}_ave_psnr_y"] = sum(results[f"{mode}_psnr_y"]) / len(results[f"{mode}_psnr_y"]) # results[f"{mode}_ave_ssim_y"] = sum(results[f"{mode}_ssim_y"]) / len(results[f"{mode}_ssim_y"]) logger.info("{:>16s} : {:<.3f} [M]".format("Max Memory", results[f"{mode}_memory"])) # Memery logger.info("------> Average runtime of ({}) is : {:.6f} milliseconds".format("test" if mode == "test" else "valid", results[f"{mode}_ave_runtime"])) logger.info("------> Average PSNR of ({}) is : {:.6f} dB".format("test" if mode == "test" else "valid", results[f"{mode}_ave_psnr"])) return results def main(args): utils_logger.logger_info("NTIRE2025-EfficientSR", log_path="NTIRE2025-EfficientSR.log") logger = logging.getLogger("NTIRE2025-EfficientSR") # -------------------------------- # basic settings # -------------------------------- torch.cuda.current_device() torch.cuda.empty_cache() torch.backends.cudnn.benchmark = False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") json_dir = os.path.join(os.getcwd(), "results.json") if not os.path.exists(json_dir): results = dict() else: with open(json_dir, "r") as f: results = json.load(f) # -------------------------------- # load model # -------------------------------- model, model_name, data_range, tile = select_model(args, device) logger.info(model_name) # if model not in results: if True: # -------------------------------- # restore image # -------------------------------- # inference on the DIV2K_LSDIR_valid set valid_results = run(model, model_name, data_range, tile, logger, device, args, mode="valid") # record PSNR, runtime results[model_name] = valid_results # inference conducted by the Organizer on DIV2K_LSDIR_test set if args.include_test: test_results = run(model, model_name, data_range, tile, logger, device, args, mode="test") results[model_name].update(test_results) input_dim = (3, 256, 256) # set the input dimension activations, num_conv = get_model_activation(model, input_dim) activations = activations/10**6 logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations)) logger.info("{:>16s} : {:16s} : {:<.4f} [G]".format("FLOPs", flops)) # fvcore is used in NTIRE2025_ESR for FLOPs calculation input_fake = torch.rand(1, 3, 256, 256).to(device) flops = FlopCountAnalysis(model, input_fake).total() flops = flops/10**9 logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops)) num_parameters = sum(map(lambda x: x.numel(), model.parameters())) num_parameters = num_parameters/10**6 logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters)) results[model_name].update({"activations": activations, "num_conv": num_conv, "flops": flops, "num_parameters": num_parameters}) with open(json_dir, "w") as f: json.dump(results, f) if args.include_test: fmt = "{:20s}\t{:10s}\t{:10s}\t{:14s}\t{:14s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n" s = fmt.format("Model", "Val PSNR", "Test PSNR", "Val Time [ms]", "Test Time [ms]", "Ave Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv") else: fmt = "{:20s}\t{:10s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n" s = fmt.format("Model", "Val PSNR", "Val Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv") for k, v in results.items(): val_psnr = f"{v['valid_ave_psnr']:2.2f}" val_time = f"{v['valid_ave_runtime']:3.2f}" mem = f"{v['valid_memory']:2.2f}" num_param = f"{v['num_parameters']:2.3f}" flops = f"{v['flops']:2.2f}" acts = f"{v['activations']:2.2f}" conv = f"{v['num_conv']:4d}" if args.include_test: # from IPython import embed; embed() test_psnr = f"{v['test_ave_psnr']:2.2f}" test_time = f"{v['test_ave_runtime']:3.2f}" ave_time = f"{(v['valid_ave_runtime'] + v['test_ave_runtime']) / 2:3.2f}" s += fmt.format(k, val_psnr, test_psnr, val_time, test_time, ave_time, num_param, flops, acts, mem, conv) else: s += fmt.format(k, val_psnr, val_time, num_param, flops, acts, mem, conv) with open(os.path.join(os.getcwd(), 'results.txt'), "w") as f: f.write(s) if __name__ == "__main__": parser = argparse.ArgumentParser("NTIRE2025-EfficientSR") parser.add_argument("--data_dir", default="../", type=str) parser.add_argument("--save_dir", default="../results", type=str) parser.add_argument("--model_id", default=0, type=int) parser.add_argument("--include_test", action="store_true", help="Inference on the `DIV2K_LSDIR_test` set") parser.add_argument("--ssim", action="store_true", help="Calculate SSIM") args = parser.parse_args() pprint(args) main(args)