File size: 3,915 Bytes
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ade9bc3
58220b6
fc04dc0
58220b6
 
 
5ef1601
58220b6
 
 
5ef1601
 
 
 
 
 
 
 
 
 
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef1601
 
 
 
 
 
58220b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ade9bc3
 
 
 
58220b6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#include "llama-kv-cache-unified.h"

#include <vector>

//
// llama_kv_cache_unified_iswa
//

// utilizes two instances of llama_kv_cache_unified
//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers

class llama_kv_cache_unified_iswa : public llama_memory_i {
public:
    llama_kv_cache_unified_iswa(
            const llama_model & model,
                    ggml_type   type_k,
                    ggml_type   type_v,
                         bool   v_trans,
                         bool   offload,
                         bool   swa_full,
                     uint32_t   kv_size,
                     uint32_t   n_seq_max,
                     uint32_t   n_ubatch,
                     uint32_t   n_pad);

    ~llama_kv_cache_unified_iswa() = default;

    //
    // llama_memory_i
    //

    llama_memory_state_ptr init_batch(
            llama_batch_allocr & balloc,
            uint32_t n_ubatch,
            bool embd_all) override;

    llama_memory_state_ptr init_full() override;

    llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;

    bool get_can_shift() const override;

    void clear(bool data) override;

    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
    void seq_keep(llama_seq_id seq_id)                                                          override;
    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;

    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
    llama_pos seq_pos_max(llama_seq_id seq_id) const override;

    // state write/load

    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;

    //
    // llama_kv_cache_unified_iswa specific API
    //

    llama_kv_cache_unified * get_base() const;
    llama_kv_cache_unified * get_swa () const;

private:
    const llama_hparams & hparams;

    std::unique_ptr<llama_kv_cache_unified> kv_base;
    std::unique_ptr<llama_kv_cache_unified> kv_swa;
};

class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
public:
    // used for errors
    llama_kv_cache_unified_iswa_state(llama_memory_status status);

    // used to create a full-cache state
    llama_kv_cache_unified_iswa_state(
            llama_kv_cache_unified_iswa * kv);

    // used to create an update state
    llama_kv_cache_unified_iswa_state(
            llama_kv_cache_unified_iswa * kv,
            llama_context * lctx,
            bool optimize);

    // used to create a state from a batch
    llama_kv_cache_unified_iswa_state(
            llama_kv_cache_unified_iswa * kv,
            std::vector<uint32_t> heads_base,
            std::vector<uint32_t> heads_swa,
            std::vector<llama_ubatch> ubatches);

    virtual ~llama_kv_cache_unified_iswa_state();

    //
    // llama_memory_state_i
    //

    bool next()  override;
    bool apply() override;

    llama_memory_status  get_status() const override;
    const llama_ubatch & get_ubatch() const override;

    //
    // llama_kv_cache_unified_iswa_state specific API
    //

    const llama_kv_cache_unified_state * get_base() const;
    const llama_kv_cache_unified_state * get_swa()  const;

private:
    //llama_kv_cache_unified_iswa * kv;

    // the index of the next ubatch to process
    size_t i_next = 0;

    std::vector<llama_ubatch> ubatches;

    const llama_memory_state_ptr state_base;
    const llama_memory_state_ptr state_swa;

    const llama_memory_status status;
};