File size: 4,510 Bytes
2860768
 
 
 
 
 
 
3d9d9ee
3e80769
 
 
 
 
2860768
 
 
 
 
 
 
 
 
3d9d9ee
 
 
2860768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e80769
3d9d9ee
 
2860768
 
 
 
 
 
 
 
 
 
 
 
3d9d9ee
2860768
 
 
 
 
 
 
3e80769
2860768
 
3e80769
2860768
3d9d9ee
 
2860768
 
 
 
3d9d9ee
2860768
3e80769
3d9d9ee
 
 
2860768
3e80769
 
2860768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Prompt Formatter Service
Handles formatting prompts appropriately for different model types:
- Completion models: Raw text continuation
- Instruction models: System prompt + user message with chat template
"""

from typing import Dict, Optional, Any
import logging

logger = logging.getLogger(__name__)


class PromptFormatter:
    """
    Unified prompt formatting for different model types.

    Completion models (CodeGen, Code Llama base):
        - Pass prompt through unchanged
        - Model treats it as text to continue

    Instruction models (Devstral, instruct variants):
        - Wrap with system prompt + user message
        - Use tokenizer's chat_template if available
        - Fallback to manual Mistral format
    """

    def format(
        self,
        prompt: str,
        model_config: Dict[str, Any],
        tokenizer: Any,
        system_prompt_override: Optional[str] = None
    ) -> str:
        """
        Format a prompt appropriately for the model type.

        Args:
            prompt: The user's input (e.g., "def quicksort(arr):")
            model_config: Model configuration from model_config.py
            tokenizer: HuggingFace tokenizer for the model
            system_prompt_override: Optional override for the default system prompt

        Returns:
            Formatted prompt ready for tokenization
        """
        prompt_style = model_config.get("prompt_style", "completion")

        if prompt_style == "instruction":
            return self._format_instruction(
                prompt,
                model_config,
                tokenizer,
                system_prompt_override
            )

        # Completion style: return raw prompt
        return prompt

    def _format_instruction(
        self,
        prompt: str,
        model_config: Dict[str, Any],
        tokenizer: Any,
        system_prompt_override: Optional[str] = None
    ) -> str:
        """
        Format prompt for instruction-tuned models.

        Priority:
        1. Tokenizer's native chat_template (if available)
        2. Manual Mistral format fallback
        """
        # Get system prompt (override > model default > generic fallback)
        system_prompt = system_prompt_override or model_config.get("system_prompt")
        if not system_prompt:
            system_prompt = "You are a helpful coding assistant. Continue the code provided."

        # Build messages list
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ]

        # Try tokenizer's native chat template first
        if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None:
            try:
                formatted = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                logger.info("Used HF tokenizer chat_template")
                return formatted
            except Exception as e:
                logger.warning(f"chat_template failed: {e}, using manual format")

        # Fallback: Manual Mistral/Llama instruction format
        # Note: We DON'T include <s> - the tokenizer adds BOS automatically
        return self._manual_mistral_format(prompt, system_prompt)

    def _manual_mistral_format(self, prompt: str, system_prompt: str) -> str:
        """
        Manual Mistral instruction format.

        Format: [INST] {system}\n\n{user} [/INST]

        Note: BOS token (<s>) is NOT included - the tokenizer adds it
        automatically during tokenization with add_special_tokens=True (default).
        """
        logger.info("Using manual Mistral instruction format")
        return f"[INST] {system_prompt}\n\n{prompt} [/INST]"


# Singleton instance for convenience
_formatter = PromptFormatter()


def format_prompt(
    prompt: str,
    model_config: Dict[str, Any],
    tokenizer: Any,
    system_prompt_override: Optional[str] = None
) -> str:
    """
    Convenience function to format a prompt.

    Args:
        prompt: The user's input (e.g., "def quicksort(arr):")
        model_config: Model configuration from model_config.py
        tokenizer: HuggingFace tokenizer for the model
        system_prompt_override: Optional override for the default system prompt

    Returns:
        Formatted prompt ready for tokenization
    """
    return _formatter.format(prompt, model_config, tokenizer, system_prompt_override)