api / backend /prompt_formatter.py
gary-boon
Remove mistral_common to fix dependency conflict
3d9d9ee
raw
history blame
4.51 kB
"""
Prompt Formatter Service
Handles formatting prompts appropriately for different model types:
- Completion models: Raw text continuation
- Instruction models: System prompt + user message with chat template
"""
from typing import Dict, Optional, Any
import logging
logger = logging.getLogger(__name__)
class PromptFormatter:
"""
Unified prompt formatting for different model types.
Completion models (CodeGen, Code Llama base):
- Pass prompt through unchanged
- Model treats it as text to continue
Instruction models (Devstral, instruct variants):
- Wrap with system prompt + user message
- Use tokenizer's chat_template if available
- Fallback to manual Mistral format
"""
def format(
self,
prompt: str,
model_config: Dict[str, Any],
tokenizer: Any,
system_prompt_override: Optional[str] = None
) -> str:
"""
Format a prompt appropriately for the model type.
Args:
prompt: The user's input (e.g., "def quicksort(arr):")
model_config: Model configuration from model_config.py
tokenizer: HuggingFace tokenizer for the model
system_prompt_override: Optional override for the default system prompt
Returns:
Formatted prompt ready for tokenization
"""
prompt_style = model_config.get("prompt_style", "completion")
if prompt_style == "instruction":
return self._format_instruction(
prompt,
model_config,
tokenizer,
system_prompt_override
)
# Completion style: return raw prompt
return prompt
def _format_instruction(
self,
prompt: str,
model_config: Dict[str, Any],
tokenizer: Any,
system_prompt_override: Optional[str] = None
) -> str:
"""
Format prompt for instruction-tuned models.
Priority:
1. Tokenizer's native chat_template (if available)
2. Manual Mistral format fallback
"""
# Get system prompt (override > model default > generic fallback)
system_prompt = system_prompt_override or model_config.get("system_prompt")
if not system_prompt:
system_prompt = "You are a helpful coding assistant. Continue the code provided."
# Build messages list
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
# Try tokenizer's native chat template first
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None:
try:
formatted = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
logger.info("Used HF tokenizer chat_template")
return formatted
except Exception as e:
logger.warning(f"chat_template failed: {e}, using manual format")
# Fallback: Manual Mistral/Llama instruction format
# Note: We DON'T include <s> - the tokenizer adds BOS automatically
return self._manual_mistral_format(prompt, system_prompt)
def _manual_mistral_format(self, prompt: str, system_prompt: str) -> str:
"""
Manual Mistral instruction format.
Format: [INST] {system}\n\n{user} [/INST]
Note: BOS token (<s>) is NOT included - the tokenizer adds it
automatically during tokenization with add_special_tokens=True (default).
"""
logger.info("Using manual Mistral instruction format")
return f"[INST] {system_prompt}\n\n{prompt} [/INST]"
# Singleton instance for convenience
_formatter = PromptFormatter()
def format_prompt(
prompt: str,
model_config: Dict[str, Any],
tokenizer: Any,
system_prompt_override: Optional[str] = None
) -> str:
"""
Convenience function to format a prompt.
Args:
prompt: The user's input (e.g., "def quicksort(arr):")
model_config: Model configuration from model_config.py
tokenizer: HuggingFace tokenizer for the model
system_prompt_override: Optional override for the default system prompt
Returns:
Formatted prompt ready for tokenization
"""
return _formatter.format(prompt, model_config, tokenizer, system_prompt_override)