Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import numpy as np | |
| import seaborn as sns | |
| from matplotlib import pyplot as plt | |
| import umap | |
| def dim_reduction(target_embeddings, umap_dim=2, n_neighbors=15, min_dist=0.1): | |
| """ | |
| Dimension reduction using UMAP. | |
| """ | |
| reducer = umap.UMAP(n_neighbors=n_neighbors, n_components=umap_dim, min_dist=min_dist, metric='cosine', random_state=500) | |
| embeddings = reducer.fit_transform(target_embeddings) | |
| return embeddings | |
| def clustering_plot(target_label, embeddings, label_trues, model_preds=None, umap_dim=2, n_neighbors=15, min_dist=0.1): | |
| """ | |
| Plot the clustering results. | |
| """ | |
| label_dict = {0:'Abstract', 1:'Introduction', 2:'Main', 3:'Methods', 4:'Summary', 5:'Captions'} | |
| target_index = np.where(label_trues == target_label)[0] | |
| trues = label_trues[target_index] | |
| embeddings = embeddings[target_index] | |
| embeddings = dim_reduction(embeddings, umap_dim=umap_dim, n_neighbors=n_neighbors, min_dist=min_dist) | |
| df = pd.DataFrame(embeddings, columns=['x', 'y']) | |
| df['true'] = trues | |
| df['true'] = df['true'].map(label_dict) | |
| if model_preds is not None: | |
| df['pred'] = model_preds[target_index] | |
| df['pred'] = df['pred'].map(label_dict) | |
| sns.scatterplot(x='x', y='y', hue='true', data=df, palette='Set2') | |
| plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | |
| plt.show() | |
| if model_preds is not None: | |
| sns.scatterplot(x='x', y='y', hue='pred', data=df, palette='Set2') | |
| plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0) | |
| plt.show() | |
| return df | |