n00b001 commited on
Commit
d371d24
·
unverified ·
1 Parent(s): 6f7fa65

feat: Add comprehensive unit tests for app.py and update dependencies

Browse files
Files changed (3) hide show
  1. app.py +64 -51
  2. requirements.txt +488 -8
  3. tests/test_app.py +157 -0
app.py CHANGED
@@ -1,23 +1,28 @@
1
  import gradio as gr
2
  from huggingface_hub import HfApi, ModelCard, whoami
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
- import os
5
  from llmcompressor import oneshot
6
  from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
7
  from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  # --- Helper Functions ---
11
 
 
12
  def get_quantization_recipe(method, model_architecture):
13
  """
14
  Returns the appropriate llm-compressor recipe based on the selected method.
15
  """
16
  if method == "AWQ":
17
  mappings = [
18
- AWQMapping("re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]),
 
 
19
  AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
20
- AWQMapping("re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"]),
 
 
21
  AWQMapping("re:.*up_proj", ["re:.*down_proj"]),
22
  ]
23
  return [
@@ -25,7 +30,7 @@ def get_quantization_recipe(method, model_architecture):
25
  ignore=["lm_head"],
26
  scheme="W4A16_ASYM",
27
  targets=["Linear"],
28
- mappings=mappings
29
  ),
30
  ]
31
  elif method == "GPTQ":
@@ -34,7 +39,9 @@ def get_quantization_recipe(method, model_architecture):
34
  "MistralForCausalLM": "MistralDecoderLayer",
35
  "MixtralForCausalLM": "MixtralDecoderLayer",
36
  }
37
- sequential_target = sequential_target_map.get(model_architecture, "LlamaDecoderLayer")
 
 
38
 
39
  return [
40
  GPTQModifier(
@@ -49,11 +56,9 @@ def get_quantization_recipe(method, model_architecture):
49
  if "Mixtral" in model_architecture:
50
  ignore_layers.append("re:.*block_sparse_moe.gate")
51
 
52
- return QuantizationModifier(
53
- scheme="FP8",
54
- targets="Linear",
55
- ignore=ignore_layers
56
- )
57
  else:
58
  raise ValueError(f"Unsupported quantization method: {method}")
59
 
@@ -62,18 +67,16 @@ def compress_and_upload(
62
  model_id: str,
63
  quant_method: str,
64
  oauth_token: gr.OAuthToken | None,
65
- *,
66
- request: gr.Request
67
  ):
68
  """
69
  Compresses a model using llm-compressor and uploads it to a new HF repo.
70
  """
71
  if not model_id:
72
  raise gr.Error("Please select a model from the search bar.")
73
-
74
  if oauth_token is None:
75
  raise gr.Error("Authentication error. Please log in to continue.")
76
-
77
  token = oauth_token.token
78
 
79
  try:
@@ -81,12 +84,15 @@ def compress_and_upload(
81
  username = whoami(token=token)["name"]
82
 
83
  # --- 1. Load Model and Tokenizer ---
84
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map=None, token=token)
85
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
86
-
 
87
  output_dir = f"{model_id.split('/')[-1]}-{quant_method}"
88
 
89
  # --- 2. Get Recipe ---
 
 
90
  recipe = get_quantization_recipe(quant_method, model.config.architectures[0])
91
 
92
  # --- 3. Run Compression ---
@@ -140,41 +146,48 @@ For more details on the recipe used, refer to the `recipe.yaml` file in this rep
140
 
141
  return f'<h1>✅ Success!</h1><br/>Model compressed and saved to your new repo: <a href="{repo_url}" target="_blank" style="text-decoration:underline">{repo_id}</a>'
142
 
 
 
143
  except Exception as e:
144
  error_message = str(e).replace("\n", "<br/>")
145
  return f'<h1>❌ ERROR</h1><br/><pre style="white-space:pre-wrap;">{error_message}</pre>'
146
 
 
 
147
  # --- Gradio Interface ---
148
- with gr.Blocks(css="footer {display: none !important;}") as demo:
149
- gr.Markdown("# LLM-Compressor My Repo")
150
- gr.Markdown(
151
- "Log in, choose a model, select a quantization method, and this Space will create a new compressed model repository on your Hugging Face profile."
152
- )
153
- with gr.Row():
154
- login_button = gr.LoginButton(min_width=250)
155
-
156
- gr.Markdown("### 1. Select a Model from the Hugging Face Hub")
157
- model_input = HuggingfaceHubSearch(
158
- label="Search for a Model",
159
- search_type="model",
160
- )
161
-
162
- gr.Markdown("### 2. Choose a Quantization Method")
163
- quant_method_dropdown = gr.Dropdown(
164
- ["AWQ", "GPTQ", "FP8"],
165
- label="Quantization Method",
166
- value="AWQ"
167
- )
168
-
169
- compress_button = gr.Button("Compress and Create Repo", variant="primary")
170
- output_html = gr.HTML(label="Result")
171
-
172
- # The inputs list correctly provides 3 arguments. Gradio will add the 4th (request) implicitly.
173
- # The function signature now correctly expects all 4.
174
- compress_button.click(
175
- fn=compress_and_upload,
176
- inputs=[model_input, quant_method_dropdown, login_button],
177
- outputs=output_html
178
- )
179
-
180
- demo.queue(max_size=5).launch()
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import HfApi, ModelCard, whoami
3
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
4
+
5
  from llmcompressor import oneshot
6
  from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
7
  from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
8
+ from transformers import AutoModelForCausalLM
9
 
10
  # --- Helper Functions ---
11
 
12
+
13
  def get_quantization_recipe(method, model_architecture):
14
  """
15
  Returns the appropriate llm-compressor recipe based on the selected method.
16
  """
17
  if method == "AWQ":
18
  mappings = [
19
+ AWQMapping(
20
+ "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]
21
+ ),
22
  AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
23
+ AWQMapping(
24
+ "re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"]
25
+ ),
26
  AWQMapping("re:.*up_proj", ["re:.*down_proj"]),
27
  ]
28
  return [
 
30
  ignore=["lm_head"],
31
  scheme="W4A16_ASYM",
32
  targets=["Linear"],
33
+ mappings=mappings,
34
  ),
35
  ]
