mangsense commited on
Commit
b7062c0
ยท
verified ยท
1 Parent(s): 140fff0

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +197 -0
handler.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import AutoTokenizer, T5ForSequenceClassification
4
+ from typing import Dict, List, Any
5
+
6
+ class EndpointHandler:
7
+ """
8
+ HuggingFace Inference Endpoint Handler for Java Vulnerability Detection
9
+ CodeT5 ๊ธฐ๋ฐ˜ ๋ถ„๋ฅ˜ ๋ชจ๋ธ (LoRA fine-tuned)
10
+ """
11
+
12
+ def __init__(self, path="."):
13
+ """
14
+ ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
15
+
16
+ Args:
17
+ path (str): ๋ชจ๋ธ์ด ์ €์žฅ๋œ ๊ฒฝ๋กœ (HuggingFace Hub์—์„œ ์ž๋™์œผ๋กœ ์„ค์ •๋จ)
18
+ """
19
+ print(f"๐Ÿš€ Loading Java Vulnerability Detection Model from {path}")
20
+
21
+ # ๋””๋ฐ”์ด์Šค ์„ค์ •
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ print(f"๐Ÿ“ Device: {self.device}")
24
+
25
+ # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
26
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
27
+
28
+ # T5ForSequenceClassification ๋ชจ๋ธ ๋กœ๋“œ
29
+ self.model = T5ForSequenceClassification.from_pretrained(
30
+ path,
31
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
32
+ )
33
+
34
+ # ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •ํ•˜๊ณ  ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
35
+ self.model.to(self.device)
36
+ self.model.eval()
37
+
38
+ print("โœ… Model loaded successfully!")
39
+
40
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
41
+ """
42
+ ๋ฉ”์ธ ์ถ”๋ก  ๋ฉ”์„œ๋“œ (HuggingFace Inference API๊ฐ€ ํ˜ธ์ถœ)
43
+
44
+ Args:
45
+ data (dict): ์ž…๋ ฅ ๋ฐ์ดํ„ฐ
46
+ - "inputs" (str): Java ์ฝ”๋“œ ๋˜๋Š”
47
+ - "code" (str): Java ์ฝ”๋“œ
48
+
49
+ Returns:
50
+ list: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
51
+ """
52
+ # 1. ์ „์ฒ˜๋ฆฌ
53
+ inputs = self.preprocess(data)
54
+
55
+ # 2. ์ถ”๋ก 
56
+ outputs = self.inference(inputs)
57
+
58
+ # 3. ํ›„์ฒ˜๋ฆฌ
59
+ result = self.postprocess(outputs)
60
+
61
+ return result
62
+
63
+ def preprocess(self, request: Dict[str, Any]) -> Dict[str, torch.Tensor]:
64
+ """
65
+ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
66
+
67
+ Args:
68
+ request (dict): API ์š”์ฒญ ๋ฐ์ดํ„ฐ
69
+
70
+ Returns:
71
+ dict: ํ† ํฌ๋‚˜์ด์ฆˆ๋œ ์ž…๋ ฅ ํ…์„œ
72
+ """
73
+ # ์ž…๋ ฅ ํ…์ŠคํŠธ ์ถ”์ถœ
74
+ if isinstance(request, dict):
75
+ # "inputs" ๋˜๋Š” "code" ํ‚ค์—์„œ Java ์ฝ”๋“œ ์ถ”์ถœ
76
+ code = request.get("inputs") or request.get("code")
77
+ elif isinstance(request, list) and len(request) > 0:
78
+ code = request[0].get("inputs") or request[0].get("code")
79
+ elif isinstance(request, str):
80
+ code = request
81
+ else:
82
+ raise ValueError(
83
+ "Invalid request format. Expected {'inputs': 'Java code here'} "
84
+ "or {'code': 'Java code here'}"
85
+ )
86
+
87
+ if not code:
88
+ raise ValueError("No code provided in request")
89
+
90
+ # ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์ ์šฉ
91
+ input_text = f"Is this Java code vulnerable?:\n{code}"
92
+
93
+ # ํ† ํฌ๋‚˜์ด์ง•
94
+ inputs = self.tokenizer(
95
+ input_text,
96
+ max_length=512,
97
+ truncation=True,
98
+ padding="max_length",
99
+ return_tensors="pt"
100
+ )
101
+
102
+ # ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™
103
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
104
+
105
+ return inputs
106
+
107
+ def inference(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
108
+ """
109
+ ๋ชจ๋ธ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
110
+
111
+ Args:
112
+ inputs (dict): ์ „์ฒ˜๋ฆฌ๋œ ์ž…๋ ฅ ํ…์„œ
113
+
114
+ Returns:
115
+ torch.Tensor: ๋ชจ๋ธ ์ถœ๋ ฅ ๋กœ์ง“
116
+ """
117
+ with torch.no_grad():
118
+ outputs = self.model(**inputs)
119
+ logits = outputs.logits
120
+
121
+ return logits
122
+
123
+ def postprocess(self, logits: torch.Tensor) -> List[Dict[str, Any]]:
124
+ """
125
+ ๋ชจ๋ธ ์ถœ๋ ฅ์„ ์‚ฌ๋žŒ์ด ์ฝ์„ ์ˆ˜ ์žˆ๋Š” ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
126
+
127
+ Args:
128
+ logits (torch.Tensor): ๋ชจ๋ธ ์ถœ๋ ฅ ๋กœ์ง“
129
+
130
+ Returns:
131
+ list: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฆฌ์ŠคํŠธ
132
+ """
133
+ # ๋กœ์ง“ ์ฒ˜๋ฆฌ (๋‹จ์ผ ์ถœ๋ ฅ vs ๋‹ค์ค‘ ํด๋ž˜์Šค)
134
+ if logits.shape[-1] == 1:
135
+ # Binary classification with single output
136
+ prob = torch.sigmoid(logits).item()
137
+ predicted_class = 1 if prob > 0.5 else 0
138
+ confidence = prob if predicted_class == 1 else (1 - prob)
139
+ probabilities = {
140
+ "LABEL_0": 1 - prob,
141
+ "LABEL_1": prob
142
+ }
143
+ else:
144
+ # Multi-class classification
145
+ probs = torch.softmax(logits, dim=1)[0]
146
+ predicted_class = torch.argmax(logits, dim=1).item()
147
+ confidence = probs[predicted_class].item()
148
+ probabilities = {
149
+ f"LABEL_{i}": probs[i].item()
150
+ for i in range(len(probs))
151
+ }
152
+
153
+ # ๋ ˆ์ด๋ธ” ๋งคํ•‘
154
+ label_map = {
155
+ 0: "safe",
156
+ 1: "vulnerable"
157
+ }
158
+
159
+ # ๊ฒฐ๊ณผ ํฌ๋งทํŒ…
160
+ result = {
161
+ "label": label_map.get(predicted_class, f"LABEL_{predicted_class}"),
162
+ "score": confidence,
163
+ "probabilities": probabilities,
164
+ "details": {
165
+ "is_vulnerable": predicted_class == 1,
166
+ "confidence_percentage": f"{confidence * 100:.2f}%",
167
+ "safe_probability": probabilities.get("LABEL_0", 0),
168
+ "vulnerable_probability": probabilities.get("LABEL_1", 0)
169
+ }
170
+ }
171
+
172
+ return [result]
173
+
174
+
175
+ # ๋กœ์ปฌ ํ…Œ์ŠคํŠธ์šฉ ์ฝ”๋“œ
176
+ if __name__ == "__main__":
177
+ # ๋กœ์ปฌ์—์„œ ํ…Œ์ŠคํŠธํ•  ๋•Œ ์‚ฌ์šฉ
178
+ handler = EndpointHandler(path=".")
179
+
180
+ # ํ…Œ์ŠคํŠธ ์ผ€์ด์Šค
181
+ test_code = """
182
+ import java.sql.*;
183
+ public class SQLInjectionVulnerable {
184
+ public void getUser(String userInput) {
185
+ String query = "SELECT * FROM users WHERE username = '" + userInput + "'";
186
+ Statement statement = connection.createStatement();
187
+ ResultSet resultSet = statement.executeQuery(query);
188
+ }
189
+ }
190
+ """
191
+
192
+ # ์ถ”๋ก  ์‹คํ–‰
193
+ request = {"inputs": test_code}
194
+ result = handler(request)
195
+
196
+ print("\n๐Ÿ“Š Test Result:")
197
+ print(result)