| | from scipy.spatial.distance import cdist |
| | from scipy.optimize import linear_sum_assignment |
| | import numpy as np |
| |
|
| | def compute_WED(pd_vertices, pd_edges, gt_vertices, gt_edges, cv=1.0, ce=1.0, normalized=True, squared=False): |
| | pd_vertices = np.array(pd_vertices) |
| | gt_vertices = np.array(gt_vertices) |
| | pd_edges = np.array(pd_edges) |
| | gt_edges = np.array(gt_edges) |
| | |
| | |
| | if squared: |
| | distances = cdist(pd_vertices, gt_vertices, metric='sqeuclidean') |
| | else: |
| | distances = cdist(pd_vertices, gt_vertices, metric='euclidean') |
| |
|
| | row_ind, col_ind = linear_sum_assignment(distances) |
| | |
| | |
| | |
| | if squared: |
| | translation_costs = cv * np.sqrt(np.sum(distances[row_ind, col_ind])) |
| | else: |
| | translation_costs = cv * np.sum(distances[row_ind, col_ind]) |
| | |
| | |
| | unmatched_pd_indices = set(range(len(pd_vertices))) - set(row_ind) |
| | deletion_costs = cv * len(unmatched_pd_indices) |
| | |
| | |
| | unmatched_gt_indices = set(range(len(gt_vertices))) - set(col_ind) |
| | insertion_costs = cv * len(unmatched_gt_indices) |
| | |
| | |
| | updated_pd_edges = [(row_ind[np.where(col_ind == edge[0])[0][0]], row_ind[np.where(col_ind == edge[1])[0][0]]) for edge in pd_edges if edge[0] in col_ind and edge[1] in col_ind] |
| | pd_edges_set = set(map(tuple, updated_pd_edges)) |
| | gt_edges_set = set(map(tuple, gt_edges)) |
| | |
| | |
| | edges_to_delete = pd_edges_set - gt_edges_set |
| | deletion_edge_costs = ce * sum(np.linalg.norm(pd_vertices[edge[0]] - pd_vertices[edge[1]]) for edge in edges_to_delete) |
| | |
| | |
| | edges_to_insert = gt_edges_set - pd_edges_set |
| | insertion_edge_costs = ce * sum(np.linalg.norm(gt_vertices[edge[0]] - gt_vertices[edge[1]]) for edge in edges_to_insert) |
| | |
| | |
| | WED = translation_costs + deletion_costs + insertion_costs + deletion_edge_costs + insertion_edge_costs |
| | |
| | if normalized: |
| | total_length_of_gt_edges = np.linalg.norm((gt_vertices[gt_edges[:, 0]] - gt_vertices[gt_edges[:, 1]]), axis=1).sum() |
| | WED = WED / total_length_of_gt_edges |
| | |
| | return WED |