Instructions to use Shubh-0789/gemma4-clinical-endpoint-extractor with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use Shubh-0789/gemma4-clinical-endpoint-extractor with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("google/gemma-4-E4B-it") model = PeftModel.from_pretrained(base_model, "Shubh-0789/gemma4-clinical-endpoint-extractor") - Notebooks
- Google Colab
- Kaggle
Gemma 4 Clinical Trial Endpoint Extractor
A fine-tuned Gemma 4 E4B-it model (LoRA adapter) for extracting structured endpoint information from clinical trial text. The model takes unstructured endpoint descriptions and outputs structured JSON following a predefined schema.
⚠️ Loading fix (this revision)
The earlier snapshot of this adapter was saved with multimodal-wrapper layer paths and also carried trained LoRAs for the vision and audio towers. With recent
peft(≥ 0.18) this causedPeftModel.from_pretrainedto log aFound missing adapter keyswarning and silently fall back to a default-initialized LoRA — the base instruction-tuned model would still produce JSON that looked schema-conformant (the system prompt is detailed enough), so the regression was easy to miss in spot checks.This revision rewrites
adapter_model.safetensorsandadapter_config.jsonsoPeftModel.from_pretrained(base, "Shubh-0789/gemma4-clinical-endpoint-extractor")loads with 0 warnings and applies the trained weights to all 42 language-model decoder layers as expected. The training data, hyperparameters and behaviour are unchanged — seeQuick start (fixed adapter)below for the up-to-date load snippet, andWhat changed in this releasefurther down for technical detail.
Quick start (fixed adapter)
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gemma4 import modeling_gemma4
from peft import PeftModel
# Required compatibility patch: Gemma 4's `Gemma4ClippableLinear` wraps
# nn.Linear and so cannot be wrapped by PEFT directly. Replace it with an
# nn.Linear-derived class BEFORE loading the base model. Behavior is
# preserved for layers where `use_clipped_linears=False` (the default).
class PatchedClippableLinear(nn.Linear):
def __init__(self, config, in_features, out_features):
nn.Linear.__init__(self, in_features, out_features, bias=False)
self.use_clipped_linears = getattr(config, "use_clipped_linears", False)
if self.use_clipped_linears:
self.register_buffer("input_min", torch.tensor(-float("inf")))
self.register_buffer("input_max", torch.tensor(float("inf")))
self.register_buffer("output_min", torch.tensor(-float("inf")))
self.register_buffer("output_max", torch.tensor(float("inf")))
def forward(self, x):
if self.use_clipped_linears:
x = torch.clamp(x, self.input_min, self.input_max)
out = nn.Linear.forward(self, x)
if self.use_clipped_linears:
out = torch.clamp(out, self.output_min, self.output_max)
return out
modeling_gemma4.Gemma4ClippableLinear = PatchedClippableLinear
# bf16 base + LoRA fits in ~16 GB; switch to BitsAndBytesConfig for 4-bit if
# you're VRAM-constrained.
base = AutoModelForCausalLM.from_pretrained(
"google/gemma-4-E4B-it",
dtype=torch.bfloat16,
device_map="auto",
)
tok = AutoTokenizer.from_pretrained("google/gemma-4-E4B-it")
model = PeftModel.from_pretrained(base, "Shubh-0789/gemma4-clinical-endpoint-extractor")
model.eval()
system_prompt = (
"You are a clinical trial endpoint extraction system. "
"Extract structured endpoint information from clinical trial text. "
"Return ONLY valid JSON."
)
endpoint_text = (
"Progression-Free Survival (PFS) assessed by RECIST v1.1 using CT scan at 24 weeks"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content":
f"Extract all endpoint fields from this clinical trial text. "
f"Return ONLY a JSON.\n\nText:\n{endpoint_text}\n\nOutput (JSON only):"},
]
prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
print(tok.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
For the best extraction quality, use the full system prompt this LoRA was trained with (10-field schema + measurement_method/evaluation_criteria taxonomy rules + few-shot examples). See Training Configuration below.
Verifying the LoRA is actually applied
If you're integrating this into a pipeline, a quick health check:
import torch
# After loading `model` as above, lora_B for at least one trained layer
# should be non-zero. (Default-init lora_B is zero, so any non-zero value
# means the trained adapter is in effect.)
key = "base_model.model.model.language_model.layers.10.self_attn.q_proj.lora_B.default.weight"
assert dict(model.named_parameters())[key].detach().abs().max() > 0
What changed in this release
Symptom you may have seen: the adapter loaded "successfully" but PeftModel.from_pretrained printed something like UserWarning: Found missing adapter keys while loading the checkpoint: ['base_model.model.model.audio_tower.layers.0.…', 'base_model.model.model.vision_tower.encoder.layers.0.…', …], and downstream extractions were noticeably less consistent than the README examples — because the LoRA effectively wasn't being applied.
Root cause. The previous snapshot's adapter_model.safetensors contained 884 trained tensors covering three towers of the Gemma 4 multimodal architecture: the language model (588 keys, 42 layers × 7 modules × 2 lora A/B), the audio tower (72 keys, 12 layers), and the vision tower (224 keys, 16 layers). The audio/vision-tower LoRAs were almost certainly an artifact of the trainer auto-targeting q_proj / k_proj / … substring matches across the whole Gemma4ForConditionalGeneration, not deliberate. When loading against a fresh Gemma4ForConditionalGeneration for text-only inference, PEFT also injected LoRA modules under audio_tower / vision_tower (because the base has them too) and then mismatched the saved keys against the model's expected key names — silently producing a default-initialized adapter instead of the trained one.
Fix in this revision (no behavior change to text inference):
adapter_model.safetensors: dropped the 296audio_tower/vision_towerLoRA keys; kept all 588language_model.layers.{0..41}LoRA keys exactly as trained.adapter_config.json: added anexclude_modulesregex (.*\.(audio_tower|vision_tower)\..*) so PEFT skips the multimodal towers during injection.
After the fix, PeftModel.from_pretrained(...) loads with no warnings and the 516 trained LM tensors that map onto the standard text decoder match the file bit-for-bit (verified with safetensors.torch.load_file + torch.allclose).
Verified with: transformers==5.7.0, peft==0.19.1, torch==2.8.0+cu128 (CUDA 12.8, RTX 5090).
If you were pinning the previous snapshot, the old commit hash is still reachable via revision= — but its loading path was broken on peft >= 0.18, so we recommend updating to this snapshot and re-running any cached extractions you've already produced.
Model Details
| Property | Value |
|---|---|
| Base Model | google/gemma-4-E4B-it |
| Method | QLoRA (4-bit NF4 quantization + LoRA rank 16) |
| Trainable Parameters | 42.4M (0.85% of 5B total) |
| Training Data | 1,558 clinical trial endpoint samples |
| Data Sources | ClinicalTrials.gov, EU CTR (EudraCT), ChiCTR (Chinese Clinical Trials) |
| Final Eval Loss | 0.3006 |
| Final Token Accuracy | 94.07% |
| Training Time | ~74 minutes on NVIDIA RTX A6000 (48GB) |
| License | Apache 2.0 |
Training Results
| Epoch | Eval Loss | Token Accuracy |
|---|---|---|
| 1 | 0.3493 | 93.23% |
| 2 | 0.3024 | 94.01% |
| 3 | 0.3006 | 94.07% |
Output Schema
The model outputs JSON with the following structure:
{
"endpoints": [
{
"endpoint_name_standardized": "string | null",
"measurement_of": "string | null",
"measurement_type": "continuous | binary | time-to-event | ordinal | null",
"metric_type": "string | null",
"timeframe": "string | null",
"measurement_method": "string | null",
"evaluation_criteria": "string | null",
"unit": "string | null",
"population": "string | null",
"is_composite": "boolean",
"components": "[]"
}
]
}
Field Definitions
- endpoint_name_standardized: Normalized endpoint name (e.g., ORR -> Objective Response Rate)
- measurement_of: Underlying clinical concept being measured
- measurement_type: One of
continuous,binary,time-to-event,ordinal - metric_type: Statistical metric (mean, median, proportion, hazard ratio, etc.)
- timeframe: Assessment window (e.g., "24 weeks", "Up to 5 years")
- measurement_method: Physical tool/technique (CT scan, MRI, blood test) - NOT scoring systems
- evaluation_criteria: Scoring system/guideline (RECIST v1.1, WHO criteria) - NOT measurement tools
- unit: Measurement unit (%, mL, ms, ng/mL, etc.)
- population: Study population if specified (ITT, safety population, etc.)
- is_composite: Whether endpoint combines multiple events
- components: List of component events for composite endpoints
Usage with 4-bit (memory-constrained)
If 16 GB of VRAM isn't available, load the base in 4-bit. (The compatibility patch above is still required.)
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base = AutoModelForCausalLM.from_pretrained(
"google/gemma-4-E4B-it",
quantization_config=quantization_config,
device_map="auto",
dtype=torch.bfloat16,
)
Example Outputs
Single Endpoint
Input: Number of participants achieving CR and PR by 4 months as per IWG 2023 criteria | [Time Frame: Through 4 months after starting treatment]
Output:
{
"endpoints": [
{
"endpoint_name_standardized": "Number of participants achieving CR and PR",
"measurement_of": "Objective Response Rate",
"measurement_type": "binary",
"metric_type": "count",
"timeframe": "Through 4 months after starting treatment",
"measurement_method": null,
"evaluation_criteria": "IWG 2023 criteria",
"unit": "number",
"population": null,
"is_composite": true,
"components": ["CR", "PR"]
}
]
}
Composite Biomarker Endpoint
Input: Biomarkers (Phase II) | Integrated biomarker endpoints include: PTEN immunohistochemistry, estrogen receptor and progesterone receptor, whole exome sequencing, ribonucleic acid sequencing. | [Time Frame: Up to 5 years]
Output:
{
"endpoints": [
{
"endpoint_name_standardized": "Integrated Biomarker Association with PFS",
"measurement_of": "prognostic association of integrated biomarkers with PFS",
"measurement_type": "continuous",
"metric_type": "hazard ratio",
"timeframe": "Up to 5 years",
"measurement_method": "Integrated biomarker analysis including PTEN immunohistochemistry, estrogen receptor and progesterone receptor, whole exome sequencing, ribonucleic acid sequencing",
"evaluation_criteria": "Proportional hazards models",
"unit": null,
"population": null,
"is_composite": true,
"components": ["PTEN immunohistochemistry", "estrogen receptor", "progesterone receptor", "whole exome sequencing", "ribonucleic acid sequencing"]
}
]
}
Minimal Input
Input: Body height
Output:
{
"endpoints": [
{
"endpoint_name_standardized": "Body height",
"measurement_of": "height",
"measurement_type": "continuous",
"metric_type": null,
"timeframe": null,
"measurement_method": "Stadiometer",
"evaluation_criteria": null,
"unit": "cm",
"population": null,
"is_composite": false,
"components": []
}
]
}
Training Configuration
| Parameter | Value |
|---|---|
| LoRA Rank (r) | 16 |
| LoRA Alpha | 16 |
| LoRA Dropout | 0.05 |
| Target Modules | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
| Learning Rate | 5e-5 |
| LR Scheduler | Cosine |
| Epochs | 3 |
| Batch Size | 1 (x4 gradient accumulation) |
| Max Sequence Length | 2048 |
| Optimizer | AdamW 8-bit |
| Quantization | 4-bit NF4 with double quantization |
| Gradient Checkpointing | Enabled |
| Warmup Steps | 50 |
| Total Steps | 1,170 |
Training Data
The model was trained on 1,558 clinical trial endpoint samples (with 195 validation) sourced from three international registries:
| Source | Endpoints |
|---|---|
| ClinicalTrials.gov (USA) | ~1,100 |
| ChiCTR (China) | ~500 |
| EU CTR (Europe) | ~350 |
Ground-truth labels were generated using Qwen 3.6-plus and the dataset covers:
- Single and multiple endpoint extraction
- Composite endpoint detection (19% of samples)
- Various measurement types (continuous, binary, time-to-event, ordinal)
- Multiple therapeutic areas (oncology, cardiology, neurology, etc.)
Hardware Requirements
| Setup | VRAM |
|---|---|
| bf16 inference (recommended) | ~16 GB |
| QLoRA inference (4-bit) | ~10 GB |
| QLoRA training | ~28 GB |
Limitations
- Ground truth labels were generated by an LLM (Qwen 3.6-plus), not manually annotated
- May occasionally hallucinate a second endpoint for single-endpoint inputs
- measurement_method may be inferred even when not explicitly stated in the text
- Primarily trained on English-language endpoint descriptions
- The compatibility patch in Quick start is required for
transformers >= 5.x; without it, PEFT cannot wrap Gemma 4'sGemma4ClippableLinearmodules.
Citation
If you use this model, please cite:
@misc{gemma4-clinical-endpoint-extractor,
title={Gemma 4 Clinical Trial Endpoint Extractor},
author={Shubhanshu Yadav},
year={2026},
publisher={Hugging Face},
url={https://huggingface.co/Shubh-0789/gemma4-clinical-endpoint-extractor}
}
- Downloads last month
- 12
Model tree for Shubh-0789/gemma4-clinical-endpoint-extractor
Evaluation results
- Eval Loss (Epoch 3)self-reported0.301
- Token Accuracy (Epoch 3)self-reported0.941