gulupgulup commited on
Commit
d5e4d46
·
verified ·
1 Parent(s): 4ba51ef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +139 -1
README.md CHANGED
@@ -13,4 +13,142 @@ tags:
13
  - textclassification
14
  - distilbert
15
  pipeline_tag: text-classification
16
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  - textclassification
14
  - distilbert
15
  pipeline_tag: text-classification
16
+ ---
17
+ # Model Card for Model ID
18
+
19
+ This is a Natural Language Inference (NLI) model built by fine-tuning DistilBERT-base-uncased on the GPT-3 NLI dataset. The model performs textual entailment classification - given two pieces of text (a premise and a hypothesis), it determines the logical relationship between them.
20
+
21
+
22
+ ## Model Details
23
+
24
+ ### Model Description
25
+
26
+ What it does:
27
+
28
+ Takes two text inputs: a premise (text_a) and a hypothesis (text_b)
29
+
30
+ Classifies their relationship into one of three categories:
31
+
32
+ **Entailment**: The hypothesis logically follows from the premise
33
+
34
+ **Neutral**: The hypothesis is neither supported nor contradicted by the premise
35
+
36
+ **Contradiction**: The hypothesis contradicts the premise
37
+
38
+ Use Cases:
39
+
40
+ - Reading comprehension tasks
41
+
42
+ - Logical reasoning applications
43
+
44
+ - Question-answering systems
45
+
46
+ - Text coherence analysis
47
+
48
+ - Information verification tasks
49
+
50
+ **Architecture**: DistilBERT-based sequence classification model with 3 output classes, optimized for efficiency while maintaining strong performance on natural language understanding tasks.
51
+
52
+ This type of model is fundamental for applications requiring understanding of logical relationships between text passages, such as fact-checking, automated reasoning, and reading comprehension systems.
53
+
54
+
55
+ ## How to Get Started with the Model
56
+
57
+ ``` python
58
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
59
+ import torch
60
+
61
+ # Load the model and tokenizer
62
+ model_name = "gulupgulup/distilbert_nli"
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
65
+ ```
66
+
67
+ ### Usage Example
68
+
69
+ ``` python
70
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
71
+ import torch
72
+
73
+ # Load the model and tokenizer
74
+ model_name = "gulupgulup/distilbert_nli"
75
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
76
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
77
+
78
+ # Example premise and hypothesis
79
+ premise = "A person is riding a bicycle in the park."
80
+ hypothesis = "Someone is exercising outdoors."
81
+
82
+ # Tokenize the input
83
+ inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True, padding=True)
84
+
85
+ # Make prediction
86
+ with torch.no_grad():
87
+ outputs = model(**inputs)
88
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
89
+ predicted_class = torch.argmax(predictions, dim=-1)
90
+
91
+ # Get the predicted label
92
+ id2label = {0: "entailment", 1: "neutral", 2: "contradiction"}
93
+ predicted_label = id2label[predicted_class.item()]
94
+
95
+ print(f"Premise: {premise}")
96
+ print(f"Hypothesis: {hypothesis}")
97
+ print(f"Predicted relationship: {predicted_label}")
98
+ print(f"Confidence scores: {predictions.squeeze().tolist()}")
99
+ ```
100
+
101
+ ## Training Details
102
+
103
+ ### Training Data
104
+
105
+ Dataset: \href{https://huggingface.co/datasets/pietrolesci/gpt3_nli}{pietrolesci/gpt3_nli} - A natural language inference dataset containing premise-hypothesis pairs with three-class labels (entailment, neutral, contradiction). The dataset consists of text pairs (text_a and text_b) where the model learns to determine the logical relationship between the premise and hypothesis.ed]
106
+
107
+
108
+ ### Training Procedure
109
+
110
+ **Base Model**: DistilBERT-base-uncased fine-tuned for sequence classification with 3 output labels for natural language inference.
111
+
112
+ **Training Framework**: Hugging Face Transformers Trainer with Weights & Biases (wandb) integration for experiment tracking.
113
+
114
+ **Data Split**: The original training set was split into train (81%), validation (9%), and test (10%) sets using stratified sampling to maintain label distribution balance across splits.
115
+
116
+
117
+ #### Preprocessing [optional]
118
+
119
+ Text pairs are tokenized using DistilBERT's tokenizer with truncation and padding applied. The label column is cast to ClassLabel format with three categories: entailment, neutral, and contradiction.
120
+
121
+ **Data Handling**: Uses DataCollatorWithPadding for dynamic padding during training and tokenizes premise-hypothesis pairs jointly.
122
+
123
+
124
+ #### Training Hyperparameters
125
+
126
+ **Learning Rate**: 1e-5
127
+
128
+ **Batch Size**: 64 (both training and evaluation)
129
+
130
+ **Number of Epochs**: 5
131
+
132
+ **Weight Decay**: 0.01
133
+
134
+ **Max Gradient Norm**: 1.0
135
+
136
+ **Optimizer**: AdamW (default)
137
+
138
+ **Evaluation Strategy**: Every epoch
139
+
140
+ **Save Strategy**: Every epoch
141
+
142
+ **Logging Steps**: 100
143
+
144
+ **Best Model Selection**: Based on validation accuracy (higher is better)
145
+
146
+ ## Evaluation
147
+
148
+ ### Metrics
149
+
150
+ **Accuracy**: Primary evaluation metric measuring the percentage of correctly classified premise-hypothesis pairs across all three NLI categories.
151
+
152
+ **Precision** (Macro-averaged): Secondary metric calculating the average precision across all three classes (entailment, neutral, contradiction), giving equal weight to each class regardless of support. This metric is useful for understanding model performance on each NLI relationship type, especially important when dealing with potentially imbalanced class distributions.
153
+
154
+ Both metrics are computed using the evaluate library and rounded to 3 decimal places for reporting.