Skip to content

Commit c8a0706

Browse files
Qualcomm AI Engine Direct - GA Static Phi-4-mini (#13179)
### Summary - Support Phi-4-mini-instruct for static llama path - add P-ROPE for phi-4-mini - add EOS tok for Phi-4-mini ### Test plan ``` python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8750 --prompt "I would like to learn python, could you teach me with a simple example?" --temperature 0 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --ptq 16a8w --decoder_model phi_4_mini --num_sharding 4 ``` cc: @haowhsu-quic, @shewu-quic, @winskuo-quic, @cccclai
1 parent 6d56713 commit c8a0706

File tree

8 files changed

+73
-12
lines changed

8 files changed

+73
-12
lines changed

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ This file provides you the instructions to run LLM Decoder model with different
66
2. LLAMA3.2 1B
77
3. LLAMA3.2 3B
88
4. QWEN2.5 0.5B
9+
5. QWEN3 0.6B / 1.7B
10+
6. Phi4-mini-instruct
911

1012
We offer the following modes to execute the model:
1113

examples/qualcomm/oss_scripts/llama/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from dataclasses import dataclass, field
1010
from typing import Callable, Dict, Type
1111

12+
from executorch.examples.models.phi_4_mini import (
13+
convert_weights as convert_phi_4_mini_weights,
14+
)
1215
from executorch.examples.models.qwen2_5 import (
1316
convert_weights as convert_qwen2_5_weights,
1417
)
@@ -71,3 +74,14 @@ class Qwen3_1_7B(HFModel):
7174
)
7275
runner_version: str = field(default=DECODER_MODEL_VERSION["qwen2_5"])
7376
convert_weights = convert_qwen3_weights
77+
78+
79+
@register_hf_model("phi_4_mini")
80+
@dataclass(init=False, frozen=True)
81+
class Phi4Mini(HFModel):
82+
repo_id: str = "microsoft/Phi-4-mini-instruct"
83+
params_path: str = os.path.join(
84+
BASE_DIR, "../../../models/phi_4_mini/config/config.json"
85+
)
86+
runner_version: str = field(default=DECODER_MODEL_VERSION["phi_4_mini"])
87+
convert_weights = convert_phi_4_mini_weights

