| import torch | |
| import torch.nn as nn | |
| from typing import Any, Dict, Tuple | |
| from transformers.models.llama.modeling_llama import LlamaForCausalLM as _BaseLlamaForCausalLM | |
| from lib.linear.incoherent_linear import ( | |
| IncoherentLinear, | |
| IncoherentMLP, | |
| IncoherentSdpaAttention, | |
| ) | |
| from lib.linear.vq_linear import VQLinearPackTensorCore, VQLinearPackSIMT | |
| from lib.linear.tcq_linear import QTIPLinearTCQ | |
| from lib.linear.comb_linear import CombLinearTCQ, CombtLinearTCQ | |
| DTYPE_MAP: Dict[str, torch.dtype] = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| "float64": torch.float64, | |
| } | |
| def _dtype_from_str(name: Any) -> torch.dtype: | |
| if isinstance(name, torch.dtype): | |
| return name | |
| key = str(name).replace("torch.", "") | |
| if key not in DTYPE_MAP: | |
| raise ValueError(f"Unsupported dtype: {name}") | |
| return DTYPE_MAP[key] | |
| def _dtype_to_str(dtype: torch.dtype) -> str: | |
| for key, value in DTYPE_MAP.items(): | |
| if dtype == value: | |
| return key | |
| return str(dtype).replace("torch.", "") | |
| def _get_child(module: Any, name: str) -> Any: | |
| if name.isdigit(): | |
| return module[int(name)] | |
| return getattr(module, name) | |
| def _resolve_parent(root: nn.Module, path: str) -> Tuple[nn.Module, str]: | |
| parts = path.split(".") | |
| module = root | |
| for name in parts[:-1]: | |
| module = _get_child(module, name) | |
| return module, parts[-1] | |
| def _assign_child(module: Any, name: str, value: Any) -> None: | |
| if name.isdigit(): | |
| module[int(name)] = value | |
| else: | |
| setattr(module, name, value) | |
| def _serialize_quant_linear(linear: nn.Module) -> Dict[str, Any]: | |
| meta: Dict[str, Any] = { | |
| "linear_cls": linear.__class__.__name__, | |
| "in_features": getattr(linear, "in_features", None), | |
| "out_features": getattr(linear, "out_features", None), | |
| "bias": getattr(linear, "bias", None) is not None, | |
| } | |
| if hasattr(linear, "dtype"): | |
| meta["linear_dtype"] = _dtype_to_str(getattr(linear, "dtype")) | |
| if isinstance(linear, VQLinearPackTensorCore): | |
| meta.update({ | |
| "lut_bits": linear.lut_bits, | |
| "vec_sz": linear.vec_sz, | |
| }) | |
| elif isinstance(linear, VQLinearPackSIMT): | |
| meta.update({ | |
| "lut_bits": linear.lut_bits, | |
| "vec_sz": linear.vec_sz, | |
| }) | |
| elif isinstance(linear, QTIPLinearTCQ): | |
| meta.update({ | |
| "td_x": linear.td_x, | |
| "td_y": linear.td_y, | |
| "L": linear.L, | |
| "KV": linear.KV, | |
| "V": linear.V, | |
| "tlut_bits": linear.tlut_bits, | |
| }) | |
| elif isinstance(linear, CombLinearTCQ): | |
| meta.update({ | |
| "td_x": linear.td_x, | |
| "td_y": linear.td_y, | |
| "out_part": list(linear.out_part), | |
| "L": linear.L, | |
| "KV": list(linear.KV), | |
| "V": linear.V, | |
| "tlut_bits": linear.tlut_bits, | |
| }) | |
| elif isinstance(linear, CombtLinearTCQ): | |
| meta.update({ | |
| "td_x": linear.td_x, | |
| "td_y": linear.td_y, | |
| "in_part": list(linear.in_part), | |
| "L": linear.L, | |
| "KV": list(linear.KV), | |
| "V": linear.V, | |
| "tlut_bits": linear.tlut_bits, | |
| }) | |
| return meta | |
| def _serialize_incoherent_linear(layer: IncoherentLinear) -> Dict[str, Any]: | |
| meta: Dict[str, Any] = { | |
| "module_type": "IncoherentLinear", | |
| "in_features": layer.in_features, | |
| "out_features": layer.out_features, | |
| "hadU": layer.hadU, | |
| "hadV": layer.hadV, | |
| "bias": layer.bias is not None, | |
| "dtype": _dtype_to_str(layer.dtype), | |
| "rot_info": layer.rot_info, | |
| "scale": float(layer.scale), | |
| } | |
| if layer.linear is not None: | |
| meta["linear"] = _serialize_quant_linear(layer.linear) | |
| return meta | |
| def _serialize_incoherent_mlp(layer: IncoherentMLP) -> Dict[str, Any]: | |
| projections: Dict[str, Dict[str, Any]] = {} | |
| if layer.merge_ug and layer.ug_proj is not None: | |
| projections["ug_proj"] = _serialize_quant_linear(layer.ug_proj) | |
| else: | |
| if layer.up_proj is not None: | |
| projections["up_proj"] = _serialize_quant_linear(layer.up_proj) | |
| if layer.gate_proj is not None: | |
| projections["gate_proj"] = _serialize_quant_linear(layer.gate_proj) | |
| if layer.down_proj is not None: | |
| projections["down_proj"] = _serialize_quant_linear(layer.down_proj) | |
| return { | |
| "module_type": "IncoherentMLP", | |
| "dtype": _dtype_to_str(layer.dtype), | |
| "merge_ug": layer.merge_ug, | |
| "scale": float(layer.scale), | |
| "projections": projections, | |
| } | |
| def _serialize_incoherent_attention(layer: IncoherentSdpaAttention) -> Dict[str, Any]: | |
| projections: Dict[str, Dict[str, Any]] = {} | |
| for name in [ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "qk_proj", | |
| "qv_proj", | |
| "kv_proj", | |
| "qkv_proj", | |
| ]: | |
| proj = getattr(layer, name, None) | |
| if proj is not None: | |
| projections[name] = _serialize_quant_linear(proj) | |
| return { | |
| "module_type": "IncoherentSdpaAttention", | |
| "dtype": _dtype_to_str(layer.dtype), | |
| "scale": float(layer.scale), | |
| "merge_qk": layer.merge_qk, | |
| "merge_kv": layer.merge_kv, | |
| "merge_qv": layer.merge_qv, | |
| "merge_qkv": layer.merge_qkv, | |
| "layer_idx": layer.layer_idx, | |
| "projections": projections, | |
| } | |
| def export_qpal_quant_config(model: nn.Module) -> Dict[str, Any]: | |
| modules: Dict[str, Any] = {} | |
| for name, module in model.named_modules(): | |
| if isinstance(module, IncoherentLinear): | |
| modules[name] = _serialize_incoherent_linear(module) | |
| elif isinstance(module, IncoherentMLP): | |
| modules[name] = _serialize_incoherent_mlp(module) | |
| elif isinstance(module, IncoherentSdpaAttention): | |
| modules[name] = _serialize_incoherent_attention(module) | |
| return {"modules": modules} | |
| def _build_quant_linear(meta: Dict[str, Any], dtype: torch.dtype, bias: bool) -> nn.Module: | |
| cls_name = meta["linear_cls"] | |
| if cls_name == "VQLinearPackTensorCore": | |
| return VQLinearPackTensorCore( | |
| meta["in_features"], | |
| meta["out_features"], | |
| meta["lut_bits"], | |
| meta.get("vec_sz", 2), | |
| bias=bias, | |
| dtype=_dtype_from_str(meta.get("linear_dtype", dtype)), | |
| ) | |
| if cls_name == "VQLinearPackSIMT": | |
| return VQLinearPackSIMT( | |
| meta["in_features"], | |
| meta["out_features"], | |
| meta["lut_bits"], | |
| meta.get("vec_sz", 1), | |
| bias=bias, | |
| dtype=_dtype_from_str(meta.get("linear_dtype", dtype)), | |
| ) | |
| if cls_name == "QTIPLinearTCQ": | |
| return QTIPLinearTCQ( | |
| meta["in_features"], | |
| meta["out_features"], | |
| meta["td_x"], | |
| meta["td_y"], | |
| meta["L"], | |
| meta["KV"], | |
| meta["V"], | |
| meta["tlut_bits"], | |
| bias=bias, | |
| dtype=_dtype_from_str(meta.get("linear_dtype", dtype)), | |
| ) | |
| if cls_name == "CombLinearTCQ": | |
| return CombLinearTCQ( | |
| meta["in_features"], | |
| meta["out_features"], | |
| meta["td_x"], | |
| meta["td_y"], | |
| meta["out_part"], | |
| meta["L"], | |
| meta["KV"], | |
| meta["V"], | |
| meta["tlut_bits"], | |
| bias=bias, | |
| dtype=_dtype_from_str(meta.get("linear_dtype", dtype)), | |
| ) | |
| if cls_name == "CombtLinearTCQ": | |
| return CombtLinearTCQ( | |
| meta["in_features"], | |
| meta["out_features"], | |
| meta["td_x"], | |
| meta["td_y"], | |
| meta["in_part"], | |
| meta["L"], | |
| meta["KV"], | |
| meta["V"], | |
| meta["tlut_bits"], | |
| bias=bias, | |
| dtype=_dtype_from_str(meta.get("linear_dtype", dtype)), | |
| ) | |
| if cls_name == "Linear": | |
| return nn.Linear( | |
| meta["in_features"], | |
| meta["out_features"], | |
| bias=bias, | |
| dtype=_dtype_from_str(meta.get("linear_dtype", dtype)), | |
| ) | |
| raise ValueError(f"Unsupported quantized linear class: {cls_name}") | |
| def _instantiate_incoherent_linear(meta: Dict[str, Any]) -> IncoherentLinear: | |
| dtype = _dtype_from_str(meta.get("dtype", "float16")) | |
| bias = bool(meta.get("bias", False)) | |
| hadU = meta.get("hadU", meta["in_features"]) | |
| hadV = meta.get("hadV", meta["out_features"]) | |
| linear_meta = meta.get("linear") | |
| linear_cls = linear_meta["linear_cls"] if linear_meta else None | |
| use_linear = linear_cls is None | |
| layer = IncoherentLinear( | |
| meta["in_features"], | |
| meta["out_features"], | |
| hadU, | |
| hadV, | |
| bias=bias, | |
| dtype=dtype, | |
| use_linear=use_linear, | |
| ) | |
| if linear_meta is not None: | |
| layer.linear = _build_quant_linear( | |
| linear_meta, | |
| dtype, | |
| bias=linear_meta.get("bias", False), | |
| ) | |
| layer.rot_info = meta.get("rot_info", "all") | |
| layer.scale = float(meta.get("scale", layer.scale)) | |
| layer.apply_rot_info() | |
| return layer | |
| def _instantiate_incoherent_mlp(meta: Dict[str, Any], config) -> IncoherentMLP: | |
| dtype = _dtype_from_str(meta.get("dtype", "float16")) | |
| mlp = IncoherentMLP( | |
| config.hidden_size, | |
| config.intermediate_size, | |
| config.hidden_act, | |
| merge_ug=meta.get("merge_ug", False), | |
| dtype=dtype, | |
| ) | |
| mlp.scale = float(meta.get("scale", mlp.scale)) | |
| projections = meta.get("projections", {}) | |
| if mlp.merge_ug: | |
| proj_meta = projections.get("ug_proj") | |
| if proj_meta is not None: | |
| mlp.ug_proj = _build_quant_linear( | |
| proj_meta, | |
| _dtype_from_str(proj_meta.get("linear_dtype", dtype)), | |
| bias=proj_meta.get("bias", False), | |
| ) | |
| else: | |
| up_meta = projections.get("up_proj") | |
| gate_meta = projections.get("gate_proj") | |
| if up_meta is not None: | |
| mlp.up_proj = _build_quant_linear( | |
| up_meta, | |
| _dtype_from_str(up_meta.get("linear_dtype", dtype)), | |
| bias=up_meta.get("bias", False), | |
| ) | |
| if gate_meta is not None: | |
| mlp.gate_proj = _build_quant_linear( | |
| gate_meta, | |
| _dtype_from_str(gate_meta.get("linear_dtype", dtype)), | |
| bias=gate_meta.get("bias", False), | |
| ) | |
| down_meta = projections.get("down_proj") | |
| if down_meta is not None: | |
| mlp.down_proj = _build_quant_linear( | |
| down_meta, | |
| _dtype_from_str(down_meta.get("linear_dtype", dtype)), | |
| bias=down_meta.get("bias", False), | |
| ) | |
| return mlp | |
| def _instantiate_incoherent_attention(meta: Dict[str, Any], config) -> IncoherentSdpaAttention: | |
| dtype = _dtype_from_str(meta.get("dtype", "float16")) | |
| attn = IncoherentSdpaAttention( | |
| config, | |
| merge_qk=meta.get("merge_qk", False), | |
| merge_kv=meta.get("merge_kv", False), | |
| merge_qv=meta.get("merge_qv", False), | |
| merge_qkv=meta.get("merge_qkv", False), | |
| layer_idx=meta.get("layer_idx"), | |
| dtype=dtype, | |
| ) | |
| attn.scale = float(meta.get("scale", attn.scale)) | |
| projections = meta.get("projections", {}) | |
| for name, proj_meta in projections.items(): | |
| setattr( | |
| attn, | |
| name, | |
| _build_quant_linear( | |
| proj_meta, | |
| _dtype_from_str(proj_meta.get("linear_dtype", dtype)), | |
| bias=proj_meta.get("bias", False), | |
| ), | |
| ) | |
| return attn | |
| def apply_qpal_quantization(model: nn.Module, quant_config: Dict[str, Any]) -> None: | |
| modules = quant_config.get("modules", {}) | |
| for path, meta in modules.items(): | |
| parent, attr = _resolve_parent(model, path) | |
| module_type = meta.get("module_type", "IncoherentLinear") | |
| if module_type == "IncoherentLinear": | |
| replacement = _instantiate_incoherent_linear(meta) | |
| elif module_type == "IncoherentMLP": | |
| replacement = _instantiate_incoherent_mlp(meta, model.config) | |
| elif module_type == "IncoherentSdpaAttention": | |
| replacement = _instantiate_incoherent_attention(meta, model.config) | |
| else: | |
| raise ValueError(f"Unsupported module_type: {module_type} for {path}") | |
| _assign_child(parent, attr, replacement) | |
| class QPalLlamaForCausalLM(_BaseLlamaForCausalLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| quant_config = getattr(config, "qpal_quant_config", None) | |
| if quant_config: | |
| apply_qpal_quantization(self, quant_config) | |
| __all__ = [ | |
| "QPalLlamaForCausalLM", | |
| "export_qpal_quant_config", | |
| "apply_qpal_quantization", | |
| ] | |