diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index bd4a5fb071e..a81ff080289 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -503,6 +503,42 @@ def annotate_scalar_tensor(node: Node, quantization_config: QuantizationConfig) @register_annotator([torch.ops.aten.tanh.default]) def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) + qconfig_16a = get_16a16w_qnn_ptq_config() + qmax, qmin = ( + qconfig_16a.output_activation.quant_max, + qconfig_16a.output_activation.quant_min, + ) + if ( + quantization_config.output_activation.quant_max == qmax + and quantization_config.output_activation.quant_min == qmin + and _is_float_tensor(node) + ): + scale = 1 / 32768.0 + zero_point = 32768 + if isinstance( + quantization_config.output_activation.observer_or_fake_quant_ctr, + torch.ao.quantization.fake_quantize.FakeQuantizeBase, + ): + observer_ctr = FixedQParamsFakeQuantize + else: + observer_ctr = FixedQParamsObserver + observer = observer_ctr.with_args( + scale=scale, + zero_point=zero_point, + dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=quantization_config.output_activation.quant_max, + quant_min=quantization_config.output_activation.quant_min, + ) + + annotate_output_qspec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=quantization_config.output_activation.quant_max, + quant_min=quantization_config.output_activation.quant_min, + observer_or_fake_quant_ctr=observer, + qscheme=torch.torch.per_tensor_affine, + ) + node.meta["quantization_annotation"].output_qspec = annotate_output_qspec @register_annotator([torch.ops.aten.full_like.default, torch.ops.aten.full.default]) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 0843c6f1133..f06cc727f98 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -6071,6 +6071,9 @@ def setUp(self): "gemma-2b": TestExampleLLMScript.LlmSpecs( SM8650=32, SM8750=36, ppl=35, pte_size=2_700_000_000 ), # 2.7 GB + "gemma2-2b": TestExampleLLMScript.LlmSpecs( + SM8650=32, SM8750=36, ppl=14, pte_size=2_860_000_000 + ), # 2.86 GB "gemma3-1b": TestExampleLLMScript.LlmSpecs( SM8650=70, SM8750=100, ppl=23, pte_size=1_200_000_000 ), # 1.2 GB diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index a0e9eb70498..83e3df1e2f8 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -134,6 +134,9 @@ class ModelArgs: model_architecture: Optional[str] = ( None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now. ) + # gemma2 attn and output soft capping + final_logit_softcapping: Optional[float] = None + attn_logit_softcapping: Optional[float] = None def __post_init__(self): if self.n_kv_heads is None: diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index 33868eda6d1..80748642db2 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -11,6 +11,7 @@ This file provides you the instructions to run LLM Decoder model with different 1. LLAMA3.2 3B 1. Codegen2 1B 1. Gemma 2B + 1. Gemma2 2B 1. Gemma3 1B 1. GLM 1.5B 1. Granite3.3 2B @@ -136,6 +137,11 @@ Default example using hybrid mode python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 ``` +#### Gemma2 2B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --decoder_model gemma2-2b --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 +``` #### Gemma3 1B Default example using hybrid mode diff --git a/examples/qualcomm/oss_scripts/llama/__init__.py b/examples/qualcomm/oss_scripts/llama/__init__.py index 6144883d036..4a8a017758b 100644 --- a/examples/qualcomm/oss_scripts/llama/__init__.py +++ b/examples/qualcomm/oss_scripts/llama/__init__.py @@ -15,6 +15,7 @@ convert_weights as convert_codegen_weights, ) from executorch.examples.models.gemma import convert_weights as convert_gemma_weights +from executorch.examples.models.gemma2 import convert_weights as convert_gemma2_weights from executorch.examples.models.gemma3 import convert_weights as convert_gemma3_weights from executorch.examples.models.glm import convert_weights as convert_glm_weights @@ -59,6 +60,7 @@ from executorch.examples.qualcomm.oss_scripts.llama.static_llm_quant_recipe import ( CodegenQuantRecipe, + Gemma2QuantRecipe, Gemma3QuantRecipe, Gemma_2BQuantRecipe, GLM_1_5B_InstructQuantRecipe, @@ -88,6 +90,7 @@ "gemma3-1b": MultiScopeAwareLlamaModel, "smolvlm_500m_instruct": LlamaModelWithoutEmbedding, "internvl3_1b": LlamaModelWithoutEmbedding, + "gemma2-2b": MultiScopeAwareLlamaModel, } @@ -303,6 +306,26 @@ class Gemma_2B(LLMModelConfig): quant_recipe = Gemma_2BQuantRecipe +@register_llm_model("gemma2-2b") +@dataclass(init=False, frozen=True) +class Gemma2(LLMModelConfig): + repo_id: str = "google/gemma-2-2b-it" + params_path: str = os.path.join( + BASE_DIR, "../../../models/gemma2/config/2b_config.json" + ) + convert_weights = convert_gemma2_weights + transform_weight = False + instruct_model = True + + num_sharding = 4 + masked_softmax = True + seq_mse_candidates = 0 + r1 = False + r2 = False + r3 = False + quant_recipe = Gemma2QuantRecipe + + @register_llm_model("gemma3-1b") @dataclass(init=False, frozen=True) class Gemma3(LLMModelConfig): diff --git a/examples/qualcomm/oss_scripts/llama/decoder_constants.py b/examples/qualcomm/oss_scripts/llama/decoder_constants.py index 3baa4b94ed6..4282a73d6bb 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_constants.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_constants.py @@ -38,6 +38,7 @@ "llama3_2-3b_instruct": "llama3", "codegen2_1b": "codegen", "gemma-2b": "gemma", + "gemma2-2b": "gemma2", "gemma3-1b": "gemma3", "granite_3_3-2b_instruct": "granite", "phi_4_mini": "phi_4_mini", diff --git a/examples/qualcomm/oss_scripts/llama/model/layernorm.py b/examples/qualcomm/oss_scripts/llama/model/layernorm.py index 7db14bdfd01..5de4a13ea9f 100644 --- a/examples/qualcomm/oss_scripts/llama/model/layernorm.py +++ b/examples/qualcomm/oss_scripts/llama/model/layernorm.py @@ -42,6 +42,7 @@ def __init__(self, hidden_size: int, eps=1e-5): super().__init__(hidden_size, eps=eps) +@register_norm("gemma2") @register_norm("gemma3") @register_norm("rmsnorm") class RMSNorm(torch.nn.RMSNorm, Norm): diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index a9c81e635b5..8b539d206c0 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -102,6 +102,9 @@ def __init__(self, layer_idx: int, config: ModelArgs, output_new_cache_only=Fals else 1.0 / config.attention_multiplier ) + # gemma 2 uses soft-capping on attention and logits + self.attn_logit_softcapping = config.attn_logit_softcapping + if getattr(config, "enable_r3", False): self.register_buffer( "r3_weight", @@ -276,6 +279,11 @@ def forward( vh = repeat_kv(vh, self.num_key_value_groups) attn = q @ kh + # gemma2-2b + if self.attn_logit_softcapping is not None: + attn = attn / self.attn_logit_softcapping + attn = torch.tanh(attn) + attn = attn * self.attn_logit_softcapping attn = attn / self.scale if self.enable_masked_softmax: attn_min = torch.amin(attn, dim=-1, keepdim=True) @@ -742,6 +750,8 @@ def __init__( self.sliding_window = kwargs["sliding_window"] # Get local freq base for sliding attention rope_freq_base = kwargs["rope_local_base_freq"] + # Parameter final_logit_softcapping is not necessary for all + self.final_logit_softcapping = kwargs.get("final_logit_softcapping") local_freqs_cos, local_freqs_sin = hf_precompute_freqs_cis( config.head_dim, @@ -817,6 +827,11 @@ def forward( hidden_states = self.norm(hidden_states) logits = self.output(hidden_states) + if self.final_logit_softcapping: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping + if self.output_cache: return logits, output_k_cache, output_v_cache return logits diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 900da3906bb..095f82f75bb 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -9,9 +9,9 @@ /** * @file * - * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma3 1B, - * Granite3.3 2B, phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B / 1.7B, - * SmolLM2 135M, SmolLM3 3B with Qualcomm AI Engine Direct. + * This tool can run Llama2 110M, Llama3.2 1B / 3B, Gemma 2B, Gemma2 2B, Gemma3 + * 1B, Granite3.3 2B, phi4-mini-instruct, Qwen2.5 0.5B / 1.5B, Qwen3 0.6B + * / 1.7B, SmolLM2 135M, SmolLM3 3B with Qualcomm AI Engine Direct. * */ @@ -130,6 +130,12 @@ std::string get_formatted_prompt( formatted_prompt.append("\n"); } break; + case example::DecoderModelVersion::kGemma2: + formatted_prompt.append("user\n"); + formatted_prompt.append(prompt); + formatted_prompt.append("\n"); + formatted_prompt.append("model\n"); + break; case example::DecoderModelVersion::kGranite: if (!system_prompt.empty()) { formatted_prompt.append("<|start_of_role|>system<|end_of_role|>"); diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 440b56f67a0..54a5c0b1d5c 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -118,6 +118,9 @@ Runner::Runner( decoder_model_version_ = DecoderModelVersion::kLlama3; } else if (decoder_model_version == "gemma") { decoder_model_version_ = DecoderModelVersion::kGemma; + } else if (decoder_model_version == "gemma2") { + decoder_model_version_ = DecoderModelVersion::kGemma2; + cache_mode_ = CacheMode::HybridCache; } else if (decoder_model_version == "gemma3") { decoder_model_version_ = DecoderModelVersion::kGemma3; cache_mode_ = CacheMode::HybridCache; @@ -202,6 +205,7 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("<|im_end|>", 0, 0).get()[0]); } else if ( decoder_model_version_ == DecoderModelVersion::kGemma || + decoder_model_version_ == DecoderModelVersion::kGemma2 || decoder_model_version_ == DecoderModelVersion::kGemma3) { eos_ids->insert(tokenizer_->encode("", 0, 0).get()[0]); } else if (decoder_model_version_ == DecoderModelVersion::kCodegen) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index f7ad3503d19..826a2db1bf1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -42,6 +42,7 @@ enum DecoderModelVersion { kSmollm3, kCodegen, kGlm, + kGemma2, }; enum KvBitWidth { diff --git a/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py index 3b0cd8efb5b..81230a4888b 100644 --- a/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py +++ b/examples/qualcomm/oss_scripts/llama/static_llm_quant_recipe.py @@ -263,6 +263,43 @@ def __init__(self, verbose: bool = False): self.recipe.custom_quant_annotations.append(annotate_kv_8bit) +class Gemma2QuantRecipe(StaticLLMQuantRecipe): + default_quant_dtype = QuantDtype.use_16a4w + + def __init__(self, verbose: bool = False): + super().__init__() + + self.recipe = ( + QuantRecipe( + self.default_quant_dtype, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_TENSOR, + verbose=verbose, + ) + .add_node_target( + { + torch.ops.aten.conv2d.default, + }, + QuantDtype.use_16a4w_block, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_BLOCK, + extra_kwargs={"block_size": (1, 32, 1, 1)}, + ) + .add_regex( + { + r"layers\..*\.attention\.wv.*", + }, + QuantDtype.use_16a8w, + False, + act_observer=MinMaxObserver, + granularity=QuantGranularity.PER_CHANNEL, + ) + ) + self.recipe.custom_quant_annotations.append(annotate_kv_8bit) + + class Gemma3QuantRecipe(StaticLLMQuantRecipe): default_quant_dtype = QuantDtype.use_16a4w diff --git a/examples/qualcomm/oss_scripts/llama/wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers.py index a6ba5a5116e..61190c5d7e9 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers.py @@ -188,16 +188,6 @@ def __init__( get_passes_dependency_for_capture_program() if apply_embedding else None ) - # check if sharding required - if self.config.num_sharding > 1: - SplitGraph, setting = model_sharding.get_split_graph_pass( - self.meta["get_n_layers"], - shares=self.config.num_sharding, - ) - self.passes_job[SplitGraph] = setting - self.dep_table[SplitGraph] = [FoldQDQ] - self.dep_table[TagQuantIO] = [SplitGraph] - # load static llama model args params_path = ( config.params_path if control_args.params is None else control_args.params @@ -210,6 +200,17 @@ def __init__( self.decoder = None if (instance := self._prepare_model()) is not None: self.tok_embedding, self.decoder = instance + self.meta = self.decoder.get_metadata() + + # check if sharding required + if instance and self.config.num_sharding > 1: + SplitGraph, setting = model_sharding.get_split_graph_pass( + self.meta["get_n_layers"], + shares=self.config.num_sharding, + ) + self.passes_job[SplitGraph] = setting + self.dep_table[SplitGraph] = [FoldQDQ] + self.dep_table[TagQuantIO] = [SplitGraph] def _process_model_args(self, model_args: ModelArgs): # TODO: support batch inputs if necessary @@ -246,7 +247,11 @@ def _prepare_model(self): # noqa: C901 state_dict = torch.load( checkpoint, weights_only=True, map_location="cpu", mmap=True ) - if self.control_args.decoder_model in {"gemma-2b", "gemma3-1b"}: + if self.control_args.decoder_model in { + "gemma-2b", + "gemma2-2b", + "gemma3-1b", + }: for k, v in state_dict.items(): if "norm" not in k: continue @@ -395,6 +400,19 @@ def _get_model_specific_kwargs(self): hf_config.text_config.rope_local_base_freq ) kwargs["sliding_window"] = hf_config.sliding_window + case "gemma2-2b": + from transformers import Gemma2Config + + hf_config = Gemma2Config.from_pretrained(self.config.repo_id) + kwargs["layer_types"] = hf_config.layer_types + kwargs["rope_local_base_freq"] = hf_config.rope_parameters[ + "rope_theta" + ] + kwargs["sliding_window"] = hf_config.sliding_window + kwargs["final_logit_softcapping"] = ( + hf_config.final_logit_softcapping + ) + kwargs["attn_logit_softcapping"] = hf_config.attn_logit_softcapping return kwargs