examples/qualcomm/oss_scripts/llama/decoder_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
"stories110m": "llama2",
1616
"llama3_2": "llama3",
1717
"qwen2_5": "qwen2_5",
18+
"phi_4_mini": "phi_4_mini",
1819
}

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def permute(w, heads):
579579
annotate_conv=args.ptq != "16a8w",
580580
),
581581
)
582-
if args.decoder_model == {"stories110m", "stories260k"}:
582+
if args.decoder_model in {"stories110m", "stories260k"}:
583583
custom_annotations = custom_annotations + (
584584
annotate_linear_16a8w_in_affine_layer,
585585
)
@@ -1175,11 +1175,16 @@ def export_llama(args) -> None:
11751175
tokenizer = AutoTokenizer.from_pretrained(model_id)
11761176
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
11771177
tokenizer = get_tokenizer(runtime_tokenizer_path)
1178+
elif args.decoder_model == "phi_4_mini":
1179+
model_id = SUPPORTED_HF_MODELS[args.decoder_model].repo_id
1180+
tokenizer = AutoTokenizer.from_pretrained(model_id)
1181+
runtime_tokenizer_path = tokenizer.save_pretrained(args.artifact)[-1]
1182+
tokenizer = get_tokenizer(runtime_tokenizer_path)
11781183
with open(runtime_tokenizer_path, "r+") as file:
11791184
data = json.load(file)
11801185
# TODO: Encountered the following error during runtime, so switched behavior for now.
1181-
# Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: Unsupported Normalizer type: NFC.
1182-
data.pop("normalizer")
1186+
# Error: libc++abi: terminating due to uncaught exception of type std::runtime_error: invert=true is not supported for Split PreTokenizer. Only invert=false is supported.
1187+
data["pre_tokenizer"]["pretokenizers"][-2]["invert"] = False
11831188
file.seek(0)
11841189
json.dump(data, file, indent=4)
11851190
file.truncate()

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@ def apply_rotary_emb_single(
3939
return x_out
4040

4141

42+
def apply_partial_rotary_emb_single(
43+
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
44+
) -> torch.Tensor:
45+
46+
if x.dim() == 4:
47+
freqs_cos = freqs_cos[None, :, None, :]
48+
freqs_sin = freqs_sin[None, :, None, :]
49+
50+
rotary_dim = freqs_cos.shape[-1] * 2
51+
52+
x_rot, x_pass = x[..., :rotary_dim], x[..., rotary_dim:]
53+
x_r, x_i = x_rot[..., : x_rot.shape[-1] // 2], x_rot[..., x_rot.shape[-1] // 2 :]
54+
x_out_r = x_r * freqs_cos - x_i * freqs_sin
55+
x_out_i = x_r * freqs_sin + x_i * freqs_cos
56+
x_rotated = torch.cat([x_out_r, x_out_i], dim=-1)
57+
return torch.cat([x_rotated, x_pass], dim=-1)
58+
59+
4260
class LlamaAttention(nn.Module):
4361
def __init__(self, config: ModelArgs, output_new_cache_only=False):
4462
super().__init__()
@@ -60,6 +78,11 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
6078
self.q_norm_fn = torch.nn.RMSNorm(q_norm_dim, eps=config.norm_eps)
6179
self.k_norm_fn = torch.nn.RMSNorm(k_norm_dim, eps=config.norm_eps)
6280

81+
if config.partial_rotary_factor < 1:
82+
self.apply_rope_emb = apply_partial_rotary_emb_single
83+
else:
84+
self.apply_rope_emb = apply_rotary_emb_single
85+
6386
self.wq = nn.Linear(
6487
self.dim,
6588
self.n_heads * self.head_dim,
@@ -199,17 +222,17 @@ def forward_sha( # noqa: C901
199222
for i in range(len(q)):
200223
if self.use_qk_norm and self.qk_norm_before_rope:
201224
q[i] = self.q_norm_fn(q[i])
202-
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
225+
q[i] = self.apply_rope_emb(q[i], freqs_cos, freqs_sin)
203226
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
204-
q[i] = torch.matmul(q[i], self.r3_weight.T)
227+
q[i] = torch.matmul(q[i], self.r3_weight)
205228
if self.use_qk_norm and not self.qk_norm_before_rope:
206229
q[i] = self.q_norm_fn(q[i])
207230
for i in range(len(k)):
208231
if self.use_qk_norm and self.qk_norm_before_rope:
209232
k[i] = self.k_norm_fn(k[i])
210-
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)
233+
k[i] = self.apply_rope_emb(k[i], freqs_cos, freqs_sin).transpose(1, 2)
211234
if hasattr(self.config, "enable_r3") and self.config.enable_r3:
212-
k[i] = torch.matmul(k[i], self.r3_weight.T)
235+
k[i] = torch.matmul(k[i], self.r3_weight)
213236
if self.use_qk_norm and not self.qk_norm_before_rope:
214237
k[i] = self.k_norm_fn(k[i])
215238

@@ -272,8 +295,8 @@ def forward(
272295
q = self.q_norm_fn(q)
273296
k = self.k_norm_fn(k)
274297

275-
q = apply_rotary_emb_single(q, freqs_cos, freqs_sin)
276-
k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
298+
q = self.apply_rope_emb(q, freqs_cos, freqs_sin)
299+
k = self.apply_rope_emb(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1)
277300

278301
if self.use_qk_norm and not self.qk_norm_before_rope:
279302
q = self.q_norm_fn(q)
@@ -368,7 +391,8 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
368391
super().__init__()
369392
self.dim = config.dim
370393
self.attention = LlamaAttention(
371-
config=config, output_new_cache_only=output_new_cache_only
394+
config=config,
395+
output_new_cache_only=output_new_cache_only,
372396
)
373397
self.feed_forward = FeedForward(config)
374398
self.attention_norm = torch.nn.RMSNorm(config.dim, eps=config.norm_eps)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
/**
1010
* @file
1111
*
12-
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B with Qualcomm
13-
* AI Engine Direct.
12+
* This tool can run Llama2 110M, Llama3.2 1B / 3B, Qwen2.5 0.5B, Qwen3 0.6B
13+
* / 1.7B phi4-mini-instruct with Qualcomm AI Engine Direct.
1414
*
1515
*/
1616

@@ -104,6 +104,16 @@ std::string get_formatted_prompt(
104104
case example::DecoderModelVersion::kQwen2_5:
105105
formatted_prompt.append(prompt);
106106
break;
107+
case example::DecoderModelVersion::kPhi4:
108+
if (!system_prompt.empty()) {
109+
formatted_prompt.append("<|system|>");
110+
formatted_prompt.append(system_prompt);
111+
formatted_prompt.append("<|end|>");
112+
}
113+
formatted_prompt.append("<|user|>");
114+
formatted_prompt.append(prompt);
115+
formatted_prompt.append("<|end|><|assistant|>");
116+
break;
107117
case example::DecoderModelVersion::kLlama3:
108118
if (!system_prompt.empty()) {
109119
formatted_prompt.append(

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ Runner::Runner(
130130
decoder_model_version_ = DecoderModelVersion::kLlama3;
131131
} else if (decoder_model_version == "qwen2_5") {
132132
decoder_model_version_ = DecoderModelVersion::kQwen2_5;
133+
} else if (decoder_model_version == "phi_4_mini") {
134+
decoder_model_version_ = DecoderModelVersion::kPhi4;
133135
} else {
134136
ET_CHECK_MSG(false, "Unsupported Decoder Model");
135137
}
@@ -185,6 +187,8 @@ Error Runner::load() {
185187
}
186188
if (decoder_model_version_ == DecoderModelVersion::kLlama3) {
187189
eos_ids->insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]);
190+
} else if (decoder_model_version_ == DecoderModelVersion::kPhi4) {
191+
eos_ids->insert(tokenizer_->encode("<|end|>", 0, 0).get()[0]);
188192
}
189193
// Try avoid getMetadataHelper as it is time consuming.
190194
Result<MethodMeta> method_meta =

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ enum DecoderModelVersion {
3131
kLlama2 = 0,
3232
kLlama3,
3333
kQwen2_5,
34+
kPhi4,
3435
};
3536
class Runner {
3637
public:

0 commit comments

Comments
 (0)