File size: 7,887 Bytes
d371d24
 
 
 
a17d990
 
d371d24
a17d990
d371d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98be1e8
d371d24
 
 
98be1e8
d371d24
 
 
 
 
 
 
 
 
 
 
98be1e8
 
d371d24
 
c2bdc87
 
 
 
 
 
 
 
 
 
d371d24
c2bdc87
d371d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba1df97
 
 
 
 
 
 
 
 
 
98be1e8
 
ba1df97
c2bdc87
 
 
 
 
 
 
 
 
ba1df97
c2bdc87
ba1df97
 
 
d371d24
 
 
 
 
 
 
 
98be1e8
d371d24
 
 
 
 
 
 
 
98be1e8
d371d24
 
222c0be
 
 
 
 
 
 
 
98be1e8
222c0be
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import pytest
from unittest.mock import MagicMock, patch
from app import get_quantization_recipe, compress_and_upload
import gradio as gr


from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
from llmcompressor.modifiers.awq import AWQModifier

# Mock external dependencies for compress_and_upload
@pytest.fixture
def mock_hf_api():
    with patch('app.HfApi') as mock_api:
        mock_api_instance = mock_api.return_value
        mock_api_instance.create_repo.return_value = "https://huggingface.co/test_user/test_model-AWQ"
        yield mock_api_instance

@pytest.fixture
def mock_whoami():
    with patch('app.whoami') as mock_whoami_func:
        mock_whoami_func.return_value = {"name": "test_user"}
        yield mock_whoami_func

@pytest.fixture
def mock_auto_model_for_causal_lm():
    with patch('app.AutoModelForCausalLM') as mock_model_class:
        mock_model_instance = MagicMock()
        mock_model_instance.config.architectures = ["LlamaForCausalLM"]
        mock_model_class.from_pretrained.return_value = mock_model_instance
        yield mock_model_class

@pytest.fixture
def mock_oneshot():
    with patch('app.oneshot') as mock_oneshot_func:
        yield mock_oneshot_func

@pytest.fixture
def mock_model_card():
    with patch('app.ModelCard') as mock_card_class:
        mock_card_instance = MagicMock()
        mock_card_class.return_value = mock_card_instance
        yield mock_card_class

@pytest.fixture
def mock_gr_oauth_token():
    mock_token = MagicMock(spec=gr.OAuthToken)
    mock_token.token = "test_token"
    return mock_token

# --- Test get_quantization_recipe ---
def test_get_quantization_recipe_awq():
    recipe = get_quantization_recipe("AWQ", "LlamaForCausalLM")
    assert len(recipe) == 1
    assert isinstance(recipe[0], AWQModifier)

def test_get_quantization_recipe_gptq():
    recipe = get_quantization_recipe("GPTQ", "LlamaForCausalLM")
    assert len(recipe) == 1
    assert isinstance(recipe[0], GPTQModifier)

def test_get_quantization_recipe_gptq_mistral():
    recipe = get_quantization_recipe("GPTQ", "MistralForCausalLM")
    assert len(recipe) == 1
    assert isinstance(recipe[0], GPTQModifier)
    assert recipe[0].sequential_targets == ["MistralDecoderLayer"]

def test_get_quantization_recipe_gptq_mixtral():
    recipe = get_quantization_recipe("GPTQ", "MixtralForCausalLM")
    assert len(recipe) == 1
    assert isinstance(recipe[0], GPTQModifier)
    assert recipe[0].sequential_targets == ["MixtralDecoderLayer"]

def test_get_quantization_recipe_fp8():
    recipe = get_quantization_recipe("FP8", "LlamaForCausalLM")
    assert len(recipe) == 1
    assert isinstance(recipe[0], QuantizationModifier)
    assert recipe[0].scheme == "FP8"
    assert recipe[0].ignore == ["lm_head"]

def test_get_quantization_recipe_fp8_mixtral():
    recipe = get_quantization_recipe("FP8", "MixtralForCausalLM")
    assert len(recipe) == 1
    assert isinstance(recipe[0], QuantizationModifier)
    assert recipe[0].scheme == "FP8"
    assert "re:.*block_sparse_moe.gate" in recipe[0].ignore

def test_get_quantization_recipe_unsupported():
    with pytest.raises(ValueError, match="Unsupported quantization method: INVALID"):
        get_quantization_recipe("INVALID", "LlamaForCausalLM")

