diff --git a/examples/apple/coreml/llama/export_static_llm_coreml.py b/examples/apple/coreml/llama/export_static_llm_coreml.py index a3fd8201414..658717c25da 100644 --- a/examples/apple/coreml/llama/export_static_llm_coreml.py +++ b/examples/apple/coreml/llama/export_static_llm_coreml.py @@ -27,6 +27,8 @@ import torch.nn as nn import torch.utils._pytree as pytree +from typing import Optional + from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner from executorch.examples.apple.coreml.llama.utils import ( @@ -89,20 +91,58 @@ def forward(self, *args, **kwargs): return out -def remove_graph_break_(edge_manager): +def remove_graph_break_(edge_manager, method_names=None): from executorch.exir.dialects._ops import ops as exir_ops - for n in edge_manager.exported_program().graph_module.graph.nodes: - if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor: - n.replace_all_uses_with(n.args[0]) - edge_manager.exported_program().graph_module.graph.eliminate_dead_code() - + if method_names is None: + method_names = [None] # Default behavior for single method -def load_model(checkpoint_path: str, params_path: str, max_context_len: int): - """Load the model from checkpoint with static_mha attention type.""" + for method_name in method_names: + if method_name is None: + ep = edge_manager.exported_program() + else: + ep = edge_manager.exported_program(method_name) + + for n in ep.graph_module.graph.nodes: + if n.target == exir_ops.edge.executorch_utils.graph_break.Tensor: + n.replace_all_uses_with(n.args[0]) + ep.graph_module.graph.eliminate_dead_code() + + +def load_model( + checkpoint_path: str, + params_path: str, + max_context_len: int, + adapter_checkpoint_path: Optional[str] = None, + adapter_config_path: Optional[str] = None, +): + """Load the model from checkpoint with static_mha attention type. + + Args: + checkpoint_path: Path to model checkpoint (.pth) + params_path: Path to params.json + max_context_len: Maximum context length + adapter_checkpoint_path: Optional path to LoRA adapter weights (adapter_model.safetensors) + adapter_config_path: Optional path to adapter config (adapter_config.json) + """ with open(params_path, "r") as f: params = json.loads(f.read()) + assert (adapter_config_path is None and adapter_checkpoint_path is None) or (adapter_config_path is not None and adapter_checkpoint_path is not None) + + # Load adapter config if provided + adapter_config = None + if adapter_config_path is not None: + with open(adapter_config_path, "r") as f: + adapter_config = json.loads(f.read()) + print(f"Loaded adapter config: rank={adapter_config.get('r')}, alpha={adapter_config.get('lora_alpha')}") + print(f"Target modules: {adapter_config.get('target_modules')}") + + # Merge adapter config into params + params["r"] = adapter_config.get("r") + params["lora_alpha"] = adapter_config.get("lora_alpha") + params["target_modules"] = adapter_config.get("target_modules") + # TODO: to support lookahead decoding, the static model outputs # full logits, but if we are not using lookahead decoding, we can have a # more efficient model by setting generate_full_logits=False and supplying the last @@ -124,8 +164,23 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int): if "model" in checkpoint: checkpoint = checkpoint["model"] + # Load and merge adapter weights if provided + if adapter_checkpoint_path is not None: + from safetensors.torch import load_file + from executorch.examples.models.llama.convert_weights import unsloth_to_meta + + adapter_weights = load_file(adapter_checkpoint_path) + # Convert adapter weight keys to Meta format + adapter_weights = unsloth_to_meta(adapter_weights) + print(f"Loaded {len(adapter_weights)} adapter weights") + + # Merge adapter weights into checkpoint + checkpoint.update(adapter_weights) + # Rename attention weight keys for static attention + # This handles both base weights and LoRA weights for i in range(len(model.layers)): + # Base weights if f"layers.{i}.attention.wq.weight" in checkpoint: checkpoint[f"layers.{i}.attention.wqs.0.weight"] = checkpoint.pop( f"layers.{i}.attention.wq.weight" @@ -139,6 +194,21 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int): f"layers.{i}.attention.wv.weight" ) + # LoRA weights (lora_a and lora_b) + for lora_suffix in ["lora_a.weight", "lora_b.weight"]: + if f"layers.{i}.attention.wq.{lora_suffix}" in checkpoint: + checkpoint[f"layers.{i}.attention.wqs.0.{lora_suffix}"] = checkpoint.pop( + f"layers.{i}.attention.wq.{lora_suffix}" + ) + if f"layers.{i}.attention.wk.{lora_suffix}" in checkpoint: + checkpoint[f"layers.{i}.attention.wks.0.{lora_suffix}"] = checkpoint.pop( + f"layers.{i}.attention.wk.{lora_suffix}" + ) + if f"layers.{i}.attention.wv.{lora_suffix}" in checkpoint: + checkpoint[f"layers.{i}.attention.wvs.0.{lora_suffix}"] = checkpoint.pop( + f"layers.{i}.attention.wv.{lora_suffix}" + ) + missing, unexpected = model.load_state_dict( checkpoint, strict=False, @@ -152,6 +222,124 @@ def load_model(checkpoint_path: str, params_path: str, max_context_len: int): return model, args +def prepare_model( + model: nn.Module, + model_args: ModelArgs, + float_dtype: torch.dtype, + target_split_size: int, + max_splits: int, + embedding_quantize: str, + linear_quantize: str, + no_graph_breaks: bool, +): + """Apply dtype, splitting, quantization, and graph breaks to a model. + + Args: + model: The model to prepare + model_args: Model arguments + float_dtype: Target dtype (torch.float16 or torch.float32) + target_split_size: Target size for linear layer splitting + max_splits: Maximum number of splits for linear layers + embedding_quantize: Embedding quantization string (e.g., "8,0") + linear_quantize: Linear quantization type ("b4w" or "c4w") + no_graph_breaks: If True, skip adding graph breaks + + Returns: + The prepared model + """ + # Set dtype + model = model.to(float_dtype).eval() + + # Apply linear splitting (before quantization) + if target_split_size is not None: + replace_linear_with_split_linear( + model, + out_target_split_size=target_split_size, + out_max_splits=max_splits, + in_target_split_size=1, + in_max_splits=1, + ) + + def make_linear_filter_fn(group_size=0): + """Create a filter function for linear quantization. + + Args: + group_size: Group size for quantization. 0 means per-axis (no constraint). + """ + def filter_fn(m, fqn): + # Check if it's a regular nn.Linear + is_linear = isinstance(m, nn.Linear) + + # Check if it's a LoRALinear (which has a base weight parameter to quantize) + is_lora_linear = False + try: + from executorch.examples.models.llama.lora import LoRALinear + is_lora_linear = isinstance(m, LoRALinear) + except ImportError: + pass + + if not (is_linear or is_lora_linear): + return False + + # For per-axis (group_size=0), no shape constraint + if group_size == 0: + return True + + # Check if the weight shape is compatible with group size + return m.weight.shape[1] % group_size == 0 + + return filter_fn + + # Apply embedding quantization + if embedding_quantize: + bitwidth, group_size = embedding_quantize.split(",") + bitwidth = int(bitwidth) + group_size = int(group_size) + assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + + if group_size == 0: + granularity = PerAxis(0) + else: + granularity = PerGroup(group_size) + weight_dtype = getattr(torch, f"int{bitwidth}") + + quantize_( + model, + IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + + # Apply linear quantization + if linear_quantize == "b4w": + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + ), + filter_fn=make_linear_filter_fn(group_size=32), + ) + elif linear_quantize == "c4w": + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerAxis(0), + ), + filter_fn=make_linear_filter_fn(group_size=0), + ) + + # Add graph breaks between transformer blocks + if not no_graph_breaks: + n_layers = len(model.layers) + model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) + model.layers[n_layers - 1] = BlockWithGraphBreak( + model.layers[n_layers - 1], break_before=False + ) + + return model + + def _get_metadata(model_args, example_inputs, input_len, cache_len, float_dtype): """ Generate metadata methods for the C++ runner. @@ -263,6 +451,25 @@ def main(): help="Output filename for the .pte model", ) + # LoRA adapter options + parser.add_argument( + "--adapter_checkpoint", + type=str, + default=None, + help="Path to LoRA adapter weights (adapter_model.safetensors)", + ) + parser.add_argument( + "--adapter_config", + type=str, + default=None, + help="Path to adapter config (adapter_config.json)", + ) + parser.add_argument( + "--multimethod", + action="store_true", + help="Export both base and LoRA models as separate methods ('base' and 'lora') in one PTE file", + ) + # Model configuration parser.add_argument( "--max_context_len", @@ -340,110 +547,159 @@ def main(): print(f"\tMax splits: {args.max_splits}") # Load model - print(f"\nLoading model from {args.checkpoint}...") - model, model_args = load_model( - args.checkpoint, - args.params, - args.max_context_len, - ) - print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") - - # Set dtype float_dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype] - model = model.to(float_dtype).eval() - # Apply linear splitting (before quantization) - if args.target_split_size is not None: - print(f"\nSplitting linear layers with target size {args.target_split_size}...") - replace_linear_with_split_linear( - model, - out_target_split_size=args.target_split_size, - out_max_splits=args.max_splits, - in_target_split_size=1, - in_max_splits=1, - ) + if args.multimethod: + # Multimethod export: create both base and LoRA models + if not args.adapter_checkpoint or not args.adapter_config: + raise ValueError("--multimethod requires --adapter_checkpoint and --adapter_config") - # Apply embedding quantization - if args.embedding_quantize: - bitwidth, group_size = args.embedding_quantize.split(",") - bitwidth = int(bitwidth) - group_size = int(group_size) - assert bitwidth in [4, 8], "CoreML only supports 4-bit and 8-bit quantization" + print(f"\n[Multimethod Export] Loading base model from {args.checkpoint}...") + base_model, model_args = load_model( + args.checkpoint, + args.params, + args.max_context_len, + ) + print(f"Base model loaded: {model_args.n_layers} layers, {model_args.dim} dim") + + print(f"\n[Multimethod Export] Loading LoRA model from {args.checkpoint}...") + print(f" with adapter from {args.adapter_checkpoint}") + lora_model, _ = load_model( + args.checkpoint, + args.params, + args.max_context_len, + args.adapter_checkpoint, + args.adapter_config, + ) + print("LoRA model loaded") + + # Prepare both models + print("\n[Multimethod Export] Preparing base model...") + base_model = prepare_model( + base_model, + model_args, + float_dtype, + args.target_split_size, + args.max_splits, + args.embedding_quantize, + args.linear_quantize, + args.no_graph_breaks, + ) - print(f"\nQuantizing embeddings: {bitwidth}-bit, group_size={group_size}...") - if group_size == 0: - granularity = PerAxis(0) - else: - granularity = PerGroup(group_size) - weight_dtype = getattr(torch, f"int{bitwidth}") + print("\n[Multimethod Export] Preparing LoRA model...") + lora_model = prepare_model( + lora_model, + model_args, + float_dtype, + args.target_split_size, + args.max_splits, + args.embedding_quantize, + args.linear_quantize, + args.no_graph_breaks, + ) - quantize_( - model, - IntxWeightOnlyConfig(weight_dtype=weight_dtype, granularity=granularity), - lambda m, fqn: isinstance(m, torch.nn.Embedding), + # Create IO manager and example inputs (shared for both models) + mgr = StaticAttentionIOManager( + model_args, + input_len=args.input_len, + cache_lens=cache_len, + batch_size=1, + dtype=float_dtype, + style="smart_mask", + mask_val=float("-inf"), + ) + example_inputs = ( + torch.zeros(1, args.input_len, dtype=torch.int32), + { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[: args.input_len], + "freqs_sin_override": mgr.freqs_sin[: args.input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + }, ) - # Apply linear quantization - if args.linear_quantize == "b4w": - print("\nQuantizing linear layers: 4-bit blockwise (group_size=32)...") - quantize_( - model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerGroup(32), - ), + # Test eager execution for both models + print("\n[Multimethod Export] Testing eager execution...") + with torch.no_grad(): + base_model(*example_inputs) + lora_model(*example_inputs) + print("Eager execution successful for both models!") + + # Export both models + print("\n[Multimethod Export] Exporting base model...") + base_ep = torch.export.export(base_model, example_inputs, strict=False) + print("Base model export successful!") + + print("\n[Multimethod Export] Exporting LoRA model...") + lora_ep = torch.export.export(lora_model, example_inputs, strict=False) + print("LoRA model export successful!") + + # Use dictionary of exported programs for multimethod + exported_programs = { + "base": base_ep, + "lora": lora_ep, + } + else: + # Single method export (original behavior) + print(f"\nLoading model from {args.checkpoint}...") + if args.adapter_checkpoint: + print(f"Loading LoRA adapter from {args.adapter_checkpoint}...") + model, model_args = load_model( + args.checkpoint, + args.params, + args.max_context_len, + args.adapter_checkpoint, + args.adapter_config, ) - elif args.linear_quantize == "c4w": - print("\nQuantizing linear layers: 4-bit channelwise...") - quantize_( + print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim") + + # Prepare model + print("\nPreparing model...") + model = prepare_model( model, - IntxWeightOnlyConfig( - weight_dtype=torch.int4, - granularity=PerAxis(0), - ), + model_args, + float_dtype, + args.target_split_size, + args.max_splits, + args.embedding_quantize, + args.linear_quantize, + args.no_graph_breaks, ) - # Add graph breaks between transformer blocks - # Keeping model pieces smaller helps with ANE performance - if not args.no_graph_breaks: - print("\nAdding graph breaks between before/after the transformer blocks...") - n_layers = len(model.layers) - model.layers[0] = BlockWithGraphBreak(model.layers[0], break_before=True) - model.layers[n_layers - 1] = BlockWithGraphBreak( - model.layers[n_layers - 1], break_before=False + # Create IO manager and example inputs + mgr = StaticAttentionIOManager( + model_args, + input_len=args.input_len, + cache_lens=cache_len, + batch_size=1, + dtype=float_dtype, + style="smart_mask", + mask_val=float("-inf"), + ) + example_inputs = ( + torch.zeros(1, args.input_len, dtype=torch.int32), + { + "masks": mgr.masks, + "freqs_cos_override": mgr.freqs_cos[: args.input_len], + "freqs_sin_override": mgr.freqs_sin[: args.input_len], + "in_cache_state": (mgr.k_caches, mgr.v_caches), + }, ) - # Create IO manager and example inputs - mgr = StaticAttentionIOManager( - model_args, - input_len=args.input_len, - cache_lens=cache_len, - batch_size=1, - dtype=float_dtype, - style="smart_mask", # Use smart_mask to match C++ StaticTransformerRunner - mask_val=float("-inf"), - ) - example_inputs = ( - torch.zeros(1, args.input_len, dtype=torch.int32), - { - "masks": mgr.masks, - "freqs_cos_override": mgr.freqs_cos[: args.input_len], - "freqs_sin_override": mgr.freqs_sin[: args.input_len], - "in_cache_state": (mgr.k_caches, mgr.v_caches), - }, - ) + # Test eager execution + print("\nTesting eager execution...") + with torch.no_grad(): + model(*example_inputs) + print("Eager execution successful!") - # Test eager execution - print("\nTesting eager execution...") - with torch.no_grad(): - model(*example_inputs) - print("Eager execution successful!") + # Export the model + print("\nExporting model...") + ep = torch.export.export(model, example_inputs, strict=False) + print("Export successful!") + print(ep) - # Export the model - print("\nExporting model...") - ep = torch.export.export(model, example_inputs) - print("Export successful!") - print(ep) + # Use single exported program + exported_programs = ep # Generate metadata for C++ runner print("\nGenerating metadata for C++ runner...") @@ -472,18 +728,27 @@ def main(): print("\nLowering to edge...") edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) edge_manager = to_edge_transform_and_lower( - ep, + exported_programs, partitioner=[partitioner], constant_methods=constant_methods, compile_config=edge_compile_config, ) - print("\nDelegated program:") - print(format_delegated_graph(edge_manager.exported_program().graph_module)) + if args.multimethod: + print("\nDelegated programs:") + for method_name in ["base", "lora"]: + print(f"\n--- {method_name} ---") + print(format_delegated_graph(edge_manager.exported_program(method_name).graph_module)) + else: + print("\nDelegated program:") + print(format_delegated_graph(edge_manager.exported_program().graph_module)) # Convert to ExecuTorch print("\nConverting to ExecuTorch...") - remove_graph_break_(edge_manager) + if args.multimethod: + remove_graph_break_(edge_manager, method_names=["base", "lora"]) + else: + remove_graph_break_(edge_manager) executorch_program = edge_manager.to_executorch( ExecutorchBackendConfig( extract_delegate_segments=True, @@ -498,6 +763,8 @@ def main(): # Save the program filename = save_pte_program(executorch_program, args.output) print(f"\nSaved ExecuTorch program to {filename}") + if args.multimethod: + print("Methods available: 'base', 'lora'") if __name__ == "__main__": diff --git a/examples/apple/coreml/llama/run_static_llm.py b/examples/apple/coreml/llama/run_static_llm.py index 2cd526aec42..17fe3ef7fa1 100644 --- a/examples/apple/coreml/llama/run_static_llm.py +++ b/examples/apple/coreml/llama/run_static_llm.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. """ -Run script for static attention Llama models exported with coreml_static_llama.py. +Run script for static attention LLM models exported with export_static_llm_coreml.py. Usage: python run_static_llm.py \ @@ -21,7 +21,6 @@ import time from typing import Any, Dict, List, Tuple -import sentencepiece as spm import torch import torch.utils._pytree as pytree @@ -29,50 +28,14 @@ from executorch.examples.models.llama.runner.generation import next_token from executorch.examples.models.llama.static_attention import StaticAttentionIOManager from executorch.runtime import Runtime +from pytorch_tokenizers import get_tokenizer -class Tokenizer: - """Wrapper to support both SentencePiece and Tiktoken tokenizers.""" - - def __init__(self, model_path: str): - try: - print("Trying to load sentencepiece") - sp = spm.SentencePieceProcessor() - sp.load(model_path) - self.tokenizer = sp - self._is_sentencepiece = True - except Exception: - print("Trying to load tiktoken") - from executorch.examples.models.llama.tokenizer import tiktoken - - self.tokenizer = tiktoken.Tokenizer(model_path) - self._is_sentencepiece = False - - def encode(self, text: str, bos: bool = True, eos: bool = False) -> List[int]: - if self._is_sentencepiece: - bos_string = "" if bos else "" - eos_string = "" if eos else "" - return self.tokenizer.encode(f"{bos_string}{text}{eos_string}") - return self.tokenizer.encode(text, bos=bos, eos=eos) - - def decode(self, tokens: List[int]) -> str: - if self._is_sentencepiece: - return self.tokenizer.decode(tokens) - return self.tokenizer.decode(tokens) - - def decode_token(self, token: int) -> str: - if self._is_sentencepiece: - return self.tokenizer.decode([token]) - try: - return self.tokenizer.decode_token(token) - except UnicodeDecodeError: - return f"<{token}>" - - @property - def stop_tokens(self) -> List[int]: - if self._is_sentencepiece: - return [self.tokenizer.eos_id()] - return self.tokenizer.stop_tokens +def get_stop_tokens(tokenizer) -> List[int]: + """Get stop tokens from tokenizer, falling back to eos_id if not available.""" + if hasattr(tokenizer, "stop_tokens"): + return tokenizer.stop_tokens + return [tokenizer.eos_id] def create_pte_wrapper( @@ -131,6 +94,12 @@ def main(): required=True, help="Path to exported .pte model", ) + parser.add_argument( + "--method", + type=str, + default="forward", + help="Method name to run (default: 'forward', use 'base' or 'lora' for multimethod models)", + ) parser.add_argument( "-p", "--params", @@ -143,6 +112,12 @@ def main(): required=True, help="Path to tokenizer model", ) + parser.add_argument( + "--tokenizer_config", + type=str, + default=None, + help="Path to tokenizer config (required for HuggingFace tokenizers)", + ) parser.add_argument( "--prompt", type=str, @@ -206,7 +181,8 @@ def main(): args = parser.parse_args() # Load tokenizer - tokenizer = Tokenizer(args.tokenizer) + tokenizer = get_tokenizer(args.tokenizer, args.tokenizer_config) + stop_tokens = get_stop_tokens(tokenizer) # Load model params with open(args.params, "r") as f: @@ -238,7 +214,8 @@ def main(): print(f"Loading model from {args.model}...") runtime = Runtime.get() program = runtime.load_program(args.model) - method = program.load_method("forward") + print(f"Loading method '{args.method}'...") + method = program.load_method(args.method) metadata = method.metadata print( @@ -291,7 +268,7 @@ def main(): ngram_size=args.ngram_size, window_size=args.window_size, n_verifications=args.n_verifications, - stop_tokens=tokenizer.stop_tokens, + stop_tokens=stop_tokens, ) else: # Use standard autoregressive decoding @@ -299,12 +276,12 @@ def main(): model_fn, first_token, n=args.max_new_tokens - 1, # -1 because first_token counts - stop_tokens=tokenizer.stop_tokens, + stop_tokens=stop_tokens, ) # Print generated tokens (skip first as it's the init_token we already printed) for token in generated_tokens[1:]: - if token in tokenizer.stop_tokens: + if token in stop_tokens: break print(tokenizer.decode_token(token), end="", flush=True) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index a0e9eb70498..33b15cd181d 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -117,7 +117,7 @@ class ModelArgs: lora_args: Optional[dict] = None # LoRA arguments to set up a LoRA inference model. - # These arguments come directly from a torchtune adapter_config.json file. + # These arguments come directly from a torchtune/unsloth adapter_config.json file. r: Optional[int] = None # Rank. lora_alpha: Optional[int] = None # Alpha. # Modules that we can apply lora adapters to. diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 9eef4413a63..1c355c98a77 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -13,6 +13,7 @@ ForwardOptions, register_attention, ) +from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope @@ -784,22 +785,43 @@ def __init__( # Possibly disable in future, depending on bug fixes in Core ML runtime self.decompose_sdpa_in_mha: bool = kwargs.get("decompose_sdpa_in_mha", False) + # LoRA configuration + self.target_modules = config.target_modules + self.lora_rank = config.r + self.lora_alpha = config.lora_alpha + if self.target_modules: + assert self.lora_rank is not None and self.lora_alpha is not None + + def _make_linear(in_dim: int, out_dim: int, bias: bool, lora_target: str) -> nn.Module: + """Create a linear layer with optional LoRA support.""" + if self.target_modules is not None and lora_target in self.target_modules: + # assert self.lora_rank is not None and self.lora_alpha is not None + return LoRALinear( + in_dim=in_dim, + out_dim=out_dim, + rank=self.lora_rank, + alpha=self.lora_alpha, + dropout=0.0, + use_bias=bias, + ) + return nn.Linear(in_dim, out_dim, bias=bias) + if self.split_mha: self.wqs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + _make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "q_proj") for _ in range(self.n_heads) ] ) self.wks = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + _make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "k_proj") for _ in range(self.n_kv_heads) ] ) self.wvs = nn.ModuleList( [ - nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias) + _make_linear(self.dim, self.head_dim, self.attention_qkv_bias, "v_proj") for _ in range(self.n_kv_heads) ] ) @@ -813,28 +835,31 @@ def __init__( else: self.wqs = nn.ModuleList( [ - nn.Linear( + _make_linear( self.dim, self.head_dim * self.n_heads, - bias=self.attention_qkv_bias, + self.attention_qkv_bias, + "q_proj", ) ] ) self.wks = nn.ModuleList( [ - nn.Linear( + _make_linear( self.dim, self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, + self.attention_qkv_bias, + "k_proj", ) ] ) self.wvs = nn.ModuleList( [ - nn.Linear( + _make_linear( self.dim, self.head_dim * self.n_kv_heads, - bias=self.attention_qkv_bias, + self.attention_qkv_bias, + "v_proj", ) ] ) @@ -842,7 +867,7 @@ def __init__( self.k_caches = nn.ModuleList([StaticKCache(layer_id, 0)]) self.v_caches = nn.ModuleList([StaticVCache(layer_id, 0)]) - self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + self.wo = _make_linear(self.n_heads * self.head_dim, self.dim, False, "o_proj") self.rope = _Rope(rope.params) self.layer_id = layer_id