| |
| |
|
|
| import dropout_layer_norm |
| import torch |
| from torch.nn import init |
|
|
|
|
| def maybe_align(x, alignment_in_bytes=16): |
| """Assume that x already has last dim divisible by alignment_in_bytes""" |
| |
| |
| return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() |
|
|
|
|
| def _dropout_add_layer_norm_forward( |
| x0, |
| residual, |
| gamma, |
| beta, |
| rowscale, |
| colscale, |
| dropout_p, |
| epsilon, |
| residual_in_fp32=False, |
| is_rms_norm=False, |
| ): |
| """Assume that arguments are contiguous and aligned to 16 bytes""" |
| hidden_size = gamma.numel() |
| x0mat = x0.view((-1, hidden_size)) |
| residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
| rowscale = rowscale.view(-1) if rowscale is not None else None |
| zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( |
| x0mat, |
| residualmat, |
| gamma, |
| beta, |
| rowscale, |
| colscale, |
| None, |
| None, |
| dropout_p, |
| epsilon, |
| 1.0, |
| 0, |
| None, |
| residual_in_fp32, |
| is_rms_norm, |
| ) |
| |
| |
| return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma |
|
|
|
|
| def _dropout_add_layer_norm_backward( |
| dz, |
| dx, |
| x, |
| x0, |
| dmask, |
| mu, |
| rsigma, |
| gamma, |
| rowscale, |
| colscale, |
| dropout_p, |
| has_residual, |
| is_rms_norm=False, |
| ): |
| """Assume that arguments are contiguous and aligned to 16 bytes |
| dx == None means that it was a post-norm architecture |
| (x = drop(x0) + residual was not returned in the fwd). |
| x0 must not be None if we have colscale. |
| """ |
| hidden_size = gamma.numel() |
| xmat = x.view((-1, hidden_size)) |
| dzmat = dz.view(xmat.shape) |
| dxmat = dx.view(xmat.shape) if dx is not None else None |
| x0mat = x0.view((-1, hidden_size)) if x0 is not None else None |
| rowscale = rowscale.view(-1) if rowscale is not None else None |
| if colscale is not None: |
| assert x0 is not None, "x0 is required to compute the gradient of colscale" |
| dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( |
| dzmat, |
| dxmat, |
| xmat, |
| x0mat, |
| dmask, |
| mu, |
| rsigma, |
| gamma, |
| rowscale, |
| colscale, |
| None, |
| None, |
| dropout_p, |
| 1.0, |
| 0, |
| has_residual, |
| is_rms_norm, |
| ) |
| |
| if colscale is None: |
| return dx0mat, dresidualmat, dgamma, dbeta |
| else: |
| dcolscale = rest[0] |
| return dx0mat, dresidualmat, dgamma, dbeta, dcolscale |
|
|
|
|
| def _dropout_add_layer_norm_subset_forward( |
| x0, |
| residual, |
| gamma, |
| beta, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| epsilon, |
| rowscale_const, |
| out_numrows, |
| residual_in_fp32=False, |
| is_rms_norm=False, |
| ): |
| """Assume that arguments are contiguous and aligned to 16 bytes""" |
| hidden_size = gamma.numel() |
| x0mat = x0.view((-1, hidden_size)) |
| residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
| x0_subset = x0_subset.view(-1) if x0_subset is not None else None |
| out_subset = out_subset.view(-1) if out_subset is not None else None |
| zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( |
| x0mat, |
| residualmat, |
| gamma, |
| beta, |
| None, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| epsilon, |
| rowscale_const, |
| out_numrows, |
| None, |
| residual_in_fp32, |
| is_rms_norm, |
| ) |
| |
| |
| return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma |
|
|
|
|
| def _dropout_add_layer_norm_subset_backward( |
| dz, |
| dx, |
| x, |
| x0, |
| dmask, |
| mu, |
| rsigma, |
| gamma, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| rowscale_const, |
| x0_numrows, |
| has_residual, |
| is_rms_norm=False, |
| ): |
| """Assume that arguments are contiguous and aligned to 16 bytes |
| dx == None means that it was a post-norm architecture |
| (x = drop(x0) + residual was not returned in the fwd). |
| x0 must not be None if we have colscale. |
| """ |
| hidden_size = gamma.numel() |
| xmat = x.view((-1, hidden_size)) |
| dzmat = dz.view(-1, hidden_size) |
| dxmat = dx.view(xmat.shape) if dx is not None else None |
| x0mat = x0.view((-1, hidden_size)) if x0 is not None else None |
| x0_subset = x0_subset.view(-1) if x0_subset is not None else None |
| out_subset = out_subset.view(-1) if out_subset is not None else None |
| if colscale is not None: |
| assert x0 is not None, "x0 is required to compute the gradient of colscale" |
| dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( |
| dzmat, |
| dxmat, |
| xmat, |
| x0mat, |
| dmask, |
| mu, |
| rsigma, |
| gamma, |
| None, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| rowscale_const, |
| x0_numrows, |
| has_residual, |
| is_rms_norm, |
| ) |
| |
| if colscale is None: |
| return dx0mat, dresidualmat, dgamma, dbeta |
| else: |
| dcolscale = rest[0] |
| return dx0mat, dresidualmat, dgamma, dbeta, dcolscale |
|
|
|
|
| def _dropout_add_layer_norm_parallel_residual_forward( |
| x0, |
| x1, |
| residual, |
| gamma0, |
| beta0, |
| gamma1, |
| beta1, |
| dropout_p, |
| epsilon, |
| residual_in_fp32=False, |
| is_rms_norm=False, |
| ): |
| """Assume that arguments are contiguous and aligned to 16 bytes""" |
| hidden_size = gamma0.numel() |
| x0mat = x0.view((-1, hidden_size)) |
| x1mat = x1.view((-1, hidden_size)) if x1 is not None else None |
| residualmat = residual.view((-1, hidden_size)) if residual is not None else None |
| ( |
| z0mat, |
| z1mat, |
| xmat, |
| dmask0, |
| dmask1, |
| mu, |
| rsigma, |
| ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( |
| x0mat, |
| x1mat, |
| residualmat, |
| gamma0, |
| beta0, |
| gamma1, |
| beta1, |
| dropout_p, |
| epsilon, |
| None, |
| residual_in_fp32, |
| is_rms_norm, |
| ) |
| |
| |
| return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma |
|
|
|
|
| def _dropout_add_layer_norm_parallel_residual_backward( |
| dz0, |
| dz1, |
| dx, |
| x, |
| dmask0, |
| dmask1, |
| mu, |
| rsigma, |
| gamma0, |
| gamma1, |
| dropout_p, |
| has_x1, |
| has_residual, |
| is_rms_norm=False, |
| ): |
| """Assume that arguments are contiguous and aligned to 16 bytes |
| dx == None means that it was a post-norm architecture |
| (x = drop(x0) + residual was not returned in the fwd). |
| """ |
| hidden_size = gamma0.numel() |
| xmat = x.view((-1, hidden_size)) |
| dz0mat = dz0.view(xmat.shape) |
| dz1mat = dz1.view(xmat.shape) if dz1 is not None else None |
| dxmat = dx.view(xmat.shape) if dx is not None else None |
| ( |
| dx0mat, |
| dx1mat, |
| dresidualmat, |
| dgamma0, |
| dbeta0, |
| dgamma1, |
| dbeta1, |
| *rest, |
| ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( |
| dz0mat, |
| dz1mat, |
| dxmat, |
| xmat, |
| dmask0, |
| dmask1, |
| mu, |
| rsigma, |
| gamma0, |
| gamma1, |
| dropout_p, |
| has_x1, |
| has_residual, |
| is_rms_norm, |
| ) |
| |
| return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 |
|
|
|
|
| class DropoutAddLayerNormFn(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| x0, |
| residual, |
| gamma, |
| beta, |
| rowscale, |
| colscale, |
| dropout_p, |
| epsilon, |
| residual_in_fp32=False, |
| prenorm=False, |
| is_rms_norm=False, |
| return_dmask=False, |
| ): |
| x0 = maybe_align(x0.contiguous(), 16) |
| residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
| gamma = maybe_align(gamma.contiguous(), 16) |
| beta = maybe_align(beta.contiguous(), 16) if beta is not None else None |
| rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None |
| colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None |
| zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( |
| x0, |
| residual, |
| gamma, |
| beta, |
| rowscale, |
| colscale, |
| dropout_p, |
| epsilon, |
| residual_in_fp32, |
| is_rms_norm, |
| ) |
| |
| x0_saved = x0 if colscale is not None else None |
| ctx.save_for_backward( |
| xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale |
| ) |
| ctx.prenorm = prenorm |
| ctx.dropout_p = dropout_p |
| ctx.has_residual = residual is not None |
| ctx.is_rms_norm = is_rms_norm |
| ctx.has_beta = beta is not None |
| if not return_dmask: |
| return ( |
| zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) |
| ) |
| else: |
| dmask = ( |
| dmask.view(x0.shape) |
| if dropout_p > 0.0 |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| ) |
| ctx.mark_non_differentiable(dmask) |
| return ( |
| (zmat.view(x0.shape), dmask) |
| if not prenorm |
| else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) |
| ) |
|
|
| @staticmethod |
| def backward(ctx, dz, *args): |
| |
| dz = maybe_align(dz.contiguous(), 16) |
| dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
| x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors |
| |
| dropout_p = ctx.dropout_p |
| has_residual = ctx.has_residual |
| dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( |
| dz, |
| dx, |
| x, |
| x0, |
| dmask, |
| mu, |
| rsigma, |
| gamma, |
| rowscale, |
| colscale, |
| dropout_p, |
| has_residual, |
| ctx.is_rms_norm, |
| ) |
| dx0 = dx0mat.view(x.shape) |
| dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
| dcolscale = rest[0] if colscale is not None else None |
| return ( |
| dx0, |
| dresidual, |
| dgamma, |
| dbeta if ctx.has_beta else None, |
| None, |
| dcolscale, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| class DropoutAddLayerNormSubsetFn(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| x0, |
| residual, |
| gamma, |
| beta, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| epsilon, |
| rowscale_const, |
| out_numrows, |
| residual_in_fp32=False, |
| prenorm=False, |
| is_rms_norm=False, |
| return_dmask=False, |
| ): |
| x0 = maybe_align(x0.contiguous(), 16) |
| residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
| gamma = maybe_align(gamma.contiguous(), 16) |
| beta = maybe_align(beta.contiguous(), 16) if beta is not None else None |
| colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None |
| zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( |
| x0, |
| residual, |
| gamma, |
| beta, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| epsilon, |
| rowscale_const, |
| out_numrows, |
| residual_in_fp32, |
| is_rms_norm, |
| ) |
| |
| x0_saved = x0 if colscale is not None else None |
| x_shape = (-1, *x0.shape[1:]) |
| ctx.save_for_backward( |
| xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset |
| ) |
| ctx.prenorm = prenorm |
| ctx.dropout_p = dropout_p |
| ctx.rowscale_const = rowscale_const |
| ctx.x0_numrows = x0.shape[:-1].numel() |
| ctx.has_residual = residual is not None |
| ctx.is_rms_norm = is_rms_norm |
| ctx.has_beta = beta is not None |
| z_shape = (-1, *x0.shape[1:]) |
| if not return_dmask: |
| return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) |
| else: |
| z = zmat.view(z_shape) |
| dmask = ( |
| dmask.view(x0.shape) |
| if dropout_p > 0.0 |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| ) |
| ctx.mark_non_differentiable(dmask) |
| return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) |
|
|
| @staticmethod |
| def backward(ctx, dz, *args): |
| |
| dz = maybe_align(dz.contiguous(), 16) |
| dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
| x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors |
| |
| dropout_p = ctx.dropout_p |
| has_residual = ctx.has_residual |
| dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( |
| dz, |
| dx, |
| x, |
| x0, |
| dmask, |
| mu, |
| rsigma, |
| gamma, |
| colscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| ctx.rowscale_const, |
| ctx.x0_numrows, |
| has_residual, |
| ctx.is_rms_norm, |
| ) |
| dx0 = dx0mat.view(-1, *x.shape[1:]) |
| dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
| dcolscale = rest[0] if colscale is not None else None |
| return ( |
| dx0, |
| dresidual, |
| dgamma, |
| dbeta if ctx.has_beta else None, |
| dcolscale, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| x0, |
| x1, |
| residual, |
| gamma0, |
| beta0, |
| gamma1, |
| beta1, |
| dropout_p, |
| epsilon, |
| residual_in_fp32=False, |
| prenorm=False, |
| is_rms_norm=False, |
| return_dmask=False, |
| ): |
| x0 = maybe_align(x0.contiguous(), 16) |
| x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None |
| residual = maybe_align(residual.contiguous(), 16) if residual is not None else None |
| gamma0 = maybe_align(gamma0.contiguous(), 16) |
| beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None |
| gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None |
| beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None |
| ( |
| z0mat, |
| z1mat, |
| xmat, |
| dmask0, |
| dmask1, |
| mu, |
| rsigma, |
| ) = _dropout_add_layer_norm_parallel_residual_forward( |
| x0, |
| x1, |
| residual, |
| gamma0, |
| beta0, |
| gamma1, |
| beta1, |
| dropout_p, |
| epsilon, |
| residual_in_fp32, |
| is_rms_norm, |
| ) |
| ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) |
| ctx.prenorm = prenorm |
| ctx.dropout_p = dropout_p |
| ctx.has_x1 = x1 is not None |
| ctx.has_residual = residual is not None |
| ctx.is_rms_norm = is_rms_norm |
| ctx.has_beta = beta0 is not None |
| z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) |
| if not return_dmask: |
| return z if not prenorm else (*z, xmat.view(x0.shape)) |
| else: |
| dmask0 = ( |
| dmask0.view(x0.shape) |
| if dropout_p > 0.0 |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| ) |
| dmask1 = ( |
| dmask1.view(x0.shape) |
| if dropout_p > 0.0 and x1 is not None |
| else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) |
| ) |
| ctx.mark_non_differentiable(dmask0) |
| ctx.mark_non_differentiable(dmask1) |
| return ( |
| (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) |
| ) |
|
|
| @staticmethod |
| def backward(ctx, dz0, dz1, *args): |
| dz0 = maybe_align(dz0.contiguous(), 16) |
| dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None |
| dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None |
| x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors |
| dropout_p = ctx.dropout_p |
| has_x1 = ctx.has_x1 |
| has_residual = ctx.has_residual |
| ( |
| dx0mat, |
| dx1mat, |
| dresidualmat, |
| dgamma0, |
| dbeta0, |
| dgamma1, |
| dbeta1, |
| ) = _dropout_add_layer_norm_parallel_residual_backward( |
| dz0, |
| dz1, |
| dx, |
| x, |
| dmask0, |
| dmask1, |
| mu, |
| rsigma, |
| gamma0, |
| gamma1, |
| dropout_p, |
| has_x1, |
| has_residual, |
| ctx.is_rms_norm, |
| ) |
| dx0 = dx0mat.view(x.shape) |
| dx1 = dx1mat.view(x.shape) if dx1mat is not None else None |
| dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None |
| return ( |
| dx0, |
| dx1, |
| dresidual, |
| dgamma0, |
| dbeta0 if ctx.has_beta else None, |
| dgamma1, |
| dbeta1 if ctx.has_beta else None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| def layer_norm(x, weight, bias, epsilon): |
| return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) |
|
|
|
|
| def dropout_add_layer_norm( |
| x0, |
| residual, |
| weight, |
| bias, |
| dropout_p, |
| epsilon, |
| rowscale=None, |
| layerscale=None, |
| prenorm=False, |
| residual_in_fp32=False, |
| return_dropout_mask=False, |
| ): |
| """residual_in_fp32 only has an effect if residual is None. |
| Otherwise residual dtype is residual.dtype. |
| """ |
| return DropoutAddLayerNormFn.apply( |
| x0, |
| residual, |
| weight, |
| bias, |
| rowscale, |
| layerscale, |
| dropout_p, |
| epsilon, |
| residual_in_fp32, |
| prenorm, |
| False, |
| return_dropout_mask, |
| ) |
|
|
|
|
| def dropout_add_layer_norm_subset( |
| x0, |
| residual, |
| weight, |
| bias, |
| dropout_p, |
| epsilon, |
| layerscale=None, |
| x0_subset=None, |
| out_subset=None, |
| rowscale_const=1.0, |
| out_numrows=0, |
| prenorm=False, |
| residual_in_fp32=False, |
| return_dropout_mask=False, |
| ): |
| """residual_in_fp32 only has an effect if residual is None. |
| Otherwise residual dtype is residual.dtype. |
| """ |
| return DropoutAddLayerNormSubsetFn.apply( |
| x0, |
| residual, |
| weight, |
| bias, |
| layerscale, |
| x0_subset, |
| out_subset, |
| dropout_p, |
| epsilon, |
| rowscale_const, |
| out_numrows, |
| residual_in_fp32, |
| prenorm, |
| False, |
| return_dropout_mask, |
| ) |
|
|
|
|
| def dropout_add_layer_norm_parallel_residual( |
| x0, |
| x1, |
| residual, |
| weight0, |
| bias0, |
| weight1, |
| bias1, |
| dropout_p, |
| epsilon, |
| prenorm=False, |
| residual_in_fp32=False, |
| return_dropout_mask=False, |
| ): |
| """residual_in_fp32 only has an effect if residual is None. |
| Otherwise residual dtype is residual.dtype. |
| """ |
| return DropoutAddLayerNormParallelResidualFn.apply( |
| x0, |
| x1, |
| residual, |
| weight0, |
| bias0, |
| weight1, |
| bias1, |
| dropout_p, |
| epsilon, |
| residual_in_fp32, |
| prenorm, |
| False, |
| return_dropout_mask, |
| ) |
|
|
|
|
| class DropoutAddLayerNorm(torch.nn.Module): |
| def __init__( |
| self, |
| hidden_size, |
| prenorm=False, |
| p=0.0, |
| eps=1e-5, |
| residual_in_fp32=False, |
| device=None, |
| dtype=None, |
| ): |
| factory_kwargs = {"device": device, "dtype": dtype} |
| super().__init__() |
| self.prenorm = prenorm |
| self.p = p |
| self.eps = eps |
| self.residual_in_fp32 = residual_in_fp32 |
| self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
| self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| init.ones_(self.weight) |
| init.zeros_(self.bias) |
|
|
| def forward(self, x0, residual=None): |
| return dropout_add_layer_norm( |
| x0, |
| residual, |
| self.weight, |
| self.bias, |
| self.p if self.training else 0.0, |
| self.eps, |
| prenorm=self.prenorm, |
| residual_in_fp32=self.residual_in_fp32, |
| ) |
|
|