# --- Test compress_and_upload ---
def test_compress_and_upload_no_model_id(mock_gr_oauth_token):
    with pytest.raises(gr.Error, match="Please select a model from the search bar."):
        compress_and_upload("", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)

def test_compress_and_upload_no_oauth_token():
    with pytest.raises(gr.Error, match="Authentication error. Please log in to continue."):
        compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", None)

def test_compress_and_upload_success(
    mock_hf_api,
    mock_whoami,
    mock_auto_model_for_causal_lm,
    mock_oneshot,
    mock_model_card,
    mock_gr_oauth_token,
):
    model_id = "org/test_model"
    quant_method = "AWQ"
    model_type_selection = "Auto-detect (recommended)"
    result = compress_and_upload(model_id, quant_method, model_type_selection, mock_gr_oauth_token)

    mock_whoami.assert_called_once_with(token="test_token")

    # The device_map and torch_dtype should depend on CUDA availability
    import torch
    if torch.cuda.is_available():
        expected_torch_dtype = torch.float16
        expected_device_map = "auto"
    else:
        expected_torch_dtype = "auto"
        expected_device_map = "cpu"

    mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
        model_id, torch_dtype=expected_torch_dtype, device_map=expected_device_map, token="test_token", trust_remote_code=True
    )
    mock_oneshot.assert_called_once()
    assert mock_oneshot.call_args[1]["model"] == mock_auto_model_for_causal_lm.from_pretrained.return_value
    assert mock_oneshot.call_args[1]["recipe"] is not None
    assert mock_oneshot.call_args[1]["output_dir"] == f"test_model-{quant_method}"

    mock_hf_api.create_repo.assert_called_once_with(
        repo_id=f"test_user/test_model-{quant_method}", exist_ok=True
    )
    mock_hf_api.upload_folder.assert_called_once_with(
        folder_path=f"test_model-{quant_method}",
        repo_id=f"test_user/test_model-{quant_method}",
        commit_message=f"Upload {quant_method} compressed model",
    )
    mock_model_card.assert_called_once()
    mock_model_card.return_value.push_to_hub.assert_called_once_with(
        f"test_user/test_model-{quant_method}", token="test_token"
    )

    assert "✅ Success!" in result
    assert "https://huggingface.co/test_user/test_model-AWQ" in result

def test_compress_and_upload_with_trust_remote_code(
    mock_hf_api,
    mock_whoami,
    mock_auto_model_for_causal_lm,
    mock_oneshot,
    mock_model_card,
    mock_gr_oauth_token,
):
    model_id = "org/test_model"
    quant_method = "AWQ"
    model_type_selection = "Auto-detect (recommended)"
    compress_and_upload(model_id, quant_method, model_type_selection, mock_gr_oauth_token)

    # The device_map and torch_dtype should depend on CUDA availability
    import torch
    if torch.cuda.is_available():
        expected_torch_dtype = torch.float16
        expected_device_map = "auto"
    else:
        expected_torch_dtype = "auto"
        expected_device_map = "cpu"

    mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
        model_id, torch_dtype=expected_torch_dtype, device_map=expected_device_map, token="test_token", trust_remote_code=True
    )


def test_compress_and_upload_model_no_architecture(
    mock_hf_api,
    mock_whoami,
    mock_auto_model_for_causal_lm,
    mock_gr_oauth_token,
):
    mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = []
    with pytest.raises(gr.Error, match="Could not determine model architecture."):
        compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)

def test_compress_and_upload_generic_exception(
    mock_hf_api,
    mock_whoami,
    mock_auto_model_for_causal_lm,
    mock_gr_oauth_token,
):
    mock_whoami.side_effect = Exception("Network error")
    result = compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
    assert "❌ ERROR" in result
    assert "Network error" in result

def test_compress_and_upload_unrecognized_architecture(
    mock_hf_api,
    mock_whoami,
    mock_auto_model_for_causal_lm,
    mock_gr_oauth_token,
):
    mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = ["UnrecognizedArchitecture"]
    result = compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
    assert "❌ ERROR" in result
    assert "AWQ quantization is only supported for LlamaForCausalLM architectures, got UnrecognizedArchitecture" in result