36
  elif method == "GPTQ":
 
39
  "MistralForCausalLM": "MistralDecoderLayer",
40
  "MixtralForCausalLM": "MixtralDecoderLayer",
41
  }
42
+ sequential_target = sequential_target_map.get(
43
+ model_architecture, "LlamaDecoderLayer"
44
+ )
45
 
46
  return [
47
  GPTQModifier(
 
56
  if "Mixtral" in model_architecture:
57
  ignore_layers.append("re:.*block_sparse_moe.gate")
58
 
59
+ return [QuantizationModifier(
60
+ scheme="FP8", targets="Linear", ignore=ignore_layers
61
+ )]
 
 
62
  else:
63
  raise ValueError(f"Unsupported quantization method: {method}")
64
 
 
67
  model_id: str,
68
  quant_method: str,
69
  oauth_token: gr.OAuthToken | None,
 
 
70
  ):
71
  """
72
  Compresses a model using llm-compressor and uploads it to a new HF repo.
73
  """
74
  if not model_id:
75
  raise gr.Error("Please select a model from the search bar.")
76
+
77
  if oauth_token is None:
78
  raise gr.Error("Authentication error. Please log in to continue.")
79
+
80
  token = oauth_token.token
81
 
82
  try:
 
84
  username = whoami(token=token)["name"]
85
 
86
  # --- 1. Load Model and Tokenizer ---
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ model_id, torch_dtype="auto", device_map=None, token=token
89
+ )
90
+
91
  output_dir = f"{model_id.split('/')[-1]}-{quant_method}"
92
 
93
  # --- 2. Get Recipe ---
94
+ if not model.config.architectures:
95
+ raise gr.Error("Could not determine model architecture.")
96
  recipe = get_quantization_recipe(quant_method, model.config.architectures[0])
97
 
98
  # --- 3. Run Compression ---
 
146
 
147
  return f'<h1>✅ Success!</h1><br/>Model compressed and saved to your new repo: <a href="{repo_url}" target="_blank" style="text-decoration:underline">{repo_id}</a>'
148
 
