Skip to content

Commit d0fe026

Browse files
Qualcomm AI Engine Direct - GA Static Phi-4-mini
Summary: - Support Phi-4-mini-instruct for static llama path - add P-ROPE for phi-4-mini
1 parent 264a91b commit d0fe026

File tree

8 files changed

+79
-8
lines changed

8 files changed

+79
-8
lines changed

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ This file provides you the instructions to run LLM Decoder model with different
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
88
4. QWEN2.5 0.5B
9+
5. QWEN3 0.6B / 1.7B
10+
6. Phi4-mini-instruct
911

1012
We offer the following modes to execute the model:
1113

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from dataclasses import dataclass, field
1010
from typing import Callable, Dict, Type
1111

12+
from executorch.examples.models.phi_4_mini import (
13+
convert_weights as convert_phi_4_mini_weights,
14+
)
15+
1216
from executorch.examples.models.qwen2_5 import (
1317
convert_weights as convert_qwen2_5_weights,
1418
)
@@ -71,3 +75,14 @@ class Qwen3_1_7B(HFModel):
7175
)
7276
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
7377
convert_weights = convert_qwen3_weights
78+
79+
80+
@register_hf_model("phi_4_mini")
81+
@dataclass(init=False, frozen=True)
82+
class Phi4Mini(HFModel):
83+
repo_id: str = "microsoft/Phi-4-mini-instruct"
84+
params_path: str = os.path.join(
85+
BASE_DIR, "../../../models/phi_4_mini/config/config.json"
86+
)
87+
runner_version: str = field(default=DECODER_MODEL_VERSION["phi_4_mini"])
88+
convert_weights = convert_phi_4_mini_weights

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
"stories110m": "llama2",
1616
"llama3_2": "llama3",
1717
"qwen2_5": "qwen2_5",
18+
"phi_4_mini": "phi_4_mini",
1819
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def compile(args, pte_filename, tokenizer):
362362
kv_config.use_kv_cache = True
363363
kv_config.enable_masked_softmax = args.enable_masked_softmax
364364
kv_config.enable_r3 = args.r3
365+
kv_config.base_model_name_or_path = args.decoder_model
365366

