Llama-3.1-8B-qpal-df-tcq-3.25bit / qpal_modelling_llama.py
bdbj's picture
Upload folder using huggingface_hub
9504f3d verified
import torch
import torch.nn as nn
from typing import Any, Dict, Tuple
from transformers.models.llama.modeling_llama import LlamaForCausalLM as _BaseLlamaForCausalLM
from lib.linear.incoherent_linear import (
IncoherentLinear,
IncoherentMLP,
IncoherentSdpaAttention,
)
from lib.linear.vq_linear import VQLinearPackTensorCore, VQLinearPackSIMT
from lib.linear.tcq_linear import QTIPLinearTCQ
from lib.linear.comb_linear import CombLinearTCQ, CombtLinearTCQ
DTYPE_MAP: Dict[str, torch.dtype] = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
"float64": torch.float64,
}
def _dtype_from_str(name: Any) -> torch.dtype:
if isinstance(name, torch.dtype):
return name
key = str(name).replace("torch.", "")
if key not in DTYPE_MAP:
raise ValueError(f"Unsupported dtype: {name}")
return DTYPE_MAP[key]
def _dtype_to_str(dtype: torch.dtype) -> str:
for key, value in DTYPE_MAP.items():
if dtype == value:
return key
return str(dtype).replace("torch.", "")
def _get_child(module: Any, name: str) -> Any:
if name.isdigit():
return module[int(name)]
return getattr(module, name)
def _resolve_parent(root: nn.Module, path: str) -> Tuple[nn.Module, str]:
parts = path.split(".")
module = root
for name in parts[:-1]:
module = _get_child(module, name)
return module, parts[-1]
def _assign_child(module: Any, name: str, value: Any) -> None:
if name.isdigit():
module[int(name)] = value
else:
setattr(module, name, value)
def _serialize_quant_linear(linear: nn.Module) -> Dict[str, Any]:
meta: Dict[str, Any] = {
"linear_cls": linear.__class__.__name__,
"in_features": getattr(linear, "in_features", None),
"out_features": getattr(linear, "out_features", None),
"bias": getattr(linear, "bias", None) is not None,
}
if hasattr(linear, "dtype"):
meta["linear_dtype"] = _dtype_to_str(getattr(linear, "dtype"))
if isinstance(linear, VQLinearPackTensorCore):
meta.update({
"lut_bits": linear.lut_bits,
"vec_sz": linear.vec_sz,
})
elif isinstance(linear, VQLinearPackSIMT):
meta.update({
"lut_bits": linear.lut_bits,
"vec_sz": linear.vec_sz,
})
elif isinstance(linear, QTIPLinearTCQ):
meta.update({
"td_x": linear.td_x,
"td_y": linear.td_y,
"L": linear.L,
"KV": linear.KV,
"V": linear.V,
"tlut_bits": linear.tlut_bits,
})
elif isinstance(linear, CombLinearTCQ):
meta.update({
"td_x": linear.td_x,
"td_y": linear.td_y,
"out_part": list(linear.out_part),
"L": linear.L,
"KV": list(linear.KV),
"V": linear.V,
"tlut_bits": linear.tlut_bits,
})
elif isinstance(linear, CombtLinearTCQ):
meta.update({
"td_x": linear.td_x,
"td_y": linear.td_y,
"in_part": list(linear.in_part),
"L": linear.L,
"KV": list(linear.KV),
"V": linear.V,
"tlut_bits": linear.tlut_bits,
})
return meta
def _serialize_incoherent_linear(layer: IncoherentLinear) -> Dict[str, Any]:
meta: Dict[str, Any] = {
"module_type": "IncoherentLinear",
"in_features": layer.in_features,
"out_features": layer.out_features,
"hadU": layer.hadU,
"hadV": layer.hadV,
"bias": layer.bias is not None,
"dtype": _dtype_to_str(layer.dtype),
"rot_info": layer.rot_info,
"scale": float(layer.scale),
}
if layer.linear is not None:
meta["linear"] = _serialize_quant_linear(layer.linear)
return meta
def _serialize_incoherent_mlp(layer: IncoherentMLP) -> Dict[str, Any]:
projections: Dict[str, Dict[str, Any]] = {}
if layer.merge_ug and layer.ug_proj is not None:
projections["ug_proj"] = _serialize_quant_linear(layer.ug_proj)
else:
if layer.up_proj is not None:
projections["up_proj"] = _serialize_quant_linear(layer.up_proj)
if layer.gate_proj is not None:
projections["gate_proj"] = _serialize_quant_linear(layer.gate_proj)
if layer.down_proj is not None:
projections["down_proj"] = _serialize_quant_linear(layer.down_proj)
return {
"module_type": "IncoherentMLP",
"dtype": _dtype_to_str(layer.dtype),
"merge_ug": layer.merge_ug,
"scale": float(layer.scale),
"projections": projections,
}
def _serialize_incoherent_attention(layer: IncoherentSdpaAttention) -> Dict[str, Any]:
projections: Dict[str, Dict[str, Any]] = {}
for name in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"qk_proj",
"qv_proj",
"kv_proj",
"qkv_proj",
]:
proj = getattr(layer, name, None)
if proj is not None:
projections[name] = _serialize_quant_linear(proj)
return {
"module_type": "IncoherentSdpaAttention",
"dtype": _dtype_to_str(layer.dtype),
"scale": float(layer.scale),
"merge_qk": layer.merge_qk,
"merge_kv": layer.merge_kv,
"merge_qv": layer.merge_qv,
"merge_qkv": layer.merge_qkv,
"layer_idx": layer.layer_idx,
"projections": projections,
}
def export_qpal_quant_config(model: nn.Module) -> Dict[str, Any]:
modules: Dict[str, Any] = {}
for name, module in model.named_modules():
if isinstance(module, IncoherentLinear):
modules[name] = _serialize_incoherent_linear(module)
elif isinstance(module, IncoherentMLP):
modules[name] = _serialize_incoherent_mlp(module)
elif isinstance(module, IncoherentSdpaAttention):
modules[name] = _serialize_incoherent_attention(module)
return {"modules": modules}
def _build_quant_linear(meta: Dict[str, Any], dtype: torch.dtype, bias: bool) -> nn.Module:
cls_name = meta["linear_cls"]
if cls_name == "VQLinearPackTensorCore":
return VQLinearPackTensorCore(
meta["in_features"],
meta["out_features"],
meta["lut_bits"],
meta.get("vec_sz", 2),
bias=bias,
dtype=_dtype_from_str(meta.get("linear_dtype", dtype)),
)
if cls_name == "VQLinearPackSIMT":
return VQLinearPackSIMT(
meta["in_features"],
meta["out_features"],
meta["lut_bits"],
meta.get("vec_sz", 1),
bias=bias,
dtype=_dtype_from_str(meta.get("linear_dtype", dtype)),
)
if cls_name == "QTIPLinearTCQ":
return QTIPLinearTCQ(
meta["in_features"],
meta["out_features"],
meta["td_x"],
meta["td_y"],
meta["L"],
meta["KV"],
meta["V"],
meta["tlut_bits"],
bias=bias,
dtype=_dtype_from_str(meta.get("linear_dtype", dtype)),
)
if cls_name == "CombLinearTCQ":
return CombLinearTCQ(
meta["in_features"],
meta["out_features"],
meta["td_x"],
meta["td_y"],
meta["out_part"],
meta["L"],
meta["KV"],
meta["V"],
meta["tlut_bits"],
bias=bias,
dtype=_dtype_from_str(meta.get("linear_dtype", dtype)),
)
if cls_name == "CombtLinearTCQ":
return CombtLinearTCQ(
meta["in_features"],
meta["out_features"],
meta["td_x"],
meta["td_y"],
meta["in_part"],
meta["L"],
meta["KV"],
meta["V"],
meta["tlut_bits"],
bias=bias,
dtype=_dtype_from_str(meta.get("linear_dtype", dtype)),
)
if cls_name == "Linear":
return nn.Linear(
meta["in_features"],
meta["out_features"],
bias=bias,
dtype=_dtype_from_str(meta.get("linear_dtype", dtype)),
)
raise ValueError(f"Unsupported quantized linear class: {cls_name}")
def _instantiate_incoherent_linear(meta: Dict[str, Any]) -> IncoherentLinear:
dtype = _dtype_from_str(meta.get("dtype", "float16"))
bias = bool(meta.get("bias", False))
hadU = meta.get("hadU", meta["in_features"])
hadV = meta.get("hadV", meta["out_features"])
linear_meta = meta.get("linear")
linear_cls = linear_meta["linear_cls"] if linear_meta else None
use_linear = linear_cls is None
layer = IncoherentLinear(
meta["in_features"],
meta["out_features"],
hadU,
hadV,
bias=bias,
dtype=dtype,
use_linear=use_linear,
)
if linear_meta is not None:
layer.linear = _build_quant_linear(
linear_meta,
dtype,
bias=linear_meta.get("bias", False),
)
layer.rot_info = meta.get("rot_info", "all")
layer.scale = float(meta.get("scale", layer.scale))
layer.apply_rot_info()
return layer
def _instantiate_incoherent_mlp(meta: Dict[str, Any], config) -> IncoherentMLP:
dtype = _dtype_from_str(meta.get("dtype", "float16"))
mlp = IncoherentMLP(
config.hidden_size,
config.intermediate_size,
config.hidden_act,
merge_ug=meta.get("merge_ug", False),
dtype=dtype,
)
mlp.scale = float(meta.get("scale", mlp.scale))
projections = meta.get("projections", {})
if mlp.merge_ug:
proj_meta = projections.get("ug_proj")
if proj_meta is not None:
mlp.ug_proj = _build_quant_linear(
proj_meta,
_dtype_from_str(proj_meta.get("linear_dtype", dtype)),
bias=proj_meta.get("bias", False),
)
else:
up_meta = projections.get("up_proj")
gate_meta = projections.get("gate_proj")
if up_meta is not None:
mlp.up_proj = _build_quant_linear(
up_meta,
_dtype_from_str(up_meta.get("linear_dtype", dtype)),
bias=up_meta.get("bias", False),
)
if gate_meta is not None:
mlp.gate_proj = _build_quant_linear(
gate_meta,
_dtype_from_str(gate_meta.get("linear_dtype", dtype)),
bias=gate_meta.get("bias", False),
)
down_meta = projections.get("down_proj")
if down_meta is not None:
mlp.down_proj = _build_quant_linear(
down_meta,
_dtype_from_str(down_meta.get("linear_dtype", dtype)),
bias=down_meta.get("bias", False),
)
return mlp
def _instantiate_incoherent_attention(meta: Dict[str, Any], config) -> IncoherentSdpaAttention:
dtype = _dtype_from_str(meta.get("dtype", "float16"))
attn = IncoherentSdpaAttention(
config,
merge_qk=meta.get("merge_qk", False),
merge_kv=meta.get("merge_kv", False),
merge_qv=meta.get("merge_qv", False),
merge_qkv=meta.get("merge_qkv", False),
layer_idx=meta.get("layer_idx"),
dtype=dtype,
)
attn.scale = float(meta.get("scale", attn.scale))
projections = meta.get("projections", {})
for name, proj_meta in projections.items():
setattr(
attn,
name,
_build_quant_linear(
proj_meta,
_dtype_from_str(proj_meta.get("linear_dtype", dtype)),
bias=proj_meta.get("bias", False),
),
)
return attn
def apply_qpal_quantization(model: nn.Module, quant_config: Dict[str, Any]) -> None:
modules = quant_config.get("modules", {})
for path, meta in modules.items():
parent, attr = _resolve_parent(model, path)
module_type = meta.get("module_type", "IncoherentLinear")
if module_type == "IncoherentLinear":
replacement = _instantiate_incoherent_linear(meta)
elif module_type == "IncoherentMLP":
replacement = _instantiate_incoherent_mlp(meta, model.config)
elif module_type == "IncoherentSdpaAttention":
replacement = _instantiate_incoherent_attention(meta, model.config)
else:
raise ValueError(f"Unsupported module_type: {module_type} for {path}")
_assign_child(parent, attr, replacement)
class QPalLlamaForCausalLM(_BaseLlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
quant_config = getattr(config, "qpal_quant_config", None)
if quant_config:
apply_qpal_quantization(self, quant_config)
__all__ = [
"QPalLlamaForCausalLM",
"export_qpal_quant_config",
"apply_qpal_quantization",
]