149
+ except gr.Error as e:
150
+ raise e
151
  except Exception as e:
152
  error_message = str(e).replace("\n", "<br/>")
153
  return f'<h1>❌ ERROR</h1><br/><pre style="white-space:pre-wrap;">{error_message}</pre>'
154
 
155
+
156
+
157
  # --- Gradio Interface ---
158
+ def build_gradio_app():
159
+ with gr.Blocks(css="footer {display: none !important;}") as demo:
160
+ gr.Markdown("# LLM-Compressor My Repo")
161
+ gr.Markdown(
162
+ "Log in, choose a model, select a quantization method, and this Space will create a new compressed model repository on your Hugging Face profile."
163
+ )
164
+ with gr.Row():
165
+ login_button = gr.LoginButton(min_width=250)
166
+
167
+ gr.Markdown("### 1. Select a Model from the Hugging Face Hub")
168
+ model_input = HuggingfaceHubSearch(
169
+ label="Search for a Model",
170
+ search_type="model",
171
+ )
172
+
173
+ gr.Markdown("### 2. Choose a Quantization Method")
174
+ quant_method_dropdown = gr.Dropdown(
175
+ ["AWQ", "GPTQ", "FP8"], label="Quantization Method", value="AWQ"
176
+ )
177
+
178
+ compress_button = gr.Button("Compress and Create Repo", variant="primary")
179
+ output_html = gr.HTML(label="Result")
180
+
181
+ compress_button.click(
182
+ fn=compress_and_upload,
183
+ inputs=[model_input, quant_method_dropdown],
184
+ outputs=output_html,
185
+ )
186
+ return demo
187
+
188
+ def main():
189
+ demo = build_gradio_app()
190
+ demo.queue(max_size=5).launch()
191
+
192
+ if __name__ == "__main__":
193
+ main()
requirements.txt CHANGED
@@ -1,8 +1,488 @@
1
- gradio
2
- gradio_huggingfacehub_search
3
- huggingface-hub>=1.0.0
4
- torch
5
- accelerate
6
- datasets
7
- llmcompressor@git+https://github.com/vllm-project/llm-compressor.git
8
- transformers[torch]@git+https://github.com/huggingface/transformers.git
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ absl-py==2.3.1
4
+ # via rouge-score
5
+ accelerate==1.12.0
6
+ # via
7
+ # llm-compressor-my-repo (pyproject.toml)
8
+ # auto-round
9
+ # llmcompressor
10
+ # lm-eval
11
+ # peft
12
+ # transformers
13
+ aiofiles==24.1.0
14
+ # via gradio
15
+ aiohappyeyeballs==2.6.1
16
+ # via aiohttp
17
+ aiohttp==3.13.2
18
+ # via fsspec
19
+ aiosignal==1.4.0
20
+ # via aiohttp
21
+ annotated-doc==0.0.4
22
+ # via fastapi
23
+ annotated-types==0.7.0
24
+ # via pydantic
25
+ anyio==4.12.0
26
+ # via
27
+ # gradio
28
+ # httpx
29
+ # starlette
30
+ attrs==25.4.0
31
+ # via
32
+ # aiohttp
33
+ # jsonlines
34
+ authlib==1.6.5
35
+ # via gradio
36
+ auto-round @ git+https://github.com/intel/auto-round.git@5ffe56ddc51cbc69cd6fe87a0b8a7d91e28bf522
37
+ # via llmcompressor
38
+ brotli==1.2.0
39
+ # via gradio
40
+ certifi==2025.11.12
41
+ # via
42
+ # httpcore
43
+ # httpx
44
+ # requests
45
+ cffi==2.0.0
46
+ # via cryptography
47
+ chardet==5.2.0
48
+ # via mbstrdecoder
49
+ charset-normalizer==3.4.4
50
+ # via requests
51
+ click==8.3.1
52
+ # via
53
+ # nltk
54
+ # typer
55
+ # typer-slim
56
+ # uvicorn
57
+ colorama==0.4.6
58
+ # via
59
+ # sacrebleu
60
+ # tqdm-multiprocess
61
+ compressed-tensors==0.12.3a20251114
62
+ # via llmcompressor
63
+ cryptography==46.0.3
64
+ # via authlib
65
+ dataproperty==1.1.0
66
+ # via
67
+ # pytablewriter
68
+ # tabledata
69
+ datasets==4.4.1
70
+ # via
71
+ # llm-compressor-my-repo (pyproject.toml)
72
+ # auto-round
73
+ # evaluate
74
+ # llmcompressor
75
+ # lm-eval
76
+ dill==0.4.0
77
+ # via
78
+ # datasets
79
+ # evaluate
80
+ # lm-eval
81
+ # multiprocess
82
+ evaluate==0.4.6
83
+ # via lm-eval
84
+ fastapi==0.122.0
85
+ # via gradio
86
+ ffmpy==1.0.0
87
+ # via gradio
88
+ filelock==3.20.0
89
+ # via
90
+ # datasets
91
+ # huggingface-hub
92
+ # torch
93
+ # transformers
94
+ frozenlist==1.8.0
95
+ # via
96
+ # aiohttp
97
+ # aiosignal
98
+ fsspec==2025.10.0
99
+ # via
100
+ # datasets
101
+ # evaluate
102
+ # gradio-client
103
+ # huggingface-hub
104
+ # torch
105
+ gradio==5.50.0
106
+ # via
107
+ # llm-compressor-my-repo (pyproject.toml)
108
+ # gradio-huggingfacehub-search
109
+ gradio-client==1.14.0
110
+ # via gradio
111
+ gradio-huggingfacehub-search==0.0.12
112
+ # via llm-compressor-my-repo (pyproject.toml)
113
+ groovy==0.1.2
114
+ # via gradio
115
+ h11==0.16.0
116
+ # via
117
+ # httpcore
118
+ # uvicorn
119
+ hf-xet==1.2.0
120
+ # via
121
+ # llm-compressor-my-repo (pyproject.toml)
122
+ # huggingface-hub
123
+ httpcore==1.0.9
124
+ # via httpx
125
+ httpx==0.28.1
126
+ # via
127
+ # datasets
128
+ # gradio
129
+ # gradio-client
130
+ # huggingface-hub
131
+ # safehttpx
132
+ huggingface-hub==1.1.6
133
+ # via
134
+ # llm-compressor-my-repo (pyproject.toml)
135
+ # accelerate
136
+ # datasets
137
+ # evaluate
138
+ # gradio
139
+ # gradio-client
140
+ # peft
141
+ # tokenizers
142
+ # transformers
143
+ idna==3.11
144
+ # via
145
+ # anyio
146
+ # httpx
147
+ # requests
148
+ # yarl
149
+ itsdangerous==2.2.0
150
+ # via gradio
151
+ jinja2==3.1.6
152
+ # via
153
+ # gradio
154
+ # torch
155
+ joblib==1.5.2
156
+ # via
157
+ # nltk
158
+ # scikit-learn
159
+ jsonlines==4.0.0
160
+ # via lm-eval
161
+ llmcompressor @ git+https://github.com/vllm-project/llm-compressor.git@db0b68d9faf09066e9b7d679b39a977e484d9b91
162
+ # via llm-compressor-my-repo (pyproject.toml)
163
+ lm-eval==0.4.9.2
164
+ # via auto-round
165
+ loguru==0.7.3
166
+ # via
167
+ # compressed-tensors
168
+ # llmcompressor
169
+ lxml==6.0.2
170
+ # via sacrebleu
171
+ markdown-it-py==4.0.0
172
+ # via rich
173
+ markupsafe==3.0.3
174
+ # via
175
+ # gradio
176
+ # jinja2
177
+ mbstrdecoder==1.1.4
178
+ # via
179
+ # dataproperty
180
+ # pytablewriter
181
+ # typepy
182
+ mdurl==0.1.2
183
+ # via markdown-it-py
184
+ more-itertools==10.8.0
185
+ # via lm-eval
186
+ mpmath==1.3.0
187
+ # via sympy
188
+ multidict==6.7.0
189
+ # via
190
+ # aiohttp
191
+ # yarl
192
+ multiprocess==0.70.18
193
+ # via
194
+ # datasets
195
+ # evaluate
196
+ networkx==3.6
197
+ # via torch
198
+ nltk==3.9.2
199
+ # via rouge-score
200
+ numexpr==2.14.1
201
+ # via lm-eval
202
+ numpy==2.3.5
203
+ # via
204
+ # accelerate
205
+ # auto-round
206
+ # datasets
207
+ # evaluate
208
+ # gradio
209
+ # llmcompressor
210
+ # numexpr
211
+ # pandas
212
+ # peft
213
+ # rouge-score
214
+ # sacrebleu
215
+ # safetensors
216
+ # scikit-learn
217
+ # scipy
218
+ # transformers
219
+ nvidia-cublas-cu12==12.8.4.1
220
+ # via
221
+ # nvidia-cudnn-cu12
222
+ # nvidia-cusolver-cu12
223
+ # torch
224
+ nvidia-cuda-cupti-cu12==12.8.90
225
+ # via torch
226
+ nvidia-cuda-nvrtc-cu12==12.8.93
227
+ # via torch
228
+ nvidia-cuda-runtime-cu12==12.8.90
229
+ # via torch
230
+ nvidia-cudnn-cu12==9.10.2.21
231
+ # via torch
232
+ nvidia-cufft-cu12==11.3.3.83
233
+ # via torch
234
+ nvidia-cufile-cu12==1.13.1.3
235
+ # via torch
236
+ nvidia-curand-cu12==10.3.9.90
237
+ # via torch
238
+ nvidia-cusolver-cu12==11.7.3.90
239
+ # via torch
240
+ nvidia-cusparse-cu12==12.5.8.93
241
+ # via
242
+ # nvidia-cusolver-cu12
243
+ # torch
244
+ nvidia-cusparselt-cu12==0.7.1
245
+ # via torch
246
+ nvidia-ml-py==13.580.82
247
+ # via llmcompressor
248
+ nvidia-nccl-cu12==2.27.5
249
+ # via torch
250
+ nvidia-nvjitlink-cu12==12.8.93
251
+ # via
252
+ # nvidia-cufft-cu12
253
+ # nvidia-cusolver-cu12
254
+ # nvidia-cusparse-cu12
255
+ # torch
256
+ nvidia-nvshmem-cu12==3.3.20
257
+ # via torch
258
+ nvidia-nvtx-cu12==12.8.90
259
+ # via torch
260
+ orjson==3.11.4
261
+ # via gradio
262
+ packaging==25.0
263
+ # via
264
+ # accelerate
265
+ # auto-round
266
+ # datasets
267
+ # evaluate
268
+ # gradio
269
+ # gradio-client
270
+ # huggingface-hub
271
+ # peft
272
+ # safetensors
273
+ # transformers
274
+ # typepy
275
+ pandas==2.3.3
276
+ # via
277
+ # datasets
278
+ # evaluate
279
+ # gradio
280
+ pathvalidate==3.3.1
281
+ # via pytablewriter
282
+ peft==0.18.0
283
+ # via lm-eval
284
+ pillow==11.3.0
285
+ # via
286
+ # auto-round
287
+ # gradio
288
+ # llmcompressor
289
+ portalocker==3.2.0
290
+ # via sacrebleu
291
+ propcache==0.4.1
292
+ # via
293
+ # aiohttp
294
+ # yarl
295
+ psutil==7.1.3
296
+ # via
297
+ # accelerate
298
+ # peft
299
+ py-cpuinfo==9.0.0
300
+ # via auto-round
301
+ pyarrow==22.0.0
302
+ # via datasets
303
+ pybind11==3.0.1
304
+ # via lm-eval
305
+ pycparser==2.23
306
+ # via cffi
307
+ pydantic==2.12.3
308
+ # via
309
+ # compressed-tensors
310
+ # fastapi
311
+ # gradio
312
+ pydantic-core==2.41.4
313
+ # via pydantic
314
+ pydub==0.25.1
315
+ # via gradio
316
+ pygments==2.19.2
317
+ # via rich
318
+ pytablewriter==1.2.1
319
+ # via lm-eval
320
+ python-dateutil==2.9.0.post0
321
+ # via
322
+ # pandas
323
+ # typepy
324
+ python-multipart==0.0.20
325
+ # via gradio
326
+ pytz==2025.2
327
+ # via
328
+ # pandas
329
+ # typepy
330
+ pyyaml==6.0.3
331
+ # via
332
+ # accelerate
333
+ # datasets
334
+ # gradio
335
+ # huggingface-hub
336
+ # llmcompressor
337
+ # peft
338
+ # transformers
339
+ regex==2025.11.3
340
+ # via
341
+ # nltk
342
+ # sacrebleu
343
+ # transformers
344
+ requests==2.32.5
345
+ # via
346
+ # datasets
347
+ # evaluate
348
+ # llmcompressor
349
+ # transformers
350
+ rich==14.2.0
351
+ # via typer
352
+ rouge-score==0.1.2
353
+ # via lm-eval
354
+ ruff==0.14.7
355
+ # via gradio
356
+ sacrebleu==2.5.1
357
+ # via lm-eval
358
+ safehttpx==0.1.7
359
+ # via gradio
360
+ safetensors==0.7.0
361
+ # via
362
+ # accelerate
363
+ # huggingface-hub
364
+ # peft
365
+ # transformers
366
+ scikit-learn==1.7.2
367
+ # via lm-eval
368
+ scipy==1.16.3
369
+ # via scikit-learn
370
+ semantic-version==2.10.0
371
+ # via gradio
372
+ sentencepiece==0.2.1
373
+ # via auto-round
374
+ setuptools==80.9.0
375
+ # via
376
+ # pytablewriter
377
+ # torch
378
+ shellingham==1.5.4
379
+ # via
380
+ # huggingface-hub
381
+ # typer
382
+ six==1.17.0
383
+ # via
384
+ # python-dateutil
385
+ # rouge-score
386
+ sqlitedict==2.1.0
387
+ # via lm-eval
388
+ starlette==0.50.0
389
+ # via
390
+ # fastapi
391
+ # gradio
392
+ sympy==1.14.0
393
+ # via torch
394
+ tabledata==1.3.4
395
+ # via pytablewriter
396
+ tabulate==0.9.0
397
+ # via sacrebleu
398
+ tcolorpy==0.1.7
399
+ # via pytablewriter
400
+ threadpoolctl==3.6.0
401
+ # via
402
+ # auto-round
403
+ # scikit-learn
404
+ tokenizers==0.22.1
405
+ # via transformers
406
+ tomlkit==0.13.3
407
+ # via gradio
408
+ torch==2.9.1
409
+ # via
410
+ # llm-compressor-my-repo (pyproject.toml)
411
+ # accelerate
412
+ # auto-round
413
+ # compressed-tensors
414
+ # huggingface-hub
415
+ # llmcompressor
416
+ # lm-eval
417
+ # peft
418
+ # safetensors
419
+ # transformers
420
+ tqdm==4.67.1
421
+ # via
422
+ # auto-round
423
+ # datasets
424
+ # evaluate
425
+ # huggingface-hub
426
+ # llmcompressor
427
+ # nltk
428
+ # peft
429
+ # tqdm-multiprocess
430
+ # transformers
431
+ tqdm-multiprocess==0.0.11
432
+ # via lm-eval
433
+ transformers @ git+https://github.com/huggingface/transformers.git@cac0a28c83cf87b7a05495de3177099c635ba852
434
+ # via
435
+ # llm-compressor-my-repo (pyproject.toml)
436
+ # auto-round
437
+ # compressed-tensors
438
+ # llmcompressor
439
+ # lm-eval
440
+ # peft
441
+ triton==3.5.1
442
+ # via torch
443
+ typepy==1.3.4
444
+ # via
445
+ # dataproperty
446
+ # pytablewriter
447
+ # tabledata
448
+ typer==0.20.0
449
+ # via gradio
450
+ typer-slim==0.20.0
451
+ # via
452
+ # huggingface-hub
453
+ # transformers
454
+ typing-extensions==4.15.0
455
+ # via
456
+ # aiosignal
457
+ # anyio
458
+ # fastapi
459
+ # gradio
460
+ # gradio-client
461
+ # huggingface-hub
462
+ # pydantic
463
+ # pydantic-core
464
+ # starlette
465
+ # torch
466
+ # typer
467
+ # typer-slim
468
+ # typing-inspection
469
+ typing-inspection==0.4.2
470
+ # via pydantic
471
+ tzdata==2025.2
472
+ # via pandas
473
+ urllib3==2.5.0
474
+ # via requests
475
+ uvicorn==0.38.0
476
+ # via gradio
477
+ websockets==15.0.1
478
+ # via gradio-client
479
+ word2number==1.1
480
+ # via lm-eval
481
+ xxhash==3.6.0
482
+ # via
483
+ # datasets
484
+ # evaluate
485
+ yarl==1.22.0
486
+ # via aiohttp
487
+ zstandard==0.25.0
488
+ # via lm-eval
tests/test_app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from unittest.mock import MagicMock, patch
3
+ from app import get_quantization_recipe, compress_and_upload
4
+ import gradio as gr
5
+ from transformers import AutoModelForCausalLM
6
+ from huggingface_hub import HfApi, ModelCard, whoami
7
+ from llmcompressor import oneshot
8
+ from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
9
+ from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
10
+
11
+ # Mock external dependencies for compress_and_upload
12
+ @pytest.fixture
13
+ def mock_hf_api():
14
+ with patch('app.HfApi') as mock_api:
15
+ mock_api_instance = mock_api.return_value
16
+ mock_api_instance.create_repo.return_value = "https://huggingface.co/test_user/test_model-AWQ"
17
+ yield mock_api_instance
18
+
19
+ @pytest.fixture
20
+ def mock_whoami():
21
+ with patch('app.whoami') as mock_whoami_func:
22
+ mock_whoami_func.return_value = {"name": "test_user"}
23
+ yield mock_whoami_func
24
+
25
+ @pytest.fixture
26
+ def mock_auto_model_for_causal_lm():
27
+ with patch('app.AutoModelForCausalLM') as mock_model_class:
28
+ mock_model_instance = MagicMock()
29
+ mock_model_instance.config.architectures = ["LlamaForCausalLM"]
30
+ mock_model_class.from_pretrained.return_value = mock_model_instance
31
+ yield mock_model_class
32
+
33
+ @pytest.fixture
34
+ def mock_oneshot():
35
+ with patch('app.oneshot') as mock_oneshot_func:
36
+ yield mock_oneshot_func
37
+
38
+ @pytest.fixture
39
+ def mock_model_card():
40
+ with patch('app.ModelCard') as mock_card_class:
41
+ mock_card_instance = MagicMock()
42
+ mock_card_class.return_value = mock_card_instance
43
+ yield mock_card_class
44
+
45
+ @pytest.fixture
46
+ def mock_gr_oauth_token():
47
+ mock_token = MagicMock(spec=gr.OAuthToken)
48
+ mock_token.token = "test_token"
49
+ return mock_token
50
+
51
+ # --- Test get_quantization_recipe ---
52
+ def test_get_quantization_recipe_awq():
53
+ recipe = get_quantization_recipe("AWQ", "LlamaForCausalLM")
54
+ assert len(recipe) == 1
55
+ assert isinstance(recipe[0], AWQModifier)
56
+
57
+ def test_get_quantization_recipe_gptq():
58
+ recipe = get_quantization_recipe("GPTQ", "LlamaForCausalLM")
59
+ assert len(recipe) == 1
60
+ assert isinstance(recipe[0], GPTQModifier)
61
+
62
+ def test_get_quantization_recipe_gptq_mistral():
63
+ recipe = get_quantization_recipe("GPTQ", "MistralForCausalLM")
64
+ assert len(recipe) == 1
65
+ assert isinstance(recipe[0], GPTQModifier)
66
+ assert recipe[0].sequential_targets == ["MistralDecoderLayer"]
67
+
68
+ def test_get_quantization_recipe_gptq_mixtral():
69
+ recipe = get_quantization_recipe("GPTQ", "MixtralForCausalLM")
70
+ assert len(recipe) == 1
71
+ assert isinstance(recipe[0], GPTQModifier)
72
+ assert recipe[0].sequential_targets == ["MixtralDecoderLayer"]
73
+
74
+ def test_get_quantization_recipe_fp8():
75
+ recipe = get_quantization_recipe("FP8", "LlamaForCausalLM")
76
+ assert len(recipe) == 1
77
+ assert isinstance(recipe[0], QuantizationModifier)
78
+ assert recipe[0].scheme == "FP8"
79
+ assert recipe[0].ignore == ["lm_head"]
80
+
81
+ def test_get_quantization_recipe_fp8_mixtral():
82
+ recipe = get_quantization_recipe("FP8", "MixtralForCausalLM")
83
+ assert len(recipe) == 1
84
+ assert isinstance(recipe[0], QuantizationModifier)
85
+ assert recipe[0].scheme == "FP8"
86
+ assert "re:.*block_sparse_moe.gate" in recipe[0].ignore
87
+
88
+ def test_get_quantization_recipe_unsupported():
89
+ with pytest.raises(ValueError, match="Unsupported quantization method: INVALID"):
90
+ get_quantization_recipe("INVALID", "LlamaForCausalLM")
91
+
92
+ # --- Test compress_and_upload ---
93
+ def test_compress_and_upload_no_model_id(mock_gr_oauth_token):
94
+ with pytest.raises(gr.Error, match="Please select a model from the search bar."):
95
+ compress_and_upload("", "AWQ", mock_gr_oauth_token)
96
+
97
+ def test_compress_and_upload_no_oauth_token():
98
+ with pytest.raises(gr.Error, match="Authentication error. Please log in to continue."):
99
+ compress_and_upload("test_model", "AWQ", None)
100
+
101
+ def test_compress_and_upload_success(
102
+ mock_hf_api,
103
+ mock_whoami,
104
+ mock_auto_model_for_causal_lm,
105
+ mock_oneshot,
106
+ mock_model_card,
107
+ mock_gr_oauth_token,
108
+ ):
109
+ model_id = "org/test_model"
110
+ quant_method = "AWQ"
111
+ result = compress_and_upload(model_id, quant_method, mock_gr_oauth_token)
112
+
113
+ mock_whoami.assert_called_once_with(token="test_token")
114
+ mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with(
115
+ model_id, torch_dtype="auto", device_map=None, token="test_token"
116
+ )
117
+ mock_oneshot.assert_called_once()
118
+ assert mock_oneshot.call_args[1]["model"] == mock_auto_model_for_causal_lm.from_pretrained.return_value
119
+ assert mock_oneshot.call_args[1]["recipe"] is not None
120
+ assert mock_oneshot.call_args[1]["output_dir"] == f"test_model-{quant_method}"
121
+
122
+ mock_hf_api.create_repo.assert_called_once_with(
123
+ repo_id=f"test_user/test_model-{quant_method}", exist_ok=True
124
+ )
125
+ mock_hf_api.upload_folder.assert_called_once_with(
126
+ folder_path=f"test_model-{quant_method}",
127
+ repo_id=f"test_user/test_model-{quant_method}",
128
+ commit_message=f"Upload {quant_method} compressed model",
129
+ )
130
+ mock_model_card.assert_called_once()
131
+ mock_model_card.return_value.push_to_hub.assert_called_once_with(
132
+ f"test_user/test_model-{quant_method}", token="test_token"
133
+ )
134
+
135
+ assert "✅ Success!" in result
136
+ assert "https://huggingface.co/test_user/test_model-AWQ" in result
137
+
138
+ def test_compress_and_upload_model_no_architecture(
139
+ mock_hf_api,
140
+ mock_whoami,
141
+ mock_auto_model_for_causal_lm,
142
+ mock_gr_oauth_token,
143
+ ):
144
+ mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = []
145
+ with pytest.raises(gr.Error, match="Could not determine model architecture."):
146
+ compress_and_upload("test_model", "AWQ", mock_gr_oauth_token)
147
+
148
+ def test_compress_and_upload_generic_exception(
149
+ mock_hf_api,
150
+ mock_whoami,
151
+ mock_auto_model_for_causal_lm,
152
+ mock_gr_oauth_token,
153
+ ):
154
+ mock_whoami.side_effect = Exception("Network error")
155
+ result = compress_and_upload("test_model", "AWQ", mock_gr_oauth_token)
156
+ assert "❌ ERROR" in result
157
+ assert "Network error" in result