Skip to content

Commit db3e22b

Browse files
authored
Enable optimized deepseekv2 -- part 1 (#3420)
1 parent 4196b77 commit db3e22b

File tree

25 files changed

+3448
-8
lines changed

25 files changed

+3448
-8
lines changed

csrc/cpu/aten/MoE.cpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace cpu {
77

88
IPEX_DEFINE_DISPATCH(mixtral_moe_tpp_kernel_stub);
99
IPEX_DEFINE_DISPATCH(mixtral_moe_woq_kernel_stub);
10+
IPEX_DEFINE_DISPATCH(deepseek_moe_woq_kernel_stub);
1011
IPEX_DEFINE_DISPATCH(mixtral_moe_kernel_stub);
1112

1213
at::Tensor mixtral_moe_tpp(
@@ -38,6 +39,41 @@ at::Tensor mixtral_moe_tpp(
3839
is_distributed);
3940
}
4041

42+
at::Tensor deepseek_moe_tpp(
43+
const at::Tensor& hidden_states,
44+
const at::Tensor& expert_mask,
45+
const std::vector<at::Tensor>& gate_wei,
46+
const std::vector<at::Tensor>& up_wei,
47+
const std::vector<at::Tensor>& down_wei,
48+
bool tpp_fallback,
49+
const at::Tensor& routing_weights,
50+
at::Tensor& output,
51+
bool is_distributed) {
52+
RECORD_FUNCTION("ipex::deepseek_moe_tpp", c10::ArrayRef<c10::IValue>({}));
53+
54+
int num_experts = gate_wei.size();
55+
for (auto i = 0; i < num_experts; i++) {
56+
auto non_zero = expert_mask[i].nonzero();
57+
if (non_zero.sizes()[0] == 0)
58+
continue;
59+
auto idx = non_zero.select(1, 0);
60+
auto top_x = non_zero.select(1, 1);
61+
output = mixtral_moe_tpp_kernel_stub(
62+
kCPU,
63+
hidden_states,
64+
top_x,
65+
idx,
66+
gate_wei[i],
67+
up_wei[i],
68+
down_wei[i],
69+
tpp_fallback,
70+
routing_weights,
71+
output,
72+
is_distributed);
73+
}
74+
return output;
75+
}
76+
4177
at::Tensor mixtral_moe(
4278
const at::Tensor& hidden_states,
4379
const at::Tensor& top_x,
@@ -72,6 +108,87 @@ at::Tensor mixtral_moe(
72108
output,
73109
is_distributed);
74110
}
111+
112+
at::Tensor deepseek_moe(
113+
const at::Tensor& hidden_states,
114+
const at::Tensor& expert_mask,
115+
const std::vector<at::Tensor>& gate_wei,
116+
const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
117+
const std::vector<at::Tensor>& up_wei,
118+
const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
119+
const std::vector<at::Tensor>& down_wei,
120+
const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
121+
const at::Tensor& routing_weights,
122+
at::Tensor& output,
123+
bool is_distributed) {
124+
RECORD_FUNCTION("ipex::deepseek_moe", c10::ArrayRef<c10::IValue>({}));
125+
126+
int num_experts = gate_wei.size();
127+
for (auto i = 0; i < num_experts; i++) {
128+
auto non_zero = expert_mask[i].nonzero();
129+
if (non_zero.sizes()[0] == 0)
130+
continue;
131+
auto idx = non_zero.select(1, 0);
132+
auto top_x = non_zero.select(1, 1);
133+
134+
output = mixtral_moe_kernel_stub(
135+
kCPU,
136+
hidden_states,
137+
top_x,
138+
idx,
139+
gate_wei[i],
140+
gate_op_ctx[i]->get_data_handle(),
141+
up_wei[i],
142+
up_op_ctx[i]->get_data_handle(),
143+
down_wei[i],
144+
down_op_ctx[i]->get_data_handle(),
145+
true,
146+
routing_weights,
147+
output,
148+
is_distributed);
149+
}
150+
return output;
151+
}
152+
153+
at::Tensor deepseek_moe_mkl(
154+
const at::Tensor& hidden_states,
155+
const at::Tensor& expert_mask,
156+
const std::vector<at::Tensor>& gate_wei,
157+
const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
158+
const std::vector<at::Tensor>& up_wei,
159+
const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
160+
const std::vector<at::Tensor>& down_wei,
161+
const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
162+
const at::Tensor& routing_weights,
163+
at::Tensor& output,
164+
bool is_distributed) {
165+
RECORD_FUNCTION("ipex::deepseek_moe_mkl", c10::ArrayRef<c10::IValue>({}));
166+
167+
int num_experts = gate_wei.size();
168+
for (auto i = 0; i < num_experts; i++) {
169+
auto non_zero = expert_mask[i].nonzero();
170+
if (non_zero.sizes()[0] == 0)
171+
continue;
172+
auto idx = non_zero.select(1, 0);
173+
auto top_x = non_zero.select(1, 1);
174+
output = mixtral_moe_kernel_stub(
175+
kCPU,
176+
hidden_states,
177+
top_x,
178+
idx,
179+
gate_wei[i],
180+
gate_op_ctx[i]->get_data_handle(),
181+
up_wei[i],
182+
up_op_ctx[i]->get_data_handle(),
183+
down_wei[i],
184+
down_op_ctx[i]->get_data_handle(),
185+
false,
186+
routing_weights,
187+
output,
188+
is_distributed);
189+
}
190+
return output;
191+
}
75192
at::Tensor mixtral_moe_woq(
76193
const at::Tensor& hidden_states,
77194
const at::Tensor& top_x,
@@ -98,6 +215,38 @@ at::Tensor mixtral_moe_woq(
98215
output,
99216
is_distributed);
100217
}
218+
at::Tensor deepseek_moe_woq(
219+
const at::Tensor& hidden_states,
220+
const at::Tensor& expert_mask,
221+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
222+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
223+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
224+
const at::Tensor& routing_weights,
225+
at::Tensor& output,
226+
bool is_distributed) {
227+
RECORD_FUNCTION("ipex::deepseek_moe_woq", c10::ArrayRef<c10::IValue>({}));
228+
229+
int num_experts = gate_ctx.size();
230+
for (auto i = 0; i < num_experts; i++) {
231+
auto non_zero = expert_mask[i].nonzero();
232+
if (non_zero.sizes()[0] == 0)
233+
continue;
234+
auto idx = non_zero.select(1, 0);
235+
auto top_x = non_zero.select(1, 1);
236+
output = mixtral_moe_woq_kernel_stub(
237+
kCPU,
238+
hidden_states,
239+
top_x,
240+
idx,
241+
gate_ctx[i]->get_data_handle(),
242+
up_ctx[i]->get_data_handle(),
243+
down_ctx[i]->get_data_handle(),
244+
routing_weights,
245+
output,
246+
is_distributed);
247+
}
248+
return output;
249+
}
101250
} // namespace cpu
102251
} // namespace torch_ipex
103252

