Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/models/gemma2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.gemma2.convert_weights import convert_weights
from executorch.examples.models.llama.model import Llama2Model


class Gemma2Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"Gemma2Model",
"convert_weights",
]
26 changes: 26 additions & 0 deletions examples/models/gemma2/config/2b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"dim": 2304,
"ffn_dim_multiplier": 1,
"hidden_dim": 9216,
"n_heads": 8,
"head_dim": 256,
"n_kv_heads": 4,
"n_layers": 26,
"act_fn": "gelu_approx",
"norm_type": "gemma3",
"norm_eps": 1e-06,
"post_attention_norm": true,
"post_ffn_norm": true,
"rope_theta": 10000.0,
"use_scaled_rope": false,
"apply_embedding": true,
"embedding_scale_factor": 48.0,
"vocab_size": 256000,
"use_hf_rope": true,
"attention_qkv_bias": false,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"sliding_window": 4096,
"layer_types": ["local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global"],
"rope_local_base_freq": 10000.0
}
113 changes: 113 additions & 0 deletions examples/models/gemma2/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import json
import os
from typing import Dict

import torch
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key


# Weight mappings from Gemma 2's checkpoint to ExecuTorch's transformer parameters.
_GEMMA2_TO_EXECUTORCH = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight",
"model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight",
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
}


def gemma2_to_executorch(
state_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""
Convert the state dict so that it matches what ExecuTorch's transformer definition expects.
"""
converted_state_dict = {}
for key, value in state_dict.items():
new_key = get_mapped_key(key, _GEMMA2_TO_EXECUTORCH)
converted_state_dict[new_key] = value
converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]
return converted_state_dict


def load_checkpoint_from_safetensors(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def load_checkpoint(input_dir: str) -> Dict:
pytorch_path = os.path.join(input_dir, "pytorch_model.bin")
if os.path.exists(pytorch_path):
print("Loading checkpoint from PyTorch .bin file")
return torch.load(pytorch_path, map_location="cpu", weights_only=True)
print("Loading checkpoint from safetensors directory")
return load_checkpoint_from_safetensors(input_dir)


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = gemma2_to_executorch(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert Gemma 2 weights to ExecuTorch transformer format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing safetensor checkpoint files, or PyTorch checkpoint file.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
Loading