🚀 PyTorch — Beyond Quantization: Membawa Inferensi Sparse ke PyTorch


📌 Problem Statement
1. Model bahasa besar (LLM) sangat revolusioner, namun biaya komputasi dan konsumsi energi untuk inferensi tetap sangat tinggi.
2. Strategi optimasi yang umum—seperti quantization—sudah mendekati batas efektivitasnya, terutama untuk perangkat edge yang daya/hard-resource terbatas.
3. Diperlukan langkah selanjutnya selain quantization: yaitu sparsity (pengurangan bobot atau neuron aktif) agar inferensi menjadi jauh lebih efisien.

🛠️ Methodology / Solusi / Hypothesis
1. PyTorch & tim riset di NimbleEdge sedang membangun framework unified untuk inferensi sparse di PyTorch.
2. Pendekatan utama:
a. Observasi bahwa pada beberapa LLM lama seperti OPT (Meta), lebih dari 95%–99% bobot MLP blok tidak aktif untuk input rata-rata—menunjukkan potensi sparsity besar.
b. Teknik “Deja Vu” untuk prediksi neuron aktif: misalnya dekomposisi rendah-rank untuk gate matrix → hanya 4-10% hidden size yang dieksekusi, latency kecil, akurasi tetap tinggi.
c. Memperkenalkan thresholding modern (CATS, CETT) untuk model dengan aktivasi panjang (SiLU/GeLU) yang mengurangi sparsity bila dibandingkan model dengan ReLU.
d. Implementasi operator caching bobot dalam PyTorch yang memuat hanya perbedaan antara mask aktif berturut-turut → mengurangi operasi memory-bound seperti index_select. Eksperimen menunjukkan kecepatan hingga ~6.7× lebih cepat untuk operasi index_select.
3. Hipotesis: Dengan menggabungkan sparsity prediktif + caching bobot + kernel hardware-aware, inferensi LLM bisa menjadi sangat efisien—memungkinkan deployment edge, bukan hanya data-centre.

📊 Findings / Results / Impact
1. Teknik Deja Vu pada model OPT menunjukkan 2× hingga 6× speed-up dalam inferensi dengan penurunan akurasi yang sangat kecil.
2. Melalui operator caching bobot dan strategi mask yang efisien, PyTorch tim melaporkan index_select time dari ~29.89 ms ke ~4.46 ms dalam eksperimen — sekitar 6.7× peningkatan.
3. Walau aktivasi panjang SiLU/GeLU menurunkan sparsity jika thresholding kasar digunakan, solusi seperti CETT berhasil mendapatkan >60% sparsity pada model seperti Llama-2 7B/Mistral 7B tanpa fine-tuning besar-besaran.
4. Implikasi nyata: inferensi LLM yang sebelumnya hanya layak untuk pusat data besar, mulai bisa dipertimbangkan di perangkat edge atau setup dengan daya terbatas.

🧩 How to Use (Contoh Praktis)
1. Jika Anda menggunakan PyTorch untuk inferensi LLM dan ingin eksplorasi sparsity:
import torch
from torch import nn
model = … # load LLM model
# setelah terdapat modul MLP, Anda bisa menerapkan sparse mask
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and “mlp” in name:
mask = compute_sparse_mask(module.weight, sparsity=0.5) # contoh
module.weight.data *= mask
# lalu lakukan inference seperti biasa
outputs = model(input_ids)
(Catatan: ini eksperimental — gunakan bersama profiling dan monitor akurasi.)
2. Untuk caching bobot prediktif:
Identifikasi modul yang sering aktif (“neuron hot”) pada input Anda.
Implementasikan caching bobot sehingga hanya bobot yang berubah antar token yang dipanggil ulang.
Gunakan library sparse pendukung (lihat repositori NimbleEdge/sparse_transformers terkait).
3. Jalankan profiling: bandingkan inference time, memory usage antara versi full vs sparse pada model dan hardware Anda.
# dengan tensor-profiling di PyTorch
torch.profiler.profile(…, record_shapes=True, with_stack=True) as prof:
model(input_ids)
print(prof.key_averages().table(sort_by=”self_cpu_time_total”))

✅ Key Takeaways
1. Sparsity adalah langkah berikutnya setelah quantization untuk inferensi LLM yang efisien.
2. Implementasi yang baik tidak hanya menghapus bobot, tetapi meliputi prediksi neuron aktif, caching bobot, dan kernel yang dioptimalkan.
3. Framework seperti PyTorch mulai menyediakan dukungan untuk inferensi sparse production-grade — artinya edge deployment LLM makin realistis.
4. Tim AI/ML sebaiknya mulai menguji sparsity sebagai bagian dari optimasi inferensi—termasuk pengaruh latency, memory, dan akurasi.
5. Untuk organisasi yang ingin menurunkan biaya inferensi dan memperluas jangkauan model ke perangkat yang lebih ringan, sparsity harus menjadi bagian dari strategi.

Sumber:
https://pytorch.org/blog/beyond-quantization-bringing-sparse-inference-to-pytorch/

🔥 #PyTorch #SparseInference #LLM #ModelOptimization #EdgeAI #Sparsity #Quantization #NimbleEdge #EfficientAI

Leave a Comment