Could an EAGLE-3 draft model trained on 1.58bits further speed up LFM2.5 inference?

#10
by Sourajit123 - opened

Hey Liquid AI! I really liked the LFM2.5- 1.2b Instruct model and was looking forward to do some personal experimentation with it.
I was recently looking into speculative decoding and Bitnet and was wondering if an EAGLE3 draft model trained at ~1.58bits using logits and text generated by this model, could offer further speedups without computation cost. Also, what if we further apply n-gram speculative decoding on top of that.
I don't really have enough compute (mostly Kaggle free tier as of now) to test it directly and would really appreciate some guidance as I am really new to these stuff.

Liquid AI org

Hey thanks @Sourajit123 ! An EAGLE-3 draft model would be neat, but it's really hard to provide meaningful inference improvements at such a small scale (1.2B here). They typically exploit memory bandwidth bottlenecks in large models, but we're compute-bound instead here. As far as I know, this is already compatible with n-gram speculative decoding, however, so that could be interesting for deterministic workloads.

Hey. I tried running n-gram speculative decoding on Llama.cpp and it shows an error which I am unable to fix. Could you kindly guide? I am getting this "RuntimeError: llama_decode returned -1" using LlamaPromptLookupDecoding in llama cpp

Also- am eagerly waiting for bigger MoE models 😁.

Liquid AI org

@Sourajit123 , speculative decoding requires rollback support, which is not implemented for hybrid models in llama.cpp.

The patch below modifies upstream llama.cpp to enable rollback for LFM models (applicable on top of https://github.com/ggml-org/llama.cpp/commit/96441c955). A cleaner version will take time to upstream.

rollback_lfm2.patch
commit 5b40de9d611c1e243195ff3375784977b5af9a1b
Author: Tarek Dakhran <[email protected]>
Date:   Sun Feb 8 01:34:12 2026 +0100

    Support speculative decoding for hybrid models

diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp
index b71d496ee..84c79a76e 100644
--- a/tools/server/server-context.cpp
+++ b/tools/server/server-context.cpp
@@ -161,6 +161,11 @@ struct server_slot {
     int32_t n_draft_total = 0;      // Total draft tokens generated
     int32_t n_draft_accepted = 0;   // Draft tokens actually accepted
 
+    // Speculative decoding checkpoint for hybrid/recurrent models
+    bool                  spec_needs_checkpoint = false;
+    std::vector<uint8_t>  spec_checkpoint_data;
+    int32_t               spec_checkpoint_n_tokens = 0;
+
     void reset() {
         SLT_DBG(*this, "%s", "\n");
 
@@ -184,6 +189,10 @@ struct server_slot {
         n_draft_total = 0;
         n_draft_accepted = 0;
 
+        // clear speculative checkpoint
+        spec_checkpoint_data.clear();
+        spec_checkpoint_n_tokens = 0;
+
         task_prev = std::move(task);
         task.reset();
 
@@ -741,7 +750,10 @@ private:
         slots.clear();
 
         const bool can_spec = common_speculative_is_compat(ctx);
-        if (!can_spec) {
+        const bool can_spec_with_checkpoint = !can_spec &&
+            (llama_model_is_hybrid(model) || llama_model_is_recurrent(model));
+
+        if (!can_spec && !can_spec_with_checkpoint) {
             SRV_WRN("%s", "speculative decoding not supported by this context\n");
         }
 
@@ -757,14 +769,16 @@ private:
             slot.prompt.tokens.has_mtmd = mctx != nullptr;
 
             // try speculative decoding
-            if (can_spec) {
+            if (can_spec || can_spec_with_checkpoint) {
                 slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
                 if (slot.spec) {
                     if (mctx) {
                         SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
                         return false;
                     }
-                    SLT_INF(slot, "%s", "speculative decoding context initialized\n");
+                    slot.spec_needs_checkpoint = can_spec_with_checkpoint;
+                    SLT_INF(slot, "speculative decoding context initialized%s\n",
+                            can_spec_with_checkpoint ? " (checkpoint mode)" : "");
                 } else {
                     SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
                 }
@@ -2071,6 +2085,13 @@ private:
                     slot.drafted.clear();
                     slot.i_batch_dft.clear();
                 } else {
+                    // save checkpoint only when drafts will be used (hybrid/recurrent)
+                    if (slot.spec_needs_checkpoint) {
+                        slot.spec_checkpoint_n_tokens = slot.prompt.n_tokens() - 1; // before sampled token
+                        const size_t sz = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+                        slot.spec_checkpoint_data.resize(sz);
+                        llama_state_seq_get_data_ext(ctx, slot.spec_checkpoint_data.data(), sz, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+                    }
                     // keep track of total number of drafted tokens tested
                     slot.n_draft_total += draft.size();
 
@@ -2806,14 +2827,57 @@ private:
                 // inform the speculative decoding about the number of accepted tokens
                 common_speculative_accept(slot.spec, ids.size() - 1);
 
-                // rollback to the state before sampling the draft tokens
-                slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
+                const bool all_accepted = (ids.size() - 1 == n_draft);
+
+                if (slot.spec_needs_checkpoint && !all_accepted && !slot.spec_checkpoint_data.empty()) {
+                    // checkpoint-based rollback for hybrid/recurrent models (some drafts rejected)
+                    const llama_token sampled_orig = slot.sampled;
+                    const int32_t n_tokens_before = slot.spec_checkpoint_n_tokens;
+
+                    // 1. restore recurrent state to before speculation
+                    llama_state_seq_set_data_ext(ctx, slot.spec_checkpoint_data.data(),
+                        slot.spec_checkpoint_data.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
+
+                    // 2. rollback prompt tracking to before speculation
+                    slot.prompt.tokens.keep_first(n_tokens_before);
+
+                    // 3. remove memory entries from checkpoint pos onwards
+                    //    (recurrent has nothing there after restore, so hybrid seq_rm succeeds)
+                    llama_memory_seq_rm(llama_get_memory(ctx), slot.id, n_tokens_before, -1);
+
+                    // 4. build replay batch: sampled token + accepted draft tokens
+                    const int n_replay = (int)ids.size();
+                    llama_batch batch_replay = llama_batch_init(n_replay, 0, 1);
 
-                // add accepted tokens to the prompt
-                slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
-                slot.sampled = ids.back(); // last accepted token
+                    common_batch_add(batch_replay, sampled_orig,
+                        slot.prompt.tokens.pos_next(), { slot.id }, false);
+                    slot.prompt.tokens.push_back(sampled_orig);
+
+                    for (size_t i = 0; i < ids.size() - 1; i++) {
+                        common_batch_add(batch_replay, ids[i],
+                            slot.prompt.tokens.pos_next(), { slot.id }, false);
+                        slot.prompt.tokens.push_back(ids[i]);
+                    }
+
+                    // 5. decode replay batch to advance recurrent + attention state
+                    const int ret = llama_decode(ctx, batch_replay);
+                    llama_batch_free(batch_replay);
+
+                    if (ret != 0) {
+                        SLT_ERR(slot, "speculative checkpoint replay decode failed: %d\n", ret);
+                    }
+
+                    slot.sampled = ids.back();
+                } else {
+                    // standard rollback: pure attention models, or all drafts accepted (no rollback needed)
+                    slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
+                    slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
+                    slot.sampled = ids.back();
+
+                    llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
+                }
 
-                llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
+                slot.spec_checkpoint_data.clear();
 
                 for (size_t i = 0; i < ids.size(); ++i) {
                     completion_token_output result;

Thanks! I'll try and see

Sign up or log in to comment