Spaces:
Paused
Paused
Andrei-Iulian SĂCELEANU commited on
Commit ·
37f6940
1
Parent(s): b922f84
fix error for lp and contr
Browse files- .gitignore +1 -0
- app.py +2 -2
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Pipfile*
|
app.py
CHANGED
|
@@ -56,12 +56,12 @@ def ssl_predict(in_text, model_type):
|
|
| 56 |
|
| 57 |
elif model_type == "contrastive_reg":
|
| 58 |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
| 59 |
-
model.
|
| 60 |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
| 61 |
|
| 62 |
elif model_type == "label_propagation":
|
| 63 |
model = LPModel()
|
| 64 |
-
model.
|
| 65 |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
| 66 |
|
| 67 |
probs = list(preds[0].numpy())
|
|
|
|
| 56 |
|
| 57 |
elif model_type == "contrastive_reg":
|
| 58 |
model = FixMatchTune(encoder_name="readerbench/RoBERT-base")
|
| 59 |
+
model.load_weights("./checkpoints/contrastive")
|
| 60 |
preds, _ = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
| 61 |
|
| 62 |
elif model_type == "label_propagation":
|
| 63 |
model = LPModel()
|
| 64 |
+
model.load_weights("./checkpoints/label_prop")
|
| 65 |
preds = model([toks["input_ids"],toks["attention_mask"]], training=False)
|
| 66 |
|
| 67 |
probs = list(preds[0].numpy())
|