Spaces:
Runtime error
Runtime error
Blair Yang commited on
Commit ·
db9289d
1
Parent(s): a53cf00
debug
Browse files
app.py
CHANGED
|
@@ -24,25 +24,26 @@ def generate_plot(meta_index, topic_index):
|
|
| 24 |
|
| 25 |
data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")
|
| 26 |
|
| 27 |
-
topic_data = data[data['sub_topic'] == topic]
|
| 28 |
|
| 29 |
# Compute human and llm accuracy
|
| 30 |
topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
|
| 31 |
topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)
|
| 32 |
|
| 33 |
-
#
|
| 34 |
-
|
| 35 |
-
|
|
|
|
| 36 |
|
| 37 |
# Prepare the plot data
|
| 38 |
plot_data = []
|
| 39 |
|
| 40 |
-
# Define a consistent color scheme
|
| 41 |
colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray
|
| 42 |
-
|
| 43 |
-
|
| 44 |
# Add bars with error bars for the averages
|
| 45 |
-
for acc_type, color
|
| 46 |
plot_data.append(go.Bar(
|
| 47 |
x=mean_data['model_name'],
|
| 48 |
y=mean_data[acc_type],
|
|
@@ -52,7 +53,7 @@ def generate_plot(meta_index, topic_index):
|
|
| 52 |
visible=True
|
| 53 |
),
|
| 54 |
name=acc_type.split('_')[0].capitalize(),
|
| 55 |
-
marker=dict(color=color
|
| 56 |
))
|
| 57 |
|
| 58 |
# Layout
|
|
|
|
| 24 |
|
| 25 |
data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",")
|
| 26 |
|
| 27 |
+
topic_data = data.loc[data['sub_topic'] == topic].copy()
|
| 28 |
|
| 29 |
# Compute human and llm accuracy
|
| 30 |
topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan)
|
| 31 |
topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan)
|
| 32 |
|
| 33 |
+
# Selecting only numeric columns for aggregation
|
| 34 |
+
numeric_cols = ['no_responses_human', 'no_correct_human', 'no_responses_llm', 'no_correct_llm', 'oracle_acc', 'human_acc', 'llm_acc']
|
| 35 |
+
mean_data = topic_data.groupby('model_name')[numeric_cols].mean().reset_index()
|
| 36 |
+
std_deviation = topic_data.groupby('model_name')[numeric_cols].std().reset_index()
|
| 37 |
|
| 38 |
# Prepare the plot data
|
| 39 |
plot_data = []
|
| 40 |
|
| 41 |
+
# Define a consistent color scheme with different opacities
|
| 42 |
colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray
|
| 43 |
+
acc_types = ['oracle_acc', 'human_acc', 'llm_acc']
|
| 44 |
+
|
| 45 |
# Add bars with error bars for the averages
|
| 46 |
+
for acc_type, color in zip(acc_types, colors):
|
| 47 |
plot_data.append(go.Bar(
|
| 48 |
x=mean_data['model_name'],
|
| 49 |
y=mean_data[acc_type],
|
|
|
|
| 53 |
visible=True
|
| 54 |
),
|
| 55 |
name=acc_type.split('_')[0].capitalize(),
|
| 56 |
+
marker=dict(color=color)
|
| 57 |
))
|
| 58 |
|
| 59 |
# Layout
|