Spaces:
Sleeping
Sleeping
File size: 7,887 Bytes
d371d24 a17d990 d371d24 a17d990 d371d24 98be1e8 d371d24 98be1e8 d371d24 98be1e8 d371d24 c2bdc87 d371d24 c2bdc87 d371d24 ba1df97 98be1e8 ba1df97 c2bdc87 ba1df97 c2bdc87 ba1df97 d371d24 98be1e8 d371d24 98be1e8 d371d24 222c0be 98be1e8 222c0be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import pytest
from unittest.mock import MagicMock, patch
from app import get_quantization_recipe, compress_and_upload
import gradio as gr
from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
from llmcompressor.modifiers.awq import AWQModifier
# Mock external dependencies for compress_and_upload
@pytest.fixture
def mock_hf_api():
with patch('app.HfApi') as mock_api:
mock_api_instance = mock_api.return_value
mock_api_instance.create_repo.return_value = "https://huggingface.co/test_user/test_model-AWQ"
yield mock_api_instance
@pytest.fixture
def mock_whoami():
with patch('app.whoami') as mock_whoami_func:
mock_whoami_func.return_value = {"name": "test_user"}
yield mock_whoami_func
@pytest.fixture
def mock_auto_model_for_causal_lm():
with patch('app.AutoModelForCausalLM') as mock_model_class:
mock_model_instance = MagicMock()
mock_model_instance.config.architectures = ["LlamaForCausalLM"]
mock_model_class.from_pretrained.return_value = mock_model_instance
yield mock_model_class
@pytest.fixture
def mock_oneshot():
with patch('app.oneshot') as mock_oneshot_func:
yield mock_oneshot_func
@pytest.fixture
def mock_model_card():
with patch('app.ModelCard') as mock_card_class:
mock_card_instance = MagicMock()
mock_card_class.return_value = mock_card_instance
yield mock_card_class
@pytest.fixture
def mock_gr_oauth_token():
mock_token = MagicMock(spec=gr.OAuthToken)
mock_token.token = "test_token"
return mock_token
# --- Test get_quantization_recipe ---
def test_get_quantization_recipe_awq():
recipe = get_quantization_recipe("AWQ", "LlamaForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], AWQModifier)
def test_get_quantization_recipe_gptq():
recipe = get_quantization_recipe("GPTQ", "LlamaForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], GPTQModifier)
def test_get_quantization_recipe_gptq_mistral():
recipe = get_quantization_recipe("GPTQ", "MistralForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], GPTQModifier)
assert recipe[0].sequential_targets == ["MistralDecoderLayer"]
def test_get_quantization_recipe_gptq_mixtral():
recipe = get_quantization_recipe("GPTQ", "MixtralForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], GPTQModifier)
assert recipe[0].sequential_targets == ["MixtralDecoderLayer"]
def test_get_quantization_recipe_fp8():
recipe = get_quantization_recipe("FP8", "LlamaForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], QuantizationModifier)
assert recipe[0].scheme == "FP8"
assert recipe[0].ignore == ["lm_head"]
def test_get_quantization_recipe_fp8_mixtral():
recipe = get_quantization_recipe("FP8", "MixtralForCausalLM")
assert len(recipe) == 1
assert isinstance(recipe[0], QuantizationModifier)
assert recipe[0].scheme == "FP8"
assert "re:.*block_sparse_moe.gate" in recipe[0].ignore
def test_get_quantization_recipe_unsupported():
with pytest.raises(ValueError, match="Unsupported quantization method: INVALID"):
get_quantization_recipe("INVALID", "LlamaForCausalLM")
# --- Test compress_and_upload ---
def test_compress_and_upload_no_model_id(mock_gr_oauth_token):
with pytest.raises(gr.Error, match="Please select a model from the search bar."):
compress_and_upload("", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
def test_compress_and_upload_no_oauth_token():
with pytest.raises(gr.Error, match="Authentication error. Please log in to continue."):
compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", None)
def test_compress_and_upload_success(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_oneshot,
mock_model_card,
mock_gr_oauth_token,
):
model_id = "org/test_model"
quant_method = "AWQ"
model_type_selection = "Auto-detect (recommended)"
result = compress_and_upload(model_id, quant_method, model_type_selection, mock_gr_oauth_token)
mock_whoami.assert_called_once_with(token="test_token")
# The device_map and torch_dtype should depend on CUDA availability
import torch
if torch.cuda.is_available():
expected_torch_dtype = torch.float16
expected_device_map = "auto"
else:
expected_torch_dtype = "auto"
expected_device_map = "cpu"
mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
model_id, torch_dtype=expected_torch_dtype, device_map=expected_device_map, token="test_token", trust_remote_code=True
)
mock_oneshot.assert_called_once()
assert mock_oneshot.call_args[1]["model"] == mock_auto_model_for_causal_lm.from_pretrained.return_value
assert mock_oneshot.call_args[1]["recipe"] is not None
assert mock_oneshot.call_args[1]["output_dir"] == f"test_model-{quant_method}"
mock_hf_api.create_repo.assert_called_once_with(
repo_id=f"test_user/test_model-{quant_method}", exist_ok=True
)
mock_hf_api.upload_folder.assert_called_once_with(
folder_path=f"test_model-{quant_method}",
repo_id=f"test_user/test_model-{quant_method}",
commit_message=f"Upload {quant_method} compressed model",
)
mock_model_card.assert_called_once()
mock_model_card.return_value.push_to_hub.assert_called_once_with(
f"test_user/test_model-{quant_method}", token="test_token"
)
assert "✅ Success!" in result
assert "https://huggingface.co/test_user/test_model-AWQ" in result
def test_compress_and_upload_with_trust_remote_code(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_oneshot,
mock_model_card,
mock_gr_oauth_token,
):
model_id = "org/test_model"
quant_method = "AWQ"
model_type_selection = "Auto-detect (recommended)"
compress_and_upload(model_id, quant_method, model_type_selection, mock_gr_oauth_token)
# The device_map and torch_dtype should depend on CUDA availability
import torch
if torch.cuda.is_available():
expected_torch_dtype = torch.float16
expected_device_map = "auto"
else:
expected_torch_dtype = "auto"
expected_device_map = "cpu"
mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
model_id, torch_dtype=expected_torch_dtype, device_map=expected_device_map, token="test_token", trust_remote_code=True
)
def test_compress_and_upload_model_no_architecture(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_gr_oauth_token,
):
mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = []
with pytest.raises(gr.Error, match="Could not determine model architecture."):
compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
def test_compress_and_upload_generic_exception(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_gr_oauth_token,
):
mock_whoami.side_effect = Exception("Network error")
result = compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
assert "❌ ERROR" in result
assert "Network error" in result
def test_compress_and_upload_unrecognized_architecture(
mock_hf_api,
mock_whoami,
mock_auto_model_for_causal_lm,
mock_gr_oauth_token,
):
mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = ["UnrecognizedArchitecture"]
result = compress_and_upload("test_model", "AWQ", "Auto-detect (recommended)", mock_gr_oauth_token)
assert "❌ ERROR" in result
assert "AWQ quantization is only supported for LlamaForCausalLM architectures, got UnrecognizedArchitecture" in result
|