import os import torch from transformers import AutoTokenizer, T5ForSequenceClassification from typing import Dict, List, Any class EndpointHandler: """ HuggingFace Inference Endpoint Handler for Java Vulnerability Detection CodeT5 기반 분류 모델 (LoRA fine-tuned) """ def __init__(self, path="."): """ 모델과 토크나이저를 초기화합니다. Args: path (str): 모델이 저장된 경로 (HuggingFace Hub에서 자동으로 설정됨) """ print(f"🚀 Loading Java Vulnerability Detection Model from {path}") # 디바이스 설정 self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"📍 Device: {self.device}") # 토크나이저 로드 self.tokenizer = AutoTokenizer.from_pretrained(path) # T5ForSequenceClassification 모델 로드 self.model = T5ForSequenceClassification.from_pretrained( path, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) # 모델을 평가 모드로 설정하고 디바이스로 이동 self.model.to(self.device) self.model.eval() print("✅ Model loaded successfully!") def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ 메인 추론 메서드 (HuggingFace Inference API가 호출) Args: data (dict): 입력 데이터 - "inputs" (str): Java 코드 또는 - "code" (str): Java 코드 Returns: list: 예측 결과 리스트 """ # 1. 전처리 inputs = self.preprocess(data) # 2. 추론 outputs = self.inference(inputs) # 3. 후처리 result = self.postprocess(outputs) return result def preprocess(self, request: Dict[str, Any]) -> Dict[str, torch.Tensor]: """ 입력 데이터를 전처리합니다. Args: request (dict): API 요청 데이터 Returns: dict: 토크나이즈된 입력 텐서 """ # 입력 텍스트 추출 if isinstance(request, dict): # "inputs" 또는 "code" 키에서 Java 코드 추출 code = request.get("inputs") or request.get("code") elif isinstance(request, list) and len(request) > 0: code = request[0].get("inputs") or request[0].get("code") elif isinstance(request, str): code = request else: raise ValueError( "Invalid request format. Expected {'inputs': 'Java code here'} " "or {'code': 'Java code here'}" ) if not code: raise ValueError("No code provided in request") # 프롬프트 템플릿 적용 input_text = f"Is this Java code vulnerable?:\n{code}" # 토크나이징 inputs = self.tokenizer( input_text, max_length=512, truncation=True, padding="max_length", return_tensors="pt" ) # 디바이스로 이동 inputs = {k: v.to(self.device) for k, v in inputs.items()} return inputs def inference(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: """ 모델 추론을 수행합니다. Args: inputs (dict): 전처리된 입력 텐서 Returns: torch.Tensor: 모델 출력 로짓 """ with torch.no_grad(): outputs = self.model(**inputs) logits = outputs.logits return logits def postprocess(self, logits: torch.Tensor) -> List[Dict[str, Any]]: """ 모델 출력을 사람이 읽을 수 있는 형태로 변환합니다. Args: logits (torch.Tensor): 모델 출력 로짓 Returns: list: 예측 결과 리스트 """ # 로짓 처리 (단일 출력 vs 다중 클래스) if logits.shape[-1] == 1: # Binary classification with single output prob = torch.sigmoid(logits).item() predicted_class = 1 if prob > 0.5 else 0 confidence = prob if predicted_class == 1 else (1 - prob) probabilities = { "LABEL_0": 1 - prob, "LABEL_1": prob } else: # Multi-class classification probs = torch.softmax(logits, dim=1)[0] predicted_class = torch.argmax(logits, dim=1).item() confidence = probs[predicted_class].item() probabilities = { f"LABEL_{i}": probs[i].item() for i in range(len(probs)) } # 레이블 매핑 label_map = { 0: "safe", 1: "vulnerable" } # 결과 포맷팅 result = { "label": label_map.get(predicted_class, f"LABEL_{predicted_class}"), "score": confidence, "probabilities": probabilities, "details": { "is_vulnerable": predicted_class == 1, "confidence_percentage": f"{confidence * 100:.2f}%", "safe_probability": probabilities.get("LABEL_0", 0), "vulnerable_probability": probabilities.get("LABEL_1", 0) } } return [result] # 로컬 테스트용 코드 if __name__ == "__main__": # 로컬에서 테스트할 때 사용 handler = EndpointHandler(path=".") # 테스트 케이스 test_code = """ import java.sql.*; public class SQLInjectionVulnerable { public void getUser(String userInput) { String query = "SELECT * FROM users WHERE username = '" + userInput + "'"; Statement statement = connection.createStatement(); ResultSet resultSet = statement.executeQuery(query); } } """ # 추론 실행 request = {"inputs": test_code} result = handler(request) print("\n📊 Test Result:") print(result)