File size: 6,458 Bytes
937cfd9
 
 
 
 
 
 
4b0f2f8
937cfd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"


TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
    // SM90
    ops.def("qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90", torch::kCUDA, &qk_int8_sv_f8_accum_f32_attn_inst_buf_sm90_wrap);

    ops.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, Tensor v_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_sm90_wrap);

    ops.def("qk_int8_sv_f8_accum_f32_attn_inst_buf(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_attn_inst_buf_wrap);

    ops.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, Tensor v_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_wrap);
    
    ops.def("qk_int8_sv_f8_accum_f32_attn(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_attn", torch::kCUDA, &qk_int8_sv_f8_accum_f32_attn_wrap);

    ops.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, Tensor v_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_wrap);

    ops.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, Tensor v_scale, Tensor v_mean, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn", torch::kCUDA, &qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn_wrap);
    
    ops.def("qk_int8_sv_f8_accum_f16_attn_inst_buf(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f16_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f16_attn_inst_buf_wrap);
    
    ops.def("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, Tensor v_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf_wrap);

    ops.def("qk_int8_sv_f16_accum_f32_attn(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f16_accum_f32_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f32_attn_wrap);

    ops.def("qk_int8_sv_f16_accum_f16_attn(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f16_accum_f16_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f16_attn_wrap);

    ops.def("qk_int8_sv_f16_accum_f16_attn_inst_buf(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f16_accum_f16_attn_inst_buf", torch::kCUDA, &qk_int8_sv_f16_accum_f16_attn_inst_buf_wrap);

    ops.def("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(Tensor q, Tensor k, Tensor v, Tensor! o, Tensor q_scale, Tensor k_scale, Tensor v_mean, int tensor_layout, int is_causal, int qk_quant_gran, float sm_scale, int return_lse) -> Tensor");
    ops.impl("qk_int8_sv_f16_accum_f16_fuse_v_mean_attn", torch::kCUDA, &qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_wrap);

    //Fused (available across supported archs)
    ops.def("quant_per_block_int8_cuda(Tensor input, Tensor! output, Tensor scale, float sm_scale, int block_size, int tensor_layout) -> ()");
    ops.impl("quant_per_block_int8_cuda", torch::kCUDA, &quant_per_block_int8_cuda_wrap);

    ops.def("quant_per_block_int8_fuse_sub_mean_cuda(Tensor input, Tensor mean, Tensor! output, Tensor scale, int block_size, int tensor_layout) -> ()");
    ops.impl("quant_per_block_int8_fuse_sub_mean_cuda", torch::kCUDA, &quant_per_block_int8_fuse_sub_mean_cuda_wrap);

    ops.def("quant_per_warp_int8_cuda(Tensor input, Tensor! output, Tensor scale, int block_size, int warp_block_size, int tensor_layout) -> ()");
    ops.impl("quant_per_warp_int8_cuda", torch::kCUDA, &quant_per_warp_int8_cuda_wrap);

    ops.def("sub_mean_cuda(Tensor input, Tensor mean, Tensor! output, int tensor_layout) -> ()");
    ops.impl("sub_mean_cuda", torch::kCUDA, &sub_mean_cuda_wrap);

    ops.def("transpose_pad_permute_cuda(Tensor input, Tensor! output, int tensor_layout) -> ()");
    ops.impl("transpose_pad_permute_cuda", torch::kCUDA, &transpose_pad_permute_cuda_wrap);

    ops.def("scale_fuse_quant_cuda(Tensor input, Tensor! output, Tensor scale, int num_tokens, float scale_max, int tensor_layout) -> ()");
    ops.impl("scale_fuse_quant_cuda", torch::kCUDA, &scale_fuse_quant_cuda_wrap);

    ops.def("mean_scale_fuse_quant_cuda(Tensor input, Tensor! output, Tensor mean, Tensor scale, int num_tokens, float scale_max, int tensor_layout) -> ()");
    ops.impl("mean_scale_fuse_quant_cuda", torch::kCUDA, &mean_scale_fuse_quant_cuda_wrap);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME);