gary-boon Claude Opus 4.5 commited on
Commit
ed06dcb
·
1 Parent(s): 3d9d9ee

Integrate mistral-common for correct Devstral tokenization

Browse files

Root cause: Devstral's Tekken tokenizer is incompatible with HuggingFace's
standard tokenization. When [INST]/[/INST] are formatted as text and
tokenized with HF, the model receives corrupted tokens -> garbage output.

Solution:
- Add mistral-common>=1.5.0 dependency
- Create MistralTokenizerWrapper using MistralTokenizer.from_hf_hub()
- Use encode_chat_completion() for correct Tekken token encoding
- Relax numpy/pydantic version constraints for compatibility

The mistral-common library produces correct token sequences by encoding
chat messages directly to token IDs, bypassing text-based formatting.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>

backend/mistral_tokenizer.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mistral Tokenizer Wrapper
3
+ Provides correct tokenization for Devstral using mistral-common library.
4
+
5
+ The Tekken tokenizer used by Devstral is incompatible with HuggingFace's
6
+ standard tokenization approach. This wrapper uses mistral-common to
7
+ produce correct token sequences for the model.
8
+ """
9
+
10
+ import logging
11
+ from typing import List, Optional
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MistralTokenizerWrapper:
17
+ """
18
+ Wrapper around mistral-common's MistralTokenizer for Devstral.
19
+
20
+ Uses encode_chat_completion() to produce correct token IDs
21
+ that the model actually expects, rather than HF's text-based approach
22
+ which produces corrupted tokens for Tekken-based models.
23
+ """
24
+
25
+ def __init__(self, model_name: str):
26
+ """
27
+ Initialize the Mistral tokenizer from HuggingFace hub.
28
+
29
+ Args:
30
+ model_name: HuggingFace model path (e.g., "mistralai/Devstral-Small-2507")
31
+ """
32
+ try:
33
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
34
+ self.tokenizer = MistralTokenizer.from_hf_hub(model_name)
35
+ self._available = True
36
+ logger.info(f"Loaded MistralTokenizer for {model_name}")
37
+ except ImportError as e:
38
+ logger.warning(f"mistral-common not available: {e}")
39
+ self._available = False
40
+ self.tokenizer = None
41
+ except Exception as e:
42
+ logger.error(f"Failed to load MistralTokenizer: {e}")
43
+ self._available = False
44
+ self.tokenizer = None
45
+
46
+ @property
47
+ def is_available(self) -> bool:
48
+ """Check if the tokenizer was loaded successfully."""
49
+ return self._available
50
+
51
+ def encode_chat(
52
+ self,
53
+ system_prompt: str,
54
+ user_prompt: str
55
+ ) -> List[int]:
56
+ """
57
+ Encode chat messages to token IDs using mistral-common.
58
+
59
+ This produces the correct token sequence for Devstral, including
60
+ proper handling of control tokens like [INST] and [/INST].
61
+
62
+ Args:
63
+ system_prompt: System message content
64
+ user_prompt: User message content (e.g., "def quicksort(arr):")
65
+
66
+ Returns:
67
+ List of token IDs ready for model input
68
+ """
69
+ if not self._available:
70
+ raise RuntimeError("MistralTokenizer not available")
71
+
72
+ from mistral_common.protocol.instruct.messages import (
73
+ SystemMessage, UserMessage
74
+ )
75
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
76
+
77
+ # Build messages list
78
+ messages = []
79
+ if system_prompt:
80
+ messages.append(SystemMessage(content=system_prompt))
81
+ messages.append(UserMessage(content=user_prompt))
82
+
83
+ # Encode using mistral-common's chat completion encoding
84
+ request = ChatCompletionRequest(messages=messages)
85
+ tokenized = self.tokenizer.encode_chat_completion(request)
86
+
87
+ logger.info(f"Encoded chat: {len(tokenized.tokens)} tokens")
88
+ return tokenized.tokens
89
+
90
+ def decode(self, token_ids: List[int]) -> str:
91
+ """
92
+ Decode token IDs back to text.
93
+
94
+ Args:
95
+ token_ids: List of token IDs to decode
96
+
97
+ Returns:
98
+ Decoded text string
99
+ """
100
+ if not self._available:
101
+ raise RuntimeError("MistralTokenizer not available")
102
+
103
+ return self.tokenizer.decode(token_ids)
104
+
105
+ def decode_token(self, token_id: int) -> str:
106
+ """
107
+ Decode a single token ID to text.
108
+
109
+ Args:
110
+ token_id: Single token ID to decode
111
+
112
+ Returns:
113
+ Decoded text for this token
114
+ """
115
+ if not self._available:
116
+ raise RuntimeError("MistralTokenizer not available")
117
+
118
+ return self.tokenizer.decode([token_id])
119
+
120
+
121
+ def create_mistral_tokenizer(model_name: str) -> Optional[MistralTokenizerWrapper]:
122
+ """
123
+ Factory function to create a MistralTokenizerWrapper.
124
+
125
+ Returns None if mistral-common is not available or loading fails.
126
+
127
+ Args:
128
+ model_name: HuggingFace model path
129
+
130
+ Returns:
131
+ MistralTokenizerWrapper instance or None
132
+ """
133
+ wrapper = MistralTokenizerWrapper(model_name)
134
+ if wrapper.is_available:
135
+ return wrapper
136
+ return None
backend/model_service.py CHANGED
@@ -229,6 +229,16 @@ class ModelManager:
229
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
230
  self.tokenizer.pad_token = self.tokenizer.eos_token
231
 
 
 
 
 
 
 
 
 
 
 
232
  # Create model adapter for multi-model support
233
  from .model_adapter import create_adapter
234
  try:
@@ -1514,11 +1524,22 @@ async def analyze_research_attention(request: Dict[str, Any], authenticated: boo
1514
  temperature = model_config["recommended_temperature"]
1515
  logger.info(f"Using model recommended temperature={temperature}")
1516
 
1517
- # Tokenize and prepare
1518
- inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device)
1519
- prompt_length = inputs["input_ids"].shape[1]
1520
- prompt_token_ids = inputs["input_ids"][0].tolist()
1521
- prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids]
 
 
 
 
 
 
 
 
 
 
 
1522
 
1523
  # Storage for generation
1524
  generated_token_ids = []
 
229
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
230
  self.tokenizer.pad_token = self.tokenizer.eos_token
231
 
232
+ # For Devstral, also load MistralTokenizer for correct encoding
233
+ self.mistral_tokenizer = None
234
+ if self.model_id == "devstral-small":
235
+ from .mistral_tokenizer import create_mistral_tokenizer
236
+ self.mistral_tokenizer = create_mistral_tokenizer(self.model_name)
237
+ if self.mistral_tokenizer:
238
+ logger.info("Loaded MistralTokenizer for Devstral (correct Tekken encoding)")
239
+ else:
240
+ logger.warning("MistralTokenizer not available - Devstral may produce garbage output")
241
+
242
  # Create model adapter for multi-model support
243
  from .model_adapter import create_adapter
244
  try:
 
1524
  temperature = model_config["recommended_temperature"]
1525
  logger.info(f"Using model recommended temperature={temperature}")
1526
 
1527
+ # Tokenize and prepare - use MistralTokenizer for Devstral
1528
+ if manager.model_id == "devstral-small" and manager.mistral_tokenizer is not None:
1529
+ # Use MistralTokenizer for correct Tekken encoding
1530
+ system_prompt = system_prompt_override or (model_config.get("system_prompt") if model_config else "")
1531
+ prompt_token_ids = manager.mistral_tokenizer.encode_chat(system_prompt, prompt)
1532
+ inputs = {"input_ids": torch.tensor([prompt_token_ids]).to(manager.device)}
1533
+ prompt_length = len(prompt_token_ids)
1534
+ # Decode tokens using MistralTokenizer for accuracy
1535
+ prompt_tokens = [manager.mistral_tokenizer.decode_token(tid) for tid in prompt_token_ids]
1536
+ logger.info(f"Used MistralTokenizer for Devstral: {prompt_length} tokens")
1537
+ else:
1538
+ # Standard HF tokenization for other models
1539
+ inputs = manager.tokenizer(formatted_prompt, return_tensors="pt").to(manager.device)
1540
+ prompt_length = inputs["input_ids"].shape[1]
1541
+ prompt_token_ids = inputs["input_ids"][0].tolist()
1542
+ prompt_tokens = [manager.tokenizer.decode([tid], skip_special_tokens=False) for tid in prompt_token_ids]
1543
 
1544
  # Storage for generation
1545
  generated_token_ids = []
requirements.txt CHANGED
@@ -3,16 +3,17 @@ fastapi==0.104.1
3
  uvicorn[standard]==0.24.0
4
  websockets==12.0
5
  python-multipart==0.0.6
6
- pydantic==2.5.0
7
 
8
  # Machine Learning
9
  # torch 2.3+ required for transformers 4.44+ (pytree API compatibility)
10
  torch>=2.3.0
11
  transformers>=4.44.0
12
  accelerate>=0.30.0
 
13
 
14
  # Utilities
15
- numpy==1.24.3
16
  aiofiles==23.2.1
17
  python-dotenv==1.0.0
18
  zarr==2.14.2
 
3
  uvicorn[standard]==0.24.0
4
  websockets==12.0
5
  python-multipart==0.0.6
6
+ pydantic>=2.0.0 # Relaxed for mistral-common compatibility
7
 
8
  # Machine Learning
9
  # torch 2.3+ required for transformers 4.44+ (pytree API compatibility)
10
  torch>=2.3.0
11
  transformers>=4.44.0
12
  accelerate>=0.30.0
13
+ mistral-common>=1.5.0 # Required for Devstral Tekken tokenizer
14
 
15
  # Utilities
16
+ numpy>=1.24.0,<2.0 # Relaxed for mistral-common compatibility
17
  aiofiles==23.2.1
18
  python-dotenv==1.0.0
19
  zarr==2.14.2