Gemma 4 Clinical Trial Endpoint Extractor

A fine-tuned Gemma 4 E4B-it model (LoRA adapter) for extracting structured endpoint information from clinical trial text. The model takes unstructured endpoint descriptions and outputs structured JSON following a predefined schema.

⚠️ Loading fix (this revision)

The earlier snapshot of this adapter was saved with multimodal-wrapper layer paths and also carried trained LoRAs for the vision and audio towers. With recent peft (≥ 0.18) this caused PeftModel.from_pretrained to log a Found missing adapter keys warning and silently fall back to a default-initialized LoRA — the base instruction-tuned model would still produce JSON that looked schema-conformant (the system prompt is detailed enough), so the regression was easy to miss in spot checks.

This revision rewrites adapter_model.safetensors and adapter_config.json so PeftModel.from_pretrained(base, "Shubh-0789/gemma4-clinical-endpoint-extractor") loads with 0 warnings and applies the trained weights to all 42 language-model decoder layers as expected. The training data, hyperparameters and behaviour are unchanged — see Quick start (fixed adapter) below for the up-to-date load snippet, and What changed in this release further down for technical detail.

Quick start (fixed adapter)

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gemma4 import modeling_gemma4
from peft import PeftModel

# Required compatibility patch: Gemma 4's `Gemma4ClippableLinear` wraps
# nn.Linear and so cannot be wrapped by PEFT directly. Replace it with an
# nn.Linear-derived class BEFORE loading the base model. Behavior is
# preserved for layers where `use_clipped_linears=False` (the default).
class PatchedClippableLinear(nn.Linear):
    def __init__(self, config, in_features, out_features):
        nn.Linear.__init__(self, in_features, out_features, bias=False)
        self.use_clipped_linears = getattr(config, "use_clipped_linears", False)
        if self.use_clipped_linears:
            self.register_buffer("input_min", torch.tensor(-float("inf")))
            self.register_buffer("input_max", torch.tensor(float("inf")))
            self.register_buffer("output_min", torch.tensor(-float("inf")))
            self.register_buffer("output_max", torch.tensor(float("inf")))
    def forward(self, x):
        if self.use_clipped_linears:
            x = torch.clamp(x, self.input_min, self.input_max)
        out = nn.Linear.forward(self, x)
        if self.use_clipped_linears:
            out = torch.clamp(out, self.output_min, self.output_max)
        return out

modeling_gemma4.Gemma4ClippableLinear = PatchedClippableLinear

# bf16 base + LoRA fits in ~16 GB; switch to BitsAndBytesConfig for 4-bit if
# you're VRAM-constrained.
base = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-E4B-it",
    dtype=torch.bfloat16,
    device_map="auto",
)
tok = AutoTokenizer.from_pretrained("google/gemma-4-E4B-it")
model = PeftModel.from_pretrained(base, "Shubh-0789/gemma4-clinical-endpoint-extractor")
model.eval()

system_prompt = (
    "You are a clinical trial endpoint extraction system. "
    "Extract structured endpoint information from clinical trial text. "
    "Return ONLY valid JSON."
)
endpoint_text = (
    "Progression-Free Survival (PFS) assessed by RECIST v1.1 using CT scan at 24 weeks"
)
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content":
        f"Extract all endpoint fields from this clinical trial text. "
        f"Return ONLY a JSON.\n\nText:\n{endpoint_text}\n\nOutput (JSON only):"},
]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tok(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
print(tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))

For the best extraction quality, use the full system prompt this LoRA was trained with (10-field schema + measurement_method/evaluation_criteria taxonomy rules + few-shot examples). See Training Configuration below.

Verifying the LoRA is actually applied

If you're integrating this into a pipeline, a quick health check:

import torch
# After loading `model` as above, lora_B for at least one trained layer
# should be non-zero. (Default-init lora_B is zero, so any non-zero value
# means the trained adapter is in effect.)
key = "base_model.model.model.language_model.layers.10.self_attn.q_proj.lora_B.default.weight"
assert dict(model.named_parameters())[key].detach().abs().max() > 0

What changed in this release

