gemma4-prometheus-merged
Prometheus-steered merged model โ
google/gemma-4-31B-itwith Prometheus adversarial-steering baked in and all adapter weights merged into the base model.
Related repositories
| Repo | Description |
|---|---|
| groxaxo/gemma4-prometheus-gptq-4bit | GPTQ-4bit quantized version of this model |
| groxaxo/gemma4-prometheus-workflow | Reproducible scripts, config, and checkpoint journal |
| groxaxo/gemma4-prometheus-fixes | All local patches applied to make this work |
| google/gemma-4-31B-it | Original base model |
What is this?
- Downloaded
google/gemma-4-31B-it(31 B parameters, BF16). - Ran Prometheus adversarial-steering optimization over 1 trial with 6 behaviors.
- Merged the best steering vectors back into the base model weights.
- Saved as a standalone, loadable model (no Prometheus runtime needed).
Size: ~58 GiB (two BF16 shards).
How to run
Minimum requirements
- 3 ร RTX 3090 (24 GB each) or any combination totalling โฅ 65 GiB VRAM
- On 2 GPUs (โค 48 GiB): load with BnB 8-bit (see below)
BF16 on 3 GPUs
from transformers import AutoModelForImageTextToText, AutoTokenizer
import torch
model_id = "groxaxo/gemma4-prometheus-merged"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
messages = [{"role": "user", "content": "Explain gradient descent."}]
text = tokenizer.apply_chat_template(
messages, tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
out = model.generate(ids, max_new_tokens=512, do_sample=False,
pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True))
BnB 8-bit on 2 GPUs (48 GiB)
from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
import torch
model_id = "groxaxo/gemma4-prometheus-merged"
bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
quantization_config=bnb_cfg,
device_map="auto",
max_memory={0: "23GiB", 1: "23GiB"},
)
model.eval()
Important: disable thinking tokens
Always pass enable_thinking=False to apply_chat_template โ otherwise the model
may emit <start_of_turn>think tokens and spend tokens on chain-of-thought.
Evaluation results
All tests run on 2 ร RTX 3090 (24 GB each) in pipeline-parallel mode
(device_map="auto" via ๐ค Accelerate). True tensor-parallelism (TP=2) requires
vLLM, which does not yet support the gemma4 architecture natively.
Coherence test (GPTQ-4bit model โ same backbone)
5 diverse ML questions answered. 5/5 passed (โฅ 15 words, on-topic).
| Prompt | Response excerpt | OK |
|---|---|---|
| Explain how neural networks learn from data. | "โฆa neural network learns by trial and error. It makes a guess, finds out how wrongโฆ" | โ |
| What is the difference between supervised and unsupervised learning? | "โฆIn supervised learning, the data is 'labeled'โฆIn unsupervised learning, the data is 'unlabeled'โฆ" | โ |
| Describe the concept of gradient descent in machine learning. | "โฆGradient Descent is an optimization algorithm used to minimize a functionโฆ" | โ |
| What are transformers in NLP? | "โฆa Transformer is a deep learning architectureโฆfocusing on the most important partsโฆ" | โ |
| Explain quantization for neural network models. | "โฆquantization is the process of reducing the precision of the numbersโฆ" | โ |
Context length
| KV Cache | Max Tokens | Bottleneck | Notes |
|---|---|---|---|
| FP16 | 6 144 | Attention compute O(nยฒ) | Without flash-attn, attention matrix = 32 heads ร nยฒ ร 2 B |
| FP8 (software) | 6 144 | Same โ attention compute | FP8 saves KV storage, not the attention matrix |
| FP16 + flash-attn (estimated) | ~113 000 | KV cache | Recommended: pip install flash-attn --no-build-isolation |
| FP8 + flash-attn (estimated) | ~226 000 | KV cache | Capped by max_position_embeddings = 262 144 |
Action: Installing
flash-attnwould increase usable context ~18ร.
Perplexity (WikiText-2, sliding window stride=512, 4096 tokens)
| Model | Perplexity | Notes |
|---|---|---|
| Merged (BnB-8bit reference) | 1782.3 | Chat model tested on raw text โ high PPL is expected |
| GPTQ-4bit | 1815.8 | +1.9% vs merged reference |
Chat-tuned models have high raw-text perplexity. The ฮPPL between variants is the meaningful signal: +1.9% degradation from 4-bit quantization.
KL divergence (GPTQ-4bit vs merged reference)
| Metric | Value |
|---|---|
| Direction | KL(merged_bnb8 โ gptq_4bit) |
| Mean KL | 4.77 nats |
| Std KL | 3.65 nats |
| Prompts | 8 ML-domain questions |
| Top-k tokens | 1000 |
Mean KL of ~4.77 nats reflects expected 4-bit quantization error relative to an 8-bit reference. Note: part of this KL is attributable to bnb-8bit noise in the reference; true KL vs FP16 merged would be somewhat lower.
Architecture notes
- 60 transformer layers alternating:
- Sliding-window attention (window=1024, 16 KV heads, head_dim=256)
- Full (global) attention (4 KV heads, head_dim=512)
- GQA with 32 query heads, 16/4 KV heads
- VLM wrapper (
model.language_model) โ vision tower present but text-only inference works
Patches applied
All source patches are documented at groxaxo/gemma4-prometheus-fixes.
Key fixes:
- Prometheus PEFT adapter targeting โ resolved exact module paths via
named_modules()traversal to prevent over-matching vision layers. - Prometheus steering FP16 โ defaulted steering vector compute dtype to FP16 (not FP32) to prevent VRAM OOM on quantized layers.
- gptqmodel Gemma4 support โ added
Gemma4QModeldefinition withlayer_modules_strict=False. - gptqmodel rotary embedding โ per-layer
position_embeddingsregeneration with correctlayer_type(sliding vs global).
Citation / acknowledgements
- Base model: google/gemma-4-31B-it
- Steering framework: Prometheus (local)
- Quantization: gptqmodel
- Downloads last month
- 806