@@ -112,17 +261,53 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
112261
"mixtral_moe_tpp",
113262
c10::DispatchKey::CPU,
114263
torch_ipex::cpu::mixtral_moe_tpp);
264+
m.def(
265+
"deepseek_moe_tpp(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
266+
Tensor[] up_wei, Tensor[] down_wei, bool tpp_fallback, Tensor routing_weights, \
267+
Tensor output, bool is_distributed) -> Tensor");
268+
m.impl(
269+
"deepseek_moe_tpp",
270+
c10::DispatchKey::CPU,
271+
torch_ipex::cpu::deepseek_moe_tpp);
115272
m.def(
116273
"mixtral_moe(Tensor hidden_states, Tensor top_x, Tensor idx, Tensor gate_wei, \
117274
Tensor gate_op_ctx, Tensor up_wei, Tensor up_op_ctx, Tensor down_wei, \
118275
Tensor down_op_ctx, bool use_dnnl, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
119276
m.impl("mixtral_moe", c10::DispatchKey::CPU, torch_ipex::cpu::mixtral_moe);
277+
m.def(
278+
"deepseek_moe(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
279+
__torch__.torch.classes.ipex_prepack.LinearOpContext[] gate_op_ctx, Tensor[] up_wei, \
280+
__torch__.torch.classes.ipex_prepack.LinearOpContext[] up_op_ctx, Tensor[] down_wei, \
281+
__torch__.torch.classes.ipex_prepack.LinearOpContext[] down_op_ctx, Tensor routing_weights, \
282+
Tensor output, bool is_distributed) -> Tensor");
283+
m.impl("deepseek_moe", c10::DispatchKey::CPU, torch_ipex::cpu::deepseek_moe);
284+
m.def(
285+
"deepseek_moe_mkl(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
286+
__torch__.torch.classes.ipex_prepack.MKLOpContext[] gate_op_ctx, Tensor[] up_wei, \
287+
__torch__.torch.classes.ipex_prepack.MKLOpContext[] up_op_ctx, \
288+
Tensor[] down_wei, __torch__.torch.classes.ipex_prepack.MKLOpContext[] down_op_ctx, \
289+
Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
290+
m.impl(
291+
"deepseek_moe_mkl",
292+
c10::DispatchKey::CPU,
293+
torch_ipex::cpu::deepseek_moe_mkl);
120294
m.def(
121295
"mixtral_moe_woq(Tensor hidden_states, Tensor top_x, Tensor idx, Tensor gate_wei, \
122296
Tensor up_wei, Tensor down_wei, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
123297
m.impl(
124298
"mixtral_moe_woq",
125299
c10::DispatchKey::CPU,
126300
torch_ipex::cpu::mixtral_moe_woq);
301+
m.def(
302+
"deepseek_moe_woq(Tensor hidden_states, Tensor expert_mask, \
303+
__torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] gate_ctx, \
304+
__torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] up_ctx, \
305+
__torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] down_ctx, \
306+
Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor");
307+
308+
m.impl(
309+
"deepseek_moe_woq",
310+
c10::DispatchKey::CPU,
311+
torch_ipex::cpu::deepseek_moe_woq);
127312
}
128313
} // namespace

