| | import os |
| | import time |
| | from datetime import timedelta |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from mmengine.config import Config |
| | from mmengine.utils import ProgressBar |
| | from transformers import AutoConfig, AutoModel |
| |
|
| | class RamDataset(torch.utils.data.Dataset): |
| | def __init__(self, data_path, is_train=True, num_relation_classes=56): |
| | super().__init__() |
| | self.num_relation_classes = num_relation_classes |
| | data = np.load(data_path, allow_pickle=True) |
| | self.samples = data["arr_0"] |
| | sample_num = self.samples.size |
| | self.sample_idx_list = [] |
| | for idx in range(sample_num): |
| | if self.samples[idx]["is_train"] == is_train: |
| | self.sample_idx_list.append(idx) |
| |
|
| | def __getitem__(self, idx): |
| | sample = self.samples[self.sample_idx_list[idx]] |
| | object_num = sample["feat"].shape[0] |
| | embedding = torch.from_numpy(sample["feat"]) |
| | gt_rels = sample["relations"] |
| | rel_target = self._get_target(object_num, gt_rels) |
| | return embedding, rel_target, gt_rels |
| |
|
| | def __len__(self): |
| | return len(self.sample_idx_list) |
| |
|
| | def _get_target(self, object_num, gt_rels): |
| | rel_target = torch.zeros([self.num_relation_classes, object_num, object_num]) |
| | for ii, jj, cls_relationship in gt_rels: |
| | rel_target[cls_relationship, ii, jj] = 1 |
| | return rel_target |
| |
|
| |
|
| | class RamModel(nn.Module): |
| | def __init__( |
| | self, |
| | pretrained_model_name_or_path, |
| | load_pretrained_weights=True, |
| | num_transformer_layer=2, |
| | input_feature_size=256, |
| | output_feature_size=768, |
| | cls_feature_size=512, |
| | num_relation_classes=56, |
| | pred_type="attention", |
| | loss_type="bce", |
| | ): |
| | super().__init__() |
| | |
| | self.cls_feature_size = cls_feature_size |
| | self.num_relation_classes = num_relation_classes |
| | self.pred_type = pred_type |
| | self.loss_type = loss_type |
| |
|
| | |
| | self.fc_input = nn.Sequential( |
| | nn.Linear(input_feature_size, output_feature_size), |
| | nn.LayerNorm(output_feature_size), |
| | ) |
| | self.fc_output = nn.Sequential( |
| | nn.Linear(output_feature_size, output_feature_size), |
| | nn.LayerNorm(output_feature_size), |
| | ) |
| | |
| | if load_pretrained_weights: |
| | self.model = AutoModel.from_pretrained(pretrained_model_name_or_path) |
| | else: |
| | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| | self.model = AutoModel.from_config(config) |
| | if num_transformer_layer != "all" and isinstance(num_transformer_layer, int): |
| | self.model.encoder.layer = self.model.encoder.layer[:num_transformer_layer] |
| | |
| | self.cls_sub = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes) |
| | self.cls_obj = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes) |
| | |
| | if self.loss_type == "bce": |
| | self.bce_loss = nn.BCEWithLogitsLoss() |
| | elif self.loss_type == "multi_label_ce": |
| | print("Use Multi Label Cross Entropy Loss.") |
| |
|
| | def forward(self, embeds, attention_mask=None): |
| | """ |
| | embeds: (batch_size, token_num, feature_size) |
| | attention_mask: (batch_size, token_num) |
| | """ |
| | |
| | embeds = self.fc_input(embeds) |
| | |
| | position_ids = torch.ones([1, embeds.shape[1]]).to(embeds.device).to(torch.long) |
| | outputs = self.model.forward(inputs_embeds=embeds, attention_mask=attention_mask, position_ids=position_ids) |
| | embeds = outputs["last_hidden_state"] |
| | |
| | embeds = self.fc_output(embeds) |
| | |
| | batch_size, token_num, feature_size = embeds.shape |
| | sub_embeds = self.cls_sub(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3]) |
| | obj_embeds = self.cls_obj(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3]) |
| | if self.pred_type == "attention": |
| | cls_pred = sub_embeds @ torch.transpose(obj_embeds, 2, 3) / self.cls_feature_size**0.5 |
| | elif self.pred_type == "einsum": |
| | cls_pred = torch.einsum("nrsc,nroc->nrso", sub_embeds, obj_embeds) |
| | return cls_pred |
| |
|
| | def loss(self, pred, target, attention_mask): |
| | loss_dict = dict() |
| | batch_size, relation_num, _, _ = pred.shape |
| |
|
| | mask = torch.zeros_like(pred).to(pred.device) |
| | for idx in range(batch_size): |
| | n = torch.sum(attention_mask[idx]).to(torch.int) |
| | mask[idx, :, :n, :n] = 1 |
| | pred = pred * mask - 9999 * (1 - mask) |
| |
|
| | if self.loss_type == "bce": |
| | loss = self.bce_loss(pred, target) |
| | elif self.loss_type == "multi_label_ce": |
| | input_tensor = torch.permute(pred, (1, 0, 2, 3)) |
| | target_tensor = torch.permute(target, (1, 0, 2, 3)) |
| | input_tensor = pred.reshape([relation_num, -1]) |
| | target_tensor = target.reshape([relation_num, -1]) |
| | loss = self.multilabel_categorical_crossentropy(target_tensor, input_tensor) |
| | weight = loss / loss.max() |
| | loss = loss * weight |
| | loss = loss.mean() |
| | loss_dict["loss"] = loss |
| |
|
| | |
| | recall_20 = get_recall_N(pred, target, object_num=20) |
| | loss_dict["recall@20"] = recall_20 |
| | return loss_dict |
| |
|
| | def multilabel_categorical_crossentropy(self, y_true, y_pred): |
| | """ |
| | https://kexue.fm/archives/7359 |
| | """ |
| | y_pred = (1 - 2 * y_true) * y_pred |
| | y_pred_neg = y_pred - y_true * 9999 |
| | y_pred_pos = y_pred - (1 - y_true) * 9999 |
| | zeros = torch.zeros_like(y_pred[..., :1]) |
| | y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) |
| | y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) |
| | neg_loss = torch.logsumexp(y_pred_neg, dim=-1) |
| | pos_loss = torch.logsumexp(y_pred_pos, dim=-1) |
| | return neg_loss + pos_loss |
| |
|
| |
|
| | def get_recall_N(y_pred, y_true, object_num=20): |
| | """ |
| | y_pred: [batch_size, 56, object_num, object_num] |
| | y_true: [batch_size, 56, object_num, object_num] |
| | """ |
| |
|
| | device = y_pred.device |
| | recall_list = [] |
| |
|
| | for idx in range(len(y_true)): |
| | sample_y_true = [] |
| | sample_y_pred = [] |
| |
|
| | |
| | _, topk_indices = torch.topk( |
| | y_true[idx : idx + 1].reshape( |
| | [ |
| | -1, |
| | ] |
| | ), |
| | k=object_num, |
| | ) |
| | for index in topk_indices: |
| | pred_cls = index // (y_true.shape[2] ** 2) |
| | index_subject_object = index % (y_true.shape[2] ** 2) |
| | pred_subject = index_subject_object // y_true.shape[2] |
| | pred_object = index_subject_object % y_true.shape[2] |
| | if y_true[idx, pred_cls, pred_subject, pred_object] == 0: |
| | continue |
| | sample_y_true.append([pred_subject, pred_object, pred_cls]) |
| |
|
| | |
| | _, topk_indices = torch.topk( |
| | y_pred[idx : idx + 1].reshape( |
| | [ |
| | -1, |
| | ] |
| | ), |
| | k=object_num, |
| | ) |
| | for index in topk_indices: |
| | pred_cls = index // (y_pred.shape[2] ** 2) |
| | index_subject_object = index % (y_pred.shape[2] ** 2) |
| | pred_subject = index_subject_object // y_pred.shape[2] |
| | pred_object = index_subject_object % y_pred.shape[2] |
| | sample_y_pred.append([pred_subject, pred_object, pred_cls]) |
| |
|
| | recall = len([x for x in sample_y_pred if x in sample_y_true]) / (len(sample_y_true) + 1e-8) |
| | recall_list.append(recall) |
| |
|
| | recall = torch.tensor(recall_list).to(device).mean() * 100 |
| | return recall |
| |
|
| |
|
| | class RamTrainer(object): |
| | def __init__(self, config): |
| | self.config = config |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self._build_dataset() |
| | self._build_dataloader() |
| | self._build_model() |
| | self._build_optimizer() |
| | self._build_lr_scheduler() |
| |
|
| | def _build_dataset(self): |
| | self.dataset = RamDataset(**self.config.dataset) |
| |
|
| | def _build_dataloader(self): |
| | self.dataloader = torch.utils.data.DataLoader( |
| | self.dataset, |
| | batch_size=self.config.dataloader.batch_size, |
| | shuffle=True if self.config.dataset.is_train else False, |
| | ) |
| |
|
| | def _build_model(self): |
| | self.model = RamModel(**self.config.model).to(self.device) |
| | if self.config.load_from is not None: |
| | self.model.load_state_dict(torch.load(self.config.load_from)) |
| | self.model.train() |
| |
|
| | def _build_optimizer(self): |
| | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr, weight_decay=self.config.optim.weight_decay, eps=self.config.optim.eps, betas=self.config.optim.betas) |
| |
|
| | def _build_lr_scheduler(self): |
| | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.optim.lr_scheduler.step, gamma=self.config.optim.lr_scheduler.gamma) |
| |
|
| | def train(self): |
| | t_start = time.time() |
| | running_avg_loss = 0 |
| | for epoch_idx in range(self.config.num_epoch): |
| | for batch_idx, batch_data in enumerate(self.dataloader): |
| | batch_embeds = batch_data[0].to(torch.float32).to(self.device) |
| | batch_target = batch_data[1].to(torch.float32).to(self.device) |
| | attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1])) |
| | batch_pred = self.model.forward(batch_embeds, attention_mask) |
| | loss_dict = self.model.loss(batch_pred, batch_target, attention_mask) |
| | loss = loss_dict["loss"] |
| | recall_20 = loss_dict["recall@20"] |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.optim.max_norm, self.config.optim.norm_type) |
| | self.optimizer.step() |
| | running_avg_loss += loss.item() |
| |
|
| | if batch_idx % 100 == 0: |
| | t_current = time.time() |
| | num_finished_step = epoch_idx * self.config.num_epoch * len(self.dataloader) + batch_idx + 1 |
| | num_to_do_step = (self.config.num_epoch - epoch_idx - 1) * len(self.dataloader) + (len(self.dataloader) - batch_idx - 1) |
| | avg_speed = num_finished_step / (t_current - t_start) |
| | eta = num_to_do_step / avg_speed |
| | print( |
| | "ETA={:0>8}, Epoch={}, Batch={}/{}, LR={}, Loss={:.4f}, RunningAvgLoss={:.4f}, Recall@20={:.2f}%".format( |
| | str(timedelta(seconds=int(eta))), epoch_idx + 1, batch_idx, len(self.dataloader), self.lr_scheduler.get_last_lr()[0], loss.item(), running_avg_loss / num_finished_step, recall_20.item() |
| | ) |
| | ) |
| | self.lr_scheduler.step() |
| | if not os.path.exists(self.config.output_dir): |
| | os.makedirs(self.config.output_dir) |
| | save_path = os.path.join(self.config.output_dir, "epoch_{}.pth".format(epoch_idx + 1)) |
| | print("Save epoch={} checkpoint to {}".format(epoch_idx + 1, save_path)) |
| | torch.save(self.model.state_dict(), save_path) |
| | return save_path |
| |
|
| |
|
| | class RamPredictor(object): |
| | def __init__(self, config): |
| | self.config = config |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self._build_dataset() |
| | self._build_dataloader() |
| | self._build_model() |
| |
|
| | def _build_dataset(self): |
| | self.dataset = RamDataset(**self.config.dataset) |
| |
|
| | def _build_dataloader(self): |
| | self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.config.dataloader.batch_size, shuffle=False) |
| |
|
| | def _build_model(self): |
| | self.model = RamModel(**self.config.model).to(self.device) |
| | if self.config.load_from is not None: |
| | self.model.load_state_dict(torch.load(self.config.load_from)) |
| | self.model.eval() |
| |
|
| | def predict(self, batch_embeds, pred_keep_num=100): |
| | """ |
| | Parameters |
| | ---------- |
| | batch_embeds: (batch_size=1, token_num, feature_size) |
| | pred_keep_num: int |
| | Returns |
| | ------- |
| | batch_pred: (batch_size, relation_num, object_num, object_num) |
| | pred_rels: [[sub_id, obj_id, rel_id], ...] |
| | """ |
| | if not isinstance(batch_embeds, torch.Tensor): |
| | batch_embeds = torch.asarray(batch_embeds) |
| | batch_embeds = batch_embeds.to(torch.float32).to(self.device) |
| | attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1])) |
| | batch_pred = self.model.forward(batch_embeds, attention_mask) |
| | for idx_i in range(batch_pred.shape[2]): |
| | batch_pred[:, :, idx_i, idx_i] = -9999 |
| | batch_pred = batch_pred.sigmoid() |
| |
|
| | pred_rels = [] |
| | _, topk_indices = torch.topk( |
| | batch_pred.reshape( |
| | [ |
| | -1, |
| | ] |
| | ), |
| | k=pred_keep_num, |
| | ) |
| |
|
| | |
| | for index in topk_indices: |
| | pred_relation = index // (batch_pred.shape[2] ** 2) |
| | index_subject_object = index % (batch_pred.shape[2] ** 2) |
| | pred_subject = index_subject_object // batch_pred.shape[2] |
| | pred_object = index_subject_object % batch_pred.shape[2] |
| | pred = [pred_subject.item(), pred_object.item(), pred_relation.item()] |
| | pred_rels.append(pred) |
| | return batch_pred, pred_rels |
| |
|
| | def eval(self): |
| | sum_recall_20 = 0.0 |
| | sum_recall_50 = 0.0 |
| | sum_recall_100 = 0.0 |
| | prog_bar = ProgressBar(len(self.dataloader)) |
| | for batch_idx, batch_data in enumerate(self.dataloader): |
| | batch_embeds = batch_data[0] |
| | batch_target = batch_data[1] |
| | gt_rels = batch_data[2] |
| | batch_pred, pred_rels = self.predict(batch_embeds) |
| | this_recall_20 = get_recall_N(batch_pred, batch_target, object_num=20) |
| | this_recall_50 = get_recall_N(batch_pred, batch_target, object_num=50) |
| | this_recall_100 = get_recall_N(batch_pred, batch_target, object_num=100) |
| | sum_recall_20 += this_recall_20.item() |
| | sum_recall_50 += this_recall_50.item() |
| | sum_recall_100 += this_recall_100.item() |
| | prog_bar.update() |
| | recall_20 = sum_recall_20 / len(self.dataloader) |
| | recall_50 = sum_recall_50 / len(self.dataloader) |
| | recall_100 = sum_recall_100 / len(self.dataloader) |
| | metric = { |
| | "recall_20": recall_20, |
| | "recall_50": recall_50, |
| | "recall_100": recall_100, |
| | } |
| | return metric |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | config = dict( |
| | dataset=dict( |
| | data_path="./data/feat_0420.npz", |
| | is_train=True, |
| | num_relation_classes=56, |
| | ), |
| | dataloader=dict( |
| | batch_size=4, |
| | ), |
| | model=dict( |
| | pretrained_model_name_or_path="bert-base-uncased", |
| | load_pretrained_weights=True, |
| | num_transformer_layer=2, |
| | input_feature_size=256, |
| | output_feature_size=768, |
| | cls_feature_size=512, |
| | num_relation_classes=56, |
| | pred_type="attention", |
| | loss_type="multi_label_ce", |
| | ), |
| | optim=dict( |
| | lr=1e-4, |
| | weight_decay=0.05, |
| | eps=1e-8, |
| | betas=(0.9, 0.999), |
| | max_norm=0.01, |
| | norm_type=2, |
| | lr_scheduler=dict( |
| | step=[6, 10], |
| | gamma=0.1, |
| | ), |
| | ), |
| | num_epoch=12, |
| | output_dir="./work_dirs", |
| | load_from=None, |
| | ) |
| |
|
| | |
| | config = Config(config) |
| | trainer = RamTrainer(config) |
| | last_model_path = trainer.train() |
| |
|
| | |
| | config.dataset.is_train = False |
| | config.load_from = last_model_path |
| | predictor = RamPredictor(config) |
| | metric = predictor.eval() |
| | print(metric) |
| |
|