gemma3-torchao-quant-sparse
A performance-optimized adaptation of the google/gemma-3-4b-it
multimodal LLM, which integrates memory and compute efficiency techniques while preserving high-quality generative capabilities for image-text tasks.
Key Features
1. INT8 Weight-Only Quantization
- Uses
torchao
’sInt8WeightOnlyConfig
for weight-only INT8 quantization. - Reduces VRAM usage significantly and speeds up inference.
- Maintains model output fidelity while lowering memory footprint.
2. Sparsity Techniques
2.1 Magnitude-Based Pruning
- Zeroes out the smallest weights in linear layers based on absolute magnitude.
- Safely reduces parameter count without altering layer shapes.
2.2 Gradual Pruning
- Incrementally increases sparsity over multiple steps (e.g., 1000 steps).
- Prevents sudden degradation in output quality.
- Ideal for safe exploration of sparsity ratios.
2.3 Layer-Norm-Based Pruning
- Prunes weights in linear layers with lowest L2 norms per layer.
- Extremely fast because it does not require forward passes.
- Preserves output quality and keeps layer shapes intact, ensuring seamless integration with the model.
2.4 Flexible Sparsity Ratios
- Supports low ratios (1–2%) for minimal impact and high ratios (up to 80%) for aggressive optimization.
2.5 Filter Map
- Excludes critical layers such as embeddings, normalization layers and output heads from pruning.
- Ensures that pruning does not break model outputs and preserves output quality.
3. Selective Torch Compile
- Only critical layers are compiled with
torch.compile
for faster execution. - Reduces compilation overhead while improving inference speed.
- Additional layers can be compiled selectively if required.
Benefits
- Significantly lower VRAM usage (~5–6 GB for 4B model at 500–750 tokens).
- Faster inference with Torch Compile + INT8 quantization.
- Safe sparsity techniques allow memory-efficient experiments without breaking the model.
- Maintains high-quality multimodal generation (image-to-text and text-to-text).