366367
prefill_config = copy.copy(kv_config)
367368
prefill_config.use_kv_cache = (
@@ -565,6 +566,7 @@ def permute(w, heads):
565566
llama_instance_list[i] = SingleLlama(
566567
llama_instance_list[i].eval(), pte_filename
567568
)
569+
568570
if args.embedding_quantize:
569571
llama_instance_list[i].passes_job[I64toI32][
570572
QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY
@@ -1180,6 +1182,16 @@ def export_llama(args) -> None:
11801182
# TODO: Encountered the following error during runtime, so switched behavior for now.
11811183
# Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: Unsupported Normalizer type: NFC.
11821184
data.pop("normalizer")
1185+
elif args.decoder_model == "phi_4_mini":
1186+
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
1187+
tokenizer = AutoTokenizer.from_pretrained(model_id)
1188+
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
1189+
tokenizer = get_tokenizer(runtime_tokenizer_path)
1190+
with open(runtime_tokenizer_path, "r+") as file:
1191+
data = json.load(file)
1192+
# TODO: Encountered the following error during runtime, so switched behavior for now.
1193+
# Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: invert=true is not supported for Split PreTokenizer. Only invert=false is supported.
1194+
data["pre_tokenizer"]["pretokenizers"][-2]["invert"] = False
11831195
file.seek(0)
11841196
json.dump(data, file, indent=4)
11851197
file.truncate()

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,35 @@ def apply_rotary_emb_single(
3434
freqs_sin = freqs_sin[None, :, None, :]
3535
x_out_r = x_r * freqs_cos - x_i * freqs_sin
3636
x_out_i = x_r * freqs_sin + x_i * freqs_cos
37-
3837
x_out = torch.cat([x_out_r, x_out_i], dim=-1)
3938
return x_out
4039

4140

41+
def apply_partial_rotary_emb_single(
42+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
43+
) -> torch.Tensor:
44+
45+
if x.dim() == 4:
46+
freqs_cos = freqs_cos[None, :, None, :]
47+
freqs_sin = freqs_sin[None, :, None, :]
48+
49+
rotary_dim = freqs_cos.shape[-1] * 2
50+
51+
x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
52+
x_r, x_i = x_rot[..., : x_rot.shape[-1] // 2], x_rot[..., x_rot.shape[-1] // 2 :]
53+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
54+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
55+
x_rotated = torch.cat([x_out_r, x_out_i], dim=-1)
56+
return torch.cat([x_rotated, x_pass], dim=-1)
57+
58+
59+
APPLY_ROPE_EMBEDDING_FUNCTIONS = {
60+
"phi_4_mini": apply_partial_rotary_emb_single,
61+
"qwen2_5": apply_rotary_emb_single,
62+
"llama3_2": apply_rotary_emb_single,
63+
}
64+
65+
4266
class LlamaAttention(nn.Module):
4367
def __init__(self, config: ModelArgs, output_new_cache_only=False):
4468
super().__init__()
@@ -59,6 +83,9 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
5983
k_norm_dim = self.head_dim
6084
self.q_norm_fn = torch.nn.RMSNorm(q_norm_dim, eps=config.norm_eps)
6185
self.k_norm_fn = torch.nn.RMSNorm(k_norm_dim, eps=config.norm_eps)
86+
self.apply_rope_emb = APPLY_ROPE_EMBEDDING_FUNCTIONS[
87+
config.base_model_name_or_path
88+
]
6289

6390
self.wq = nn.Linear(
6491
self.dim,
@@ -199,15 +226,15 @@ def forward_sha( # noqa: C901
199226
for i in range(len(q)):
200227
if self.use_qk_norm and self.qk_norm_before_rope:
201228
q[i] = self.q_norm_fn(q[i])
202-
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
229+
q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin)
203230
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
204231
q[i] = torch.matmul(q[i], self.r3_weight.T)
205232
if self.use_qk_norm and not self.qk_norm_before_rope:
206233
q[i] = self.q_norm_fn(q[i])
207234
for i in range(len(k)):
208235
if self.use_qk_norm and self.qk_norm_before_rope:
209236
k[i] = self.k_norm_fn(k[i])
210-
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)
237+
k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin).transpose(1, 2)
211238
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
212239
k[i] = torch.matmul(k[i], self.r3_weight.T)
213240
if self.use_qk_norm and not self.qk_norm_before_rope:
@@ -272,8 +299,8 @@ def forward(
272299
q = self.q_norm_fn(q)
273300
k = self.k_norm_fn(k)
274301

275-
q = apply_rotary_emb_single(q, freqs_cos, freqs_sin)
276-
k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
302+
q = self.apply_rope_emb(q, freqs_cos, freqs_sin)
303+
k = self.apply_rope_emb(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
277304

278305
if self.use_qk_norm and not self.qk_norm_before_rope:
279306
q = self.q_norm_fn(q)
@@ -368,7 +395,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
368395
super().__init__()
369396
self.dim = config.dim
370397
self.attention = LlamaAttention(
371-
config=config, output_new_cache_only=output_new_cache_only
398+
config=config,
399+
output_new_cache_only=output_new_cache_only,
372400
)
373401
self.feed_forward = FeedForward(config)
374402
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
/**
1010
* @file
1111
*
12-
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B with Qualcomm
13-
* AI Engine Direct.
12+
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B, Qwen3 0.6B
13+
* / 1.7B phi4-mini-instruct with Qualcomm AI Engine Direct.
1414
*
1515
*/
1616

@@ -104,6 +104,16 @@ std::string get_formatted_prompt(
104104
case example::DecoderModelVersion::kQwen2_5:
105105
formatted_prompt.append(prompt);
106106
break;
107+
case example::DecoderModelVersion::kPhi4:
108+
if (!system_prompt.empty()) {
109+
formatted_prompt.append("<|system|>");
110+
formatted_prompt.append(system_prompt);
111+
formatted_prompt.append("<|end|>");
112+
}
113+
formatted_prompt.append("<|user|>");
114+
formatted_prompt.append(prompt);
115+
formatted_prompt.append("<|end|><|assistant|>");
116+
break;
107117
case example::DecoderModelVersion::kLlama3:
108118
if (!system_prompt.empty()) {
109119
formatted_prompt.append(

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ Runner::Runner(
130130
decoder_model_version_ = DecoderModelVersion::kLlama3;
131131
} else if (decoder_model_version == "qwen2_5") {
132132
decoder_model_version_ = DecoderModelVersion::kQwen2_5;
133+
} else if (decoder_model_version == "phi_4_mini") {
134+
decoder_model_version_ = DecoderModelVersion::kPhi4;
133135
} else {
134136
ET_CHECK_MSG(false, "Unsupported Decoder Model");
135137
}

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum DecoderModelVersion {
3131
kLlama2 = 0,
3232
kLlama3,
3333
kQwen2_5,
34+
kPhi4,
3435
};
3536
class Runner {
3637
public:

0 commit comments

Comments
 (0)