csrc/cpu/aten/MoE.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <ATen/ATen.h>
44
#include <dyndisp/DispatchStub.h>
5+
#include "Linear.h"
56

67
namespace torch_ipex {
78
namespace cpu {
@@ -16,6 +17,35 @@ at::Tensor mixtral_moe_tpp(
1617
const at::Tensor&,
1718
at::Tensor&,
1819
bool);
20+
at::Tensor deepseek_moe_tpp(
21+
const at::Tensor&,
22+
const at::Tensor&,
23+
const std::vector<at::Tensor>&,
24+
const std::vector<at::Tensor>&,
25+
const std::vector<at::Tensor>&,
26+
bool,
27+
const at::Tensor&,
28+
at::Tensor&,
29+
bool);
30+
at::Tensor mixtral_moe_woq(
31+
const at::Tensor&,
32+
const at::Tensor&,
33+
const at::Tensor&,
34+
const at::Tensor&,
35+
const at::Tensor&,
36+
const at::Tensor&,
37+
const at::Tensor&,
38+
at::Tensor&,
39+
bool);
40+
at::Tensor deepseek_moe_woq(
41+
const at::Tensor&,
42+
const at::Tensor&,
43+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>&,
44+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>&,
45+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>&,
46+
const at::Tensor&,
47+
at::Tensor&,
48+
bool);
1949
at::Tensor mixtral_moe_woq(
2050
const at::Tensor&,
2151
const at::Tensor&,
@@ -40,6 +70,30 @@ at::Tensor mixtral_moe(
4070
const at::Tensor&,
4171
at::Tensor&,
4272
bool);
73+
at::Tensor deepseek_moe(
74+
const at::Tensor&,
75+
const at::Tensor&,
76+
const std::vector<at::Tensor>&,
77+
const std::vector<c10::intrusive_ptr<LinearOpContext>>&,
78+
const std::vector<at::Tensor>&,
79+
const std::vector<c10::intrusive_ptr<LinearOpContext>>&,
80+
const std::vector<at::Tensor>&,
81+
const std::vector<c10::intrusive_ptr<LinearOpContext>>&,
82+
const at::Tensor&,
83+
at::Tensor&,
84+
bool);
85+
at::Tensor deepseek_moe_mkl(
86+
const at::Tensor&,
87+
const at::Tensor&,
88+
const std::vector<at::Tensor>&,
89+
const std::vector<c10::intrusive_ptr<MKLOpContext>>&,
90+
const std::vector<at::Tensor>&,
91+
const std::vector<c10::intrusive_ptr<MKLOpContext>>&,
92+
const std::vector<at::Tensor>&,
93+
const std::vector<c10::intrusive_ptr<MKLOpContext>>&,
94+
const at::Tensor&,
95+
at::Tensor&,
96+
bool);
4397
using mixtral_moe_tpp_kernel_fn = at::Tensor (*)(
4498
const at::Tensor& hidden_states,
4599
const at::Tensor& top_x,
@@ -51,6 +105,16 @@ using mixtral_moe_tpp_kernel_fn = at::Tensor (*)(
51105
const at::Tensor& routing_weights,
52106
at::Tensor& output,
53107
bool is_distributed);
108+
using deepseek_moe_tpp_kernel_fn = at::Tensor (*)(
109+
const at::Tensor& hidden_states,
110+
const at::Tensor& expert_mask,
111+
const std::vector<at::Tensor>& gate_wei,
112+
const std::vector<at::Tensor>& up_wei,
113+
const std::vector<at::Tensor>& down_wei,
114+
bool tpp_fallback,
115+
const at::Tensor& routing_weights,
116+
at::Tensor& output,
117+
bool is_distributed);
54118
using mixtral_moe_woq_kernel_fn = at::Tensor (*)(
55119
const at::Tensor& hidden_states,
56120
const at::Tensor& top_x,
@@ -61,6 +125,15 @@ using mixtral_moe_woq_kernel_fn = at::Tensor (*)(
61125
const at::Tensor& routing_weights,
62126
at::Tensor& output,
63127
bool is_distributed);
128+
using deepseek_moe_woq_kernel_fn = at::Tensor (*)(
129+
const at::Tensor& hidden_states,
130+
const at::Tensor& expert_mask,
131+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
132+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
133+
const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
134+
const at::Tensor& routing_weights,
135+
at::Tensor& output,
136+
bool is_distributed);
64137
using mixtral_moe_kernel_fn = at::Tensor (*)(
65138
const at::Tensor& hidden_states,
66139
const at::Tensor& top_x,
@@ -75,8 +148,36 @@ using mixtral_moe_kernel_fn = at::Tensor (*)(
75148
const at::Tensor& routing_weights,
76149
at::Tensor& output,
77150
bool is_distributed);
151+
using deepseek_moe_kernel_fn = at::Tensor (*)(
152+
const at::Tensor& hidden_states,
153+
const at::Tensor& expert_mask,
154+
const std::vector<at::Tensor>& gate_wei,
155+
const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
156+
const std::vector<at::Tensor>& up_wei,
157+
const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
158+
const std::vector<at::Tensor>& down_wei,
159+
const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
160+
const at::Tensor& routing_weights,
161+
at::Tensor& output,
162+
bool is_distributed);
163+
using deepseek_moe_mkl_kernel_fn = at::Tensor (*)(
164+
const at::Tensor& hidden_states,
165+
const at::Tensor& expert_mask,
166+
const std::vector<at::Tensor>& gate_wei,
167+
const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
168+
const std::vector<at::Tensor>& up_wei,
169+
const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
170+
const std::vector<at::Tensor>& down_wei,
171+
const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
172+
const at::Tensor& routing_weights,
173+
at::Tensor& output,
174+
bool is_distributed);
78175
IPEX_DECLARE_DISPATCH(mixtral_moe_tpp_kernel_fn, mixtral_moe_tpp_kernel_stub);
176+
IPEX_DECLARE_DISPATCH(deepseek_moe_tpp_kernel_fn, deepseek_moe_tpp_kernel_stub);
79177
IPEX_DECLARE_DISPATCH(mixtral_moe_woq_kernel_fn, mixtral_moe_woq_kernel_stub);
178+
IPEX_DECLARE_DISPATCH(deepseek_moe_woq_kernel_fn, deepseek_moe_woq_kernel_stub);
80179
IPEX_DECLARE_DISPATCH(mixtral_moe_kernel_fn, mixtral_moe_kernel_stub);
180+
IPEX_DECLARE_DISPATCH(deepseek_moe_kernel_fn, deepseek_moe_kernel_stub);
181+
IPEX_DECLARE_DISPATCH(deepseek_moe_mkl_kernel_fn, deepseek_moe_mkl_kernel_stub);
81182
} // namespace cpu
82183
} // namespace torch_ipex

examples/cpu/llm/inference/distributed/run_accuracy_with_deepspeed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,12 +400,14 @@ def get_repo_root(model_name_or_path):
400400

401401
def get_checkpoint_files(model_name_or_path):
402402
cached_repo_dir = get_repo_root(model_name_or_path)
403-
403+
glob_pattern = "*.[bp][it][n]"
404+
if re.search("deepseek-v2", model_name_or_path, re.IGNORECASE):
405+
glob_pattern = "*.[sbp][ait][fn][e][t][e][n][s][o][r][s]"
404406
# extensions: .bin | .pt
405407
# creates a list of paths from all downloaded files in cache dir
406408
file_list = [
407409
str(entry)
408-
for entry in Path(cached_repo_dir).rglob("*.[bp][it][n]")
410+
for entry in Path(cached_repo_dir).rglob(glob_pattern)
409411
if entry.is_file()
410412
]
411413
return file_list

0 commit comments

Comments
 (0)