File size: 6,298 Bytes
9948ba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, AutoConfig
from transformers.modeling_outputs import MaskedLMOutput   
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
try:
    from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from einops import rearrange

def convert_hf_config_to_mamba(hf_config) -> MambaConfig:
  return MambaConfig(
      d_model=hf_config.d_model,
      d_intermediate=getattr(hf_config, "intermediate_size", 4 * hf_config.d_model),
      n_layer=getattr(hf_config, "n_layer", getattr(hf_config, "num_hidden_layers", 12)),
      vocab_size=hf_config.vocab_size,
      ssm_cfg=getattr(hf_config, "ssm_cfg", {}),
      attn_layer_idx=getattr(hf_config, "attn_layer_idx", []),
      attn_cfg=getattr(hf_config, "attn_cfg", {}),
      rms_norm=getattr(hf_config, "rms_norm", True),
      residual_in_fp32=getattr(hf_config, "residual_in_fp32", True),
      fused_add_norm=getattr(hf_config, "fused_add_norm", False),
      pad_vocab_size_multiple=getattr(hf_config, "pad_vocab_size_multiple", 8),
      tie_embeddings=getattr(hf_config, "tie_embeddings", False),
  )

def patch_mixer_forward_to_accept_embeddings(model):
    """

    Injects a new forward method into a MixerModel instance,

    allowing it to accept either input_ids or inputs_embeds.

    """

    def new_forward(self, input_ids=None, inputs_embeds=None, inference_params=None, attention_mask=None, **mixer_kwargs):
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        elif input_ids is not None:
            hidden_states = self.embedding(input_ids)
        else:
            raise ValueError("You must provide either input_ids or inputs_embeds.")

        residual = None

        # hiddens: (batch_size, seq_len, d_model)
        # attention_mask: (batch_size, seq_len) -- 1 for real tokens, 0 for padding
        mask = attention_mask.unsqueeze(-1)  # (batch_size, seq_len, 1)

        for layer in self.layers:
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params, **mixer_kwargs
            )

            # Add attention mask
            hidden_states = hidden_states * mask
            residual = residual * mask

        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            hidden_states = layer_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
                is_rms_norm=isinstance(self.norm_f, RMSNorm)
            )
        return hidden_states

    # Bind the new forward method to the instance
    model.backbone.forward = new_forward.__get__(model.backbone, model.backbone.__class__)

class BiMambaForMaskedLM(PreTrainedModel):
    config_class    = AutoConfig
    base_model_prefix = "bimamba"

    def __init__(self, config):
        super().__init__(config)                    # <-- HF init
        mamba_cfg = convert_hf_config_to_mamba(config)

        # your embedding + two Mamba directions + proj
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        self.mamba_forward   = MambaLMHeadModel(mamba_cfg)
        self.mamba_backward  = MambaLMHeadModel(mamba_cfg)
        self.lm_head_proj    = nn.Linear(config.d_model * 2, config.d_model, bias=False)

        # Patch mixer_forward_to accept embeddings
        patch_mixer_forward_to_accept_embeddings(self.mamba_forward)
        patch_mixer_forward_to_accept_embeddings(self.mamba_backward)

        # self.post_init()  # wires up HF weight-tying & save/load

    #### Added:
    def get_input_embeddings(self):
        return self.token_embedding

    def set_input_embeddings(self, new_emb):
        self.token_embedding = new_emb

    def get_output_embeddings(self):
        return self.lm_head_proj

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        for backbone in (self.mamba_forward.backbone,
                         self.mamba_backward.backbone):
            for block in backbone.layers:
                block.gradient_checkpointing = True

    def forward(

        self,

        input_ids=None,

        inputs_embeds=None,

        attention_mask=None,

        labels=None,

        return_dict=True,

    ):
        if inputs_embeds is None:
            input_ids = input_ids.long()
            inputs_embeds = self.token_embedding(input_ids)

        hid_fwd = self.mamba_forward.backbone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        rev_emb = torch.flip(inputs_embeds, dims=[1])
        rev_mask = torch.flip(attention_mask, dims=[1])
        hid_bwd = self.mamba_backward.backbone(inputs_embeds=rev_emb, attention_mask=rev_mask)
        hid_bwd = torch.flip(hid_bwd, dims=[1])

        combined = torch.cat([hid_fwd, hid_bwd], dim=-1)
        projected = self.lm_head_proj(combined)
        logits    = F.linear(projected, self.token_embedding.weight)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
            loss    = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

        if not return_dict:
            out = (logits, combined)
            return (loss,) + out if loss is not None else out

        return MaskedLMOutput(
            loss=loss,
            logits=logits,
            hidden_states=projected,
        )