Symptom you may have seen: the adapter loaded "successfully" but PeftModel.from_pretrained printed something like UserWarning: Found missing adapter keys while loading the checkpoint: ['base_model.model.model.audio_tower.layers.0.…', 'base_model.model.model.vision_tower.encoder.layers.0.…', …], and downstream extractions were noticeably less consistent than the README examples — because the LoRA effectively wasn't being applied.

Root cause. The previous snapshot's adapter_model.safetensors contained 884 trained tensors covering three towers of the Gemma 4 multimodal architecture: the language model (588 keys, 42 layers × 7 modules × 2 lora A/B), the audio tower (72 keys, 12 layers), and the vision tower (224 keys, 16 layers). The audio/vision-tower LoRAs were almost certainly an artifact of the trainer auto-targeting q_proj / k_proj / substring matches across the whole Gemma4ForConditionalGeneration, not deliberate. When loading against a fresh Gemma4ForConditionalGeneration for text-only inference, PEFT also injected LoRA modules under audio_tower / vision_tower (because the base has them too) and then mismatched the saved keys against the model's expected key names — silently producing a default-initialized adapter instead of the trained one.

Fix in this revision (no behavior change to text inference):

  • adapter_model.safetensors: dropped the 296 audio_tower / vision_tower LoRA keys; kept all 588 language_model.layers.{0..41} LoRA keys exactly as trained.
  • adapter_config.json: added an exclude_modules regex (.*\.(audio_tower|vision_tower)\..*) so PEFT skips the multimodal towers during injection.

After the fix, PeftModel.from_pretrained(...) loads with no warnings and the 516 trained LM tensors that map onto the standard text decoder match the file bit-for-bit (verified with safetensors.torch.load_file + torch.allclose).

Verified with: transformers==5.7.0, peft==0.19.1, torch==2.8.0+cu128 (CUDA 12.8, RTX 5090).

If you were pinning the previous snapshot, the old commit hash is still reachable via revision= — but its loading path was broken on peft >= 0.18, so we recommend updating to this snapshot and re-running any cached extractions you've already produced.

Model Details

Property Value
Base Model google/gemma-4-E4B-it
Method QLoRA (4-bit NF4 quantization + LoRA rank 16)
Trainable Parameters 42.4M (0.85% of 5B total)
Training Data 1,558 clinical trial endpoint samples
Data Sources ClinicalTrials.gov, EU CTR (EudraCT), ChiCTR (Chinese Clinical Trials)
Final Eval Loss 0.3006
Final Token Accuracy 94.07%
Training Time ~74 minutes on NVIDIA RTX A6000 (48GB)
License Apache 2.0

Training Results

Epoch Eval Loss Token Accuracy
1 0.3493 93.23%
2 0.3024 94.01%
3 0.3006 94.07%

Output Schema

The model outputs JSON with the following structure:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "string | null",
      "measurement_of": "string | null",
      "measurement_type": "continuous | binary | time-to-event | ordinal | null",
      "metric_type": "string | null",
      "timeframe": "string | null",
      "measurement_method": "string | null",
      "evaluation_criteria": "string | null",
      "unit": "string | null",
      "population": "string | null",
      "is_composite": "boolean",
      "components": "[]"
    }
  ]
}

Field Definitions

  • endpoint_name_standardized: Normalized endpoint name (e.g., ORR -> Objective Response Rate)
  • measurement_of: Underlying clinical concept being measured
  • measurement_type: One of continuous, binary, time-to-event, ordinal
  • metric_type: Statistical metric (mean, median, proportion, hazard ratio, etc.)
  • timeframe: Assessment window (e.g., "24 weeks", "Up to 5 years")
  • measurement_method: Physical tool/technique (CT scan, MRI, blood test) - NOT scoring systems
  • evaluation_criteria: Scoring system/guideline (RECIST v1.1, WHO criteria) - NOT measurement tools
  • unit: Measurement unit (%, mL, ms, ng/mL, etc.)
  • population: Study population if specified (ITT, safety population, etc.)
  • is_composite: Whether endpoint combines multiple events
  • components: List of component events for composite endpoints

Usage with 4-bit (memory-constrained)

If 16 GB of VRAM isn't available, load the base in 4-bit. (The compatibility patch above is still required.)

from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
base = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-E4B-it",
    quantization_config=quantization_config,
    device_map="auto",
    dtype=torch.bfloat16,
)

