|
|
import os.path
|
|
|
import logging
|
|
|
import torch
|
|
|
import argparse
|
|
|
import json
|
|
|
import glob
|
|
|
|
|
|
from pprint import pprint
|
|
|
from fvcore.nn import FlopCountAnalysis
|
|
|
from utils.model_summary import get_model_activation, get_model_flops
|
|
|
from utils import utils_logger
|
|
|
from utils import utils_image as util
|
|
|
|
|
|
|
|
|
def select_model(args, device):
|
|
|
|
|
|
|
|
|
model_id = args.model_id
|
|
|
if model_id == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from models.team00_EFDN import EFDN
|
|
|
name, data_range = f"{model_id:02}_EFDN_baseline", 1.0
|
|
|
model_path = os.path.join('model_zoo', 'team00_EFDN.pth')
|
|
|
model = EFDN()
|
|
|
model.load_state_dict(torch.load(model_path), strict=True)
|
|
|
elif model_id == 23:
|
|
|
from models.team23_DSCF import DSCF
|
|
|
|
|
|
name, data_range = f"{model_id:02}_DSCF", 1.0
|
|
|
model_path = os.path.join('model_zoo', 'team23_DSCF.pth')
|
|
|
model = DSCF(3,3,feature_channels=26,upscale=4)
|
|
|
state_dict = torch.load(model_path)
|
|
|
|
|
|
model.load_state_dict(state_dict, strict=False)
|
|
|
else:
|
|
|
raise NotImplementedError(f"Model {model_id} is not implemented.")
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
tile = None
|
|
|
for k, v in model.named_parameters():
|
|
|
v.requires_grad = False
|
|
|
model = model.to(device)
|
|
|
return model, name, data_range, tile
|
|
|
|
|
|
|
|
|
def select_dataset(data_dir, mode):
|
|
|
|
|
|
if mode == "test":
|
|
|
path = [
|
|
|
(
|
|
|
p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
|
|
p
|
|
|
) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_test_HR/*.png")))
|
|
|
]
|
|
|
|
|
|
|
|
|
elif mode == "valid":
|
|
|
path = [
|
|
|
(
|
|
|
p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
|
|
p
|
|
|
) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_valid_HR/*.png")))
|
|
|
]
|
|
|
else:
|
|
|
raise NotImplementedError(f"{mode} is not implemented in select_dataset")
|
|
|
|
|
|
return path
|
|
|
|
|
|
|
|
|
def forward(img_lq, model, tile=None, tile_overlap=32, scale=4):
|
|
|
if tile is None:
|
|
|
|
|
|
output = model(img_lq)
|
|
|
else:
|
|
|
|
|
|
b, c, h, w = img_lq.size()
|
|
|
tile = min(tile, h, w)
|
|
|
tile_overlap = tile_overlap
|
|
|
sf = scale
|
|
|
|
|
|
stride = tile - tile_overlap
|
|
|
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
|
|
|
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
|
|
|
E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
|
|
|
W = torch.zeros_like(E)
|
|
|
|
|
|
for h_idx in h_idx_list:
|
|
|
for w_idx in w_idx_list:
|
|
|
in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
|
|
|
out_patch = model(in_patch)
|
|
|
out_patch_mask = torch.ones_like(out_patch)
|
|
|
|
|
|
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
|
|
|
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
|
|
|
output = E.div_(W)
|
|
|
|
|
|
return output
|
|
|
|
|
|
def run(model, model_name, data_range, tile, logger, device, args, mode="test"):
|
|
|
|
|
|
sf = 4
|
|
|
border = sf
|
|
|
results = dict()
|
|
|
results[f"{mode}_runtime"] = []
|
|
|
results[f"{mode}_psnr"] = []
|
|
|
if args.ssim:
|
|
|
results[f"{mode}_ssim"] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_path = select_dataset(args.data_dir, mode)
|
|
|
save_path = os.path.join(args.save_dir, model_name, mode)
|
|
|
util.mkdir(save_path)
|
|
|
|
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
|
|
|
|
for i, (img_lr, img_hr) in enumerate(data_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_name, ext = os.path.splitext(os.path.basename(img_hr))
|
|
|
img_lr = util.imread_uint(img_lr, n_channels=3)
|
|
|
img_lr = util.uint2tensor4(img_lr, data_range)
|
|
|
img_lr = img_lr.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start.record()
|
|
|
img_sr = forward(img_lr, model, tile)
|
|
|
end.record()
|
|
|
torch.cuda.synchronize()
|
|
|
results[f"{mode}_runtime"].append(start.elapsed_time(end))
|
|
|
img_sr = util.tensor2uint(img_sr, data_range)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_hr = util.imread_uint(img_hr, n_channels=3)
|
|
|
img_hr = img_hr.squeeze()
|
|
|
img_hr = util.modcrop(img_hr, sf)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
psnr = util.calculate_psnr(img_sr, img_hr, border=border)
|
|
|
results[f"{mode}_psnr"].append(psnr)
|
|
|
|
|
|
if args.ssim:
|
|
|
ssim = util.calculate_ssim(img_sr, img_hr, border=border)
|
|
|
results[f"{mode}_ssim"].append(ssim)
|
|
|
logger.info("{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.".format(img_name + ext, psnr, ssim))
|
|
|
else:
|
|
|
logger.info("{:s} - PSNR: {:.2f} dB".format(img_name + ext, psnr))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
results[f"{mode}_memory"] = torch.cuda.max_memory_allocated(torch.cuda.current_device()) / 1024 ** 2
|
|
|
results[f"{mode}_ave_runtime"] = sum(results[f"{mode}_runtime"]) / len(results[f"{mode}_runtime"])
|
|
|
results[f"{mode}_ave_psnr"] = sum(results[f"{mode}_psnr"]) / len(results[f"{mode}_psnr"])
|
|
|
if args.ssim:
|
|
|
results[f"{mode}_ave_ssim"] = sum(results[f"{mode}_ssim"]) / len(results[f"{mode}_ssim"])
|
|
|
|
|
|
|
|
|
logger.info("{:>16s} : {:<.3f} [M]".format("Max Memory", results[f"{mode}_memory"]))
|
|
|
logger.info("------> Average runtime of ({}) is : {:.6f} milliseconds".format("test" if mode == "test" else "valid", results[f"{mode}_ave_runtime"]))
|
|
|
logger.info("------> Average PSNR of ({}) is : {:.6f} dB".format("test" if mode == "test" else "valid", results[f"{mode}_ave_psnr"]))
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
|
|
utils_logger.logger_info("NTIRE2025-EfficientSR", log_path="NTIRE2025-EfficientSR.log")
|
|
|
logger = logging.getLogger("NTIRE2025-EfficientSR")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.current_device()
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
json_dir = os.path.join(os.getcwd(), "results.json")
|
|
|
if not os.path.exists(json_dir):
|
|
|
results = dict()
|
|
|
else:
|
|
|
with open(json_dir, "r") as f:
|
|
|
results = json.load(f)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model, model_name, data_range, tile = select_model(args, device)
|
|
|
logger.info(model_name)
|
|
|
|
|
|
|
|
|
if True:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_results = run(model, model_name, data_range, tile, logger, device, args, mode="valid")
|
|
|
|
|
|
results[model_name] = valid_results
|
|
|
|
|
|
|
|
|
if args.include_test:
|
|
|
test_results = run(model, model_name, data_range, tile, logger, device, args, mode="test")
|
|
|
results[model_name].update(test_results)
|
|
|
|
|
|
input_dim = (3, 256, 256)
|
|
|
activations, num_conv = get_model_activation(model, input_dim)
|
|
|
activations = activations/10**6
|
|
|
logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
|
|
|
logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_fake = torch.rand(1, 3, 256, 256).to(device)
|
|
|
flops = FlopCountAnalysis(model, input_fake).total()
|
|
|
flops = flops/10**9
|
|
|
logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
|
|
|
|
|
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
|
|
|
num_parameters = num_parameters/10**6
|
|
|
logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
|
|
|
results[model_name].update({"activations": activations, "num_conv": num_conv, "flops": flops, "num_parameters": num_parameters})
|
|
|
|
|
|
with open(json_dir, "w") as f:
|
|
|
json.dump(results, f)
|
|
|
if args.include_test:
|
|
|
fmt = "{:20s}\t{:10s}\t{:10s}\t{:14s}\t{:14s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
|
|
s = fmt.format("Model", "Val PSNR", "Test PSNR", "Val Time [ms]", "Test Time [ms]", "Ave Time [ms]",
|
|
|
"Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
|
|
else:
|
|
|
fmt = "{:20s}\t{:10s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
|
|
s = fmt.format("Model", "Val PSNR", "Val Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
|
|
for k, v in results.items():
|
|
|
val_psnr = f"{v['valid_ave_psnr']:2.2f}"
|
|
|
val_time = f"{v['valid_ave_runtime']:3.2f}"
|
|
|
mem = f"{v['valid_memory']:2.2f}"
|
|
|
|
|
|
num_param = f"{v['num_parameters']:2.3f}"
|
|
|
flops = f"{v['flops']:2.2f}"
|
|
|
acts = f"{v['activations']:2.2f}"
|
|
|
conv = f"{v['num_conv']:4d}"
|
|
|
if args.include_test:
|
|
|
|
|
|
test_psnr = f"{v['test_ave_psnr']:2.2f}"
|
|
|
test_time = f"{v['test_ave_runtime']:3.2f}"
|
|
|
ave_time = f"{(v['valid_ave_runtime'] + v['test_ave_runtime']) / 2:3.2f}"
|
|
|
s += fmt.format(k, val_psnr, test_psnr, val_time, test_time, ave_time, num_param, flops, acts, mem, conv)
|
|
|
else:
|
|
|
s += fmt.format(k, val_psnr, val_time, num_param, flops, acts, mem, conv)
|
|
|
with open(os.path.join(os.getcwd(), 'results.txt'), "w") as f:
|
|
|
f.write(s)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser("NTIRE2025-EfficientSR")
|
|
|
parser.add_argument("--data_dir", default="../", type=str)
|
|
|
parser.add_argument("--save_dir", default="../results", type=str)
|
|
|
parser.add_argument("--model_id", default=0, type=int)
|
|
|
parser.add_argument("--include_test", action="store_true", help="Inference on the `DIV2K_LSDIR_test` set")
|
|
|
parser.add_argument("--ssim", action="store_true", help="Calculate SSIM")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
pprint(args)
|
|
|
|
|
|
main(args)
|
|
|
|