Leesn465 commited on
Commit
d26b332
·
verified ·
1 Parent(s): 3ae9a70

Update util/keywordExtract.py

Browse files
Files changed (1) hide show
  1. util/keywordExtract.py +34 -22
util/keywordExtract.py CHANGED
@@ -21,7 +21,6 @@ def load_company_list(file_path='상장법인목록.xls'):
21
  summary_tokenizer = PreTrainedTokenizerFast.from_pretrained("gogamza/kobart-summarization")
22
  summary_model = BartForConditionalGeneration.from_pretrained("gogamza/kobart-summarization")
23
 
24
-
25
  def summarize_kobart(text):
26
  input_ids = summary_tokenizer.encode(text, return_tensors="pt")
27
  summary_ids = summary_model.generate(
@@ -58,34 +57,46 @@ kw_model = KeyBERT(model=kobert_embedder)
58
 
59
  STOPWORDS_FILE = "stopwords-ko.txt"
60
 
61
- # ✅ 감성 분석용 모델 (예: kykim/bert-kor-base 사용 가정)
62
- sentiment_model_name = "kykim/bert-kor-base"
63
  bert_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
64
  bert_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
- bert_model = bert_model.to(device)
67
 
68
- def classify_emotion(text):
69
- tokens = bert_tokenizer(text, padding=True, truncation=True, return_tensors="pt").to(device)
70
- with torch.no_grad():
71
- prediction = bert_model(**tokens)
72
- prediction = F.softmax(prediction.logits, dim=1)
73
- output = prediction.argmax(dim=1).item()
74
- labels = ["부정적", "중립적", "긍정적"]
75
- return labels[output]
76
 
77
- sentiment_tokenizer = BertTokenizer.from_pretrained("kykim/bert-kor-base")
78
- sentiment_model = BertForSequenceClassification.from_pretrained("kykim/bert-kor-base")
79
 
80
  def analyze_sentiment(text):
81
- inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
 
 
 
 
 
 
 
82
  with torch.no_grad():
83
- outputs = sentiment_model(**inputs)
84
- probs = F.softmax(outputs.logits, dim=1)
85
- return {
86
- "positive": round(float(probs[0][1]), 4),
87
- "negative": round(float(probs[0][0]), 4)
88
- }
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  def get_or_download_stopwords():
91
  # 1. 파일이 있으면 읽어서 반환
@@ -125,8 +136,9 @@ def resultKeyword(content) :
125
 
126
  # 불용어 처리 후 요약 텍스트에서 키워드 추출
127
  filtered_summary = remove_stopwords(summary, korean_stopwords)
 
128
  keywords = kw_model.extract_keywords(
129
- filtered_summary,
130
  keyphrase_ngram_range=(1, 2), # 복합명사 유지 가능
131
  stop_words=None,
132
  top_n=5
 
21
  summary_tokenizer = PreTrainedTokenizerFast.from_pretrained("gogamza/kobart-summarization")
22
  summary_model = BartForConditionalGeneration.from_pretrained("gogamza/kobart-summarization")
23
 
 
24
  def summarize_kobart(text):
25
  input_ids = summary_tokenizer.encode(text, return_tensors="pt")
26
  summary_ids = summary_model.generate(
 
57
 
58
  STOPWORDS_FILE = "stopwords-ko.txt"
59
 
60
+ # ✅ 감성 분석용 모델 (예: snunlp/KR-FinBert-SC 사용)
61
+ sentiment_model_name = "snunlp/KR-FinBert-SC"
62
  bert_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
63
  bert_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
64
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
65
 
66
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name) # 👈 tokenizer 정의
67
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
 
 
 
 
 
 
68
 
69
+ sentiment_model = sentiment_model.to(device)
 
70
 
71
  def analyze_sentiment(text):
72
+ inputs = sentiment_tokenizer(
73
+ text,
74
+ return_tensors="pt",
75
+ truncation=True,
76
+ padding=True,
77
+ max_length=512 # 👈 추가
78
+ ).to(device)
79
+
80
+ # 모델 추론
81
  with torch.no_grad():
82
+ outputs = bert_model(**inputs)
83
+ logits = outputs.logits
84
+ #확률 계산
85
+ print("logits:", logits)
86
+ print("logits.shape:", logits.shape)
87
+
88
+ probs = F.softmax(logits, dim=1)[0]
89
+ #라벨링
90
+ label_idx = torch.argmax(probs).item()
91
+ labels = ["부정적", "중립적", "긍정적"]
92
+ label = labels[label_idx]
93
+
94
+ return {
95
+ "negative": round(float(probs[0]), 4),
96
+ "neutral": round(float(probs[1]), 4),
97
+ "positive": round(float(probs[2]), 4),
98
+ }
99
+
100
 
101
  def get_or_download_stopwords():
102
  # 1. 파일이 있으면 읽어서 반환
 
136
 
137
  # 불용어 처리 후 요약 텍스트에서 키워드 추출
138
  filtered_summary = remove_stopwords(summary, korean_stopwords)
139
+ filtered_content = remove_stopwords(content, korean_stopwords)
140
  keywords = kw_model.extract_keywords(
141
+ filtered_content,
142
  keyphrase_ngram_range=(1, 2), # 복합명사 유지 가능
143
  stop_words=None,
144
  top_n=5