Example Outputs

Single Endpoint

Input: Number of participants achieving CR and PR by 4 months as per IWG 2023 criteria | [Time Frame: Through 4 months after starting treatment]

Output:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "Number of participants achieving CR and PR",
      "measurement_of": "Objective Response Rate",
      "measurement_type": "binary",
      "metric_type": "count",
      "timeframe": "Through 4 months after starting treatment",
      "measurement_method": null,
      "evaluation_criteria": "IWG 2023 criteria",
      "unit": "number",
      "population": null,
      "is_composite": true,
      "components": ["CR", "PR"]
    }
  ]
}

Composite Biomarker Endpoint

Input: Biomarkers (Phase II) | Integrated biomarker endpoints include: PTEN immunohistochemistry, estrogen receptor and progesterone receptor, whole exome sequencing, ribonucleic acid sequencing. | [Time Frame: Up to 5 years]

Output:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "Integrated Biomarker Association with PFS",
      "measurement_of": "prognostic association of integrated biomarkers with PFS",
      "measurement_type": "continuous",
      "metric_type": "hazard ratio",
      "timeframe": "Up to 5 years",
      "measurement_method": "Integrated biomarker analysis including PTEN immunohistochemistry, estrogen receptor and progesterone receptor, whole exome sequencing, ribonucleic acid sequencing",
      "evaluation_criteria": "Proportional hazards models",
      "unit": null,
      "population": null,
      "is_composite": true,
      "components": ["PTEN immunohistochemistry", "estrogen receptor", "progesterone receptor", "whole exome sequencing", "ribonucleic acid sequencing"]
    }
  ]
}

Minimal Input

Input: Body height

Output:

{
  "endpoints": [
    {
      "endpoint_name_standardized": "Body height",
      "measurement_of": "height",
      "measurement_type": "continuous",
      "metric_type": null,
      "timeframe": null,
      "measurement_method": "Stadiometer",
      "evaluation_criteria": null,
      "unit": "cm",
      "population": null,
      "is_composite": false,
      "components": []
    }
  ]
}

Training Configuration

Parameter Value
LoRA Rank (r) 16
LoRA Alpha 16
LoRA Dropout 0.05
Target Modules q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
Learning Rate 5e-5
LR Scheduler Cosine
Epochs 3
Batch Size 1 (x4 gradient accumulation)
Max Sequence Length 2048
Optimizer AdamW 8-bit
Quantization 4-bit NF4 with double quantization
Gradient Checkpointing Enabled
Warmup Steps 50
Total Steps 1,170

Training Data

The model was trained on 1,558 clinical trial endpoint samples (with 195 validation) sourced from three international registries:

Source Endpoints
ClinicalTrials.gov (USA) ~1,100
ChiCTR (China) ~500
EU CTR (Europe) ~350

Ground-truth labels were generated using Qwen 3.6-plus and the dataset covers:

  • Single and multiple endpoint extraction
  • Composite endpoint detection (19% of samples)
  • Various measurement types (continuous, binary, time-to-event, ordinal)
  • Multiple therapeutic areas (oncology, cardiology, neurology, etc.)

Hardware Requirements

Setup VRAM
bf16 inference (recommended) ~16 GB
QLoRA inference (4-bit) ~10 GB
QLoRA training ~28 GB

Limitations

  • Ground truth labels were generated by an LLM (Qwen 3.6-plus), not manually annotated
  • May occasionally hallucinate a second endpoint for single-endpoint inputs
  • measurement_method may be inferred even when not explicitly stated in the text
  • Primarily trained on English-language endpoint descriptions
  • The compatibility patch in Quick start is required for transformers >= 5.x; without it, PEFT cannot wrap Gemma 4's Gemma4ClippableLinear modules.

Citation

If you use this model, please cite:

@misc{gemma4-clinical-endpoint-extractor,
  title={Gemma 4 Clinical Trial Endpoint Extractor},
  author={Shubhanshu Yadav},
  year={2026},
  publisher={Hugging Face},
  url={https://huggingface.co/Shubh-0789/gemma4-clinical-endpoint-extractor}
}
Downloads last month
12
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for Shubh-0789/gemma4-clinical-endpoint-extractor

Adapter
(107)
this model

Evaluation results