Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import BertTokenizer, BertModel
|
| 6 |
+
|
| 7 |
+
# Load pre-trained BERT model and tokenizer from HuggingFace
|
| 8 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 9 |
+
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
|
| 10 |
+
|
| 11 |
+
# App title and description
|
| 12 |
+
st.title("BERT Attention Map Visualizer")
|
| 13 |
+
st.write("""
|
| 14 |
+
## Introduction
|
| 15 |
+
This application visualizes the attention mechanism of the BERT model for a given input sentence.
|
| 16 |
+
The attention mechanism allows BERT to focus on different parts of the sentence when encoding each token,
|
| 17 |
+
providing insights into how the model understands the context and relationships between words.
|
| 18 |
+
This app showcases how BERT generates attention maps and word embeddings using a pre-trained BERT model.
|
| 19 |
+
|
| 20 |
+
### Attention Mechanism
|
| 21 |
+
The attention mechanism is a method to enhance the ability of the model to focus on important parts of the input sequence.
|
| 22 |
+
It computes a weighted sum of values (V) based on the similarity between queries (Q) and keys (K). The formulation is as follows:
|
| 23 |
+
|
| 24 |
+
$$
|
| 25 |
+
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V
|
| 26 |
+
$$
|
| 27 |
+
|
| 28 |
+
where:
|
| 29 |
+
- \( Q \) (Query): Represents the current token for which attention is being calculated.
|
| 30 |
+
- \( K \) (Key): Represents the tokens in the input sequence to compare against the query.
|
| 31 |
+
- \( V \) (Value): Represents the actual values used to compute the attention-weighted sum.
|
| 32 |
+
- \( d_k \): Dimension of the key vectors, used for scaling.
|
| 33 |
+
|
| 34 |
+
### Key, Query, and Value
|
| 35 |
+
- **Query (Q)**: Captures the essence of the word/token we are focusing on.
|
| 36 |
+
- **Key (K)**: Represents all words/tokens we are comparing the query against.
|
| 37 |
+
- **Value (V)**: Contains the information of all tokens that is aggregated based on attention scores.
|
| 38 |
+
|
| 39 |
+
This mechanism allows the model to dynamically adjust its focus on different parts of the sentence, thereby improving contextual understanding.
|
| 40 |
+
""")
|
| 41 |
+
|
| 42 |
+
# Input sentence from the user
|
| 43 |
+
sentence = st.text_input("Enter a sentence:", "The cat is on the mat")
|
| 44 |
+
|
| 45 |
+
# Tokenize and encode the sentence
|
| 46 |
+
inputs = tokenizer(sentence, return_tensors='pt', add_special_tokens=True)
|
| 47 |
+
|
| 48 |
+
# Get the embeddings and attention weights from BERT
|
| 49 |
+
outputs = model(**inputs)
|
| 50 |
+
attention = outputs.attentions # Extract attention weights directly from the pretrained model
|
| 51 |
+
attention_weights = attention[-1].squeeze(0) # Get attention from the last layer
|
| 52 |
+
|
| 53 |
+
# Function to visualize attention weights
|
| 54 |
+
def visualize_attention(tokens, attention_weights):
|
| 55 |
+
attention_weights = attention_weights.detach().numpy()
|
| 56 |
+
|
| 57 |
+
fig, ax = plt.subplots(figsize=(8, 8))
|
| 58 |
+
cax = ax.matshow(attention_weights, cmap='viridis')
|
| 59 |
+
|
| 60 |
+
plt.xticks(range(len(tokens)), tokens, rotation=90)
|
| 61 |
+
plt.yticks(range(len(tokens)), tokens)
|
| 62 |
+
|
| 63 |
+
fig.colorbar(cax)
|
| 64 |
+
plt.title("Attention Map")
|
| 65 |
+
st.pyplot(fig)
|
| 66 |
+
|
| 67 |
+
# Extract tokens including special tokens
|
| 68 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
| 69 |
+
|
| 70 |
+
# Remove special tokens for visualization
|
| 71 |
+
tokens_vis = [token for token in tokens if token not in tokenizer.all_special_tokens]
|
| 72 |
+
|
| 73 |
+
# Visualize the attention weights for the sentence excluding special tokens
|
| 74 |
+
visualize_attention(tokens_vis, attention_weights[0, 1:-1, 1:-1])
|
| 75 |
+
|
| 76 |
+
st.write("""
|
| 77 |
+
### About BERT
|
| 78 |
+
BERT (Bidirectional Encoder Representations from Transformers) is a transformer-based model designed to understand the context of words in a sentence. It uses the attention mechanism to weigh the importance of different words when generating word embeddings. This attention mechanism is crucial for tasks like language translation, sentiment analysis, and more.
|
| 79 |
+
""")
|