gemma4-prometheus-merged

Prometheus-steered merged model โ€” google/gemma-4-31B-it with Prometheus adversarial-steering baked in and all adapter weights merged into the base model.

Related repositories

Repo Description
groxaxo/gemma4-prometheus-gptq-4bit GPTQ-4bit quantized version of this model
groxaxo/gemma4-prometheus-workflow Reproducible scripts, config, and checkpoint journal
groxaxo/gemma4-prometheus-fixes All local patches applied to make this work
google/gemma-4-31B-it Original base model

What is this?

  1. Downloaded google/gemma-4-31B-it (31 B parameters, BF16).
  2. Ran Prometheus adversarial-steering optimization over 1 trial with 6 behaviors.
  3. Merged the best steering vectors back into the base model weights.
  4. Saved as a standalone, loadable model (no Prometheus runtime needed).

Size: ~58 GiB (two BF16 shards).


How to run

Minimum requirements

  • 3 ร— RTX 3090 (24 GB each) or any combination totalling โ‰ฅ 65 GiB VRAM
  • On 2 GPUs (โ‰ค 48 GiB): load with BnB 8-bit (see below)

BF16 on 3 GPUs

from transformers import AutoModelForImageTextToText, AutoTokenizer
import torch

model_id = "groxaxo/gemma4-prometheus-merged"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.eval()

messages = [{"role": "user", "content": "Explain gradient descent."}]
text = tokenizer.apply_chat_template(
    messages, tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False,
)
ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
with torch.no_grad():
    out = model.generate(ids, max_new_tokens=512, do_sample=False,
                         pad_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True))

BnB 8-bit on 2 GPUs (48 GiB)

from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
import torch

model_id = "groxaxo/gemma4-prometheus-merged"

bnb_cfg = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_cfg,
    device_map="auto",
    max_memory={0: "23GiB", 1: "23GiB"},
)
model.eval()

Important: disable thinking tokens

Always pass enable_thinking=False to apply_chat_template โ€” otherwise the model may emit <start_of_turn>think tokens and spend tokens on chain-of-thought.


Evaluation results

All tests run on 2 ร— RTX 3090 (24 GB each) in pipeline-parallel mode (device_map="auto" via ๐Ÿค— Accelerate). True tensor-parallelism (TP=2) requires vLLM, which does not yet support the gemma4 architecture natively.

Coherence test (GPTQ-4bit model โ€” same backbone)

5 diverse ML questions answered. 5/5 passed (โ‰ฅ 15 words, on-topic).

Prompt Response excerpt OK
Explain how neural networks learn from data. "โ€ฆa neural network learns by trial and error. It makes a guess, finds out how wrongโ€ฆ" โœ…
What is the difference between supervised and unsupervised learning? "โ€ฆIn supervised learning, the data is 'labeled'โ€ฆIn unsupervised learning, the data is 'unlabeled'โ€ฆ" โœ…
Describe the concept of gradient descent in machine learning. "โ€ฆGradient Descent is an optimization algorithm used to minimize a functionโ€ฆ" โœ…
What are transformers in NLP? "โ€ฆa Transformer is a deep learning architectureโ€ฆfocusing on the most important partsโ€ฆ" โœ…
Explain quantization for neural network models. "โ€ฆquantization is the process of reducing the precision of the numbersโ€ฆ" โœ…

Context length

KV Cache Max Tokens Bottleneck Notes
FP16 6 144 Attention compute O(nยฒ) Without flash-attn, attention matrix = 32 heads ร— nยฒ ร— 2 B
FP8 (software) 6 144 Same โ€” attention compute FP8 saves KV storage, not the attention matrix
FP16 + flash-attn (estimated) ~113 000 KV cache Recommended: pip install flash-attn --no-build-isolation
FP8 + flash-attn (estimated) ~226 000 KV cache Capped by max_position_embeddings = 262 144

Action: Installing flash-attn would increase usable context ~18ร—.

Perplexity (WikiText-2, sliding window stride=512, 4096 tokens)

Model Perplexity Notes
Merged (BnB-8bit reference) 1782.3 Chat model tested on raw text โ€” high PPL is expected
GPTQ-4bit 1815.8 +1.9% vs merged reference

Chat-tuned models have high raw-text perplexity. The ฮ”PPL between variants is the meaningful signal: +1.9% degradation from 4-bit quantization.

KL divergence (GPTQ-4bit vs merged reference)

Metric Value
Direction KL(merged_bnb8 โ€– gptq_4bit)
Mean KL 4.77 nats
Std KL 3.65 nats
Prompts 8 ML-domain questions
Top-k tokens 1000

Mean KL of ~4.77 nats reflects expected 4-bit quantization error relative to an 8-bit reference. Note: part of this KL is attributable to bnb-8bit noise in the reference; true KL vs FP16 merged would be somewhat lower.


Architecture notes

  • 60 transformer layers alternating:
    • Sliding-window attention (window=1024, 16 KV heads, head_dim=256)
    • Full (global) attention (4 KV heads, head_dim=512)
  • GQA with 32 query heads, 16/4 KV heads
  • VLM wrapper (model.language_model) โ€” vision tower present but text-only inference works

Patches applied

All source patches are documented at groxaxo/gemma4-prometheus-fixes.

Key fixes:

  1. Prometheus PEFT adapter targeting โ€” resolved exact module paths via named_modules() traversal to prevent over-matching vision layers.
  2. Prometheus steering FP16 โ€” defaulted steering vector compute dtype to FP16 (not FP32) to prevent VRAM OOM on quantized layers.
  3. gptqmodel Gemma4 support โ€” added Gemma4QModel definition with layer_modules_strict=False.
  4. gptqmodel rotary embedding โ€” per-layer position_embeddings regeneration with correct layer_type (sliding vs global).

Citation / acknowledgements

Downloads last month
806
Safetensors
Model size
31B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for groxaxo/gemma4-prometheus-merged

Finetuned
(54)
this model
Quantizations
1 model