@@ -7,6 +7,7 @@ namespace cpu {
7
7
8
8
IPEX_DEFINE_DISPATCH (mixtral_moe_tpp_kernel_stub);
9
9
IPEX_DEFINE_DISPATCH (mixtral_moe_woq_kernel_stub);
10
+ IPEX_DEFINE_DISPATCH (deepseek_moe_woq_kernel_stub);
10
11
IPEX_DEFINE_DISPATCH (mixtral_moe_kernel_stub);
11
12
12
13
at::Tensor mixtral_moe_tpp (
@@ -38,6 +39,41 @@ at::Tensor mixtral_moe_tpp(
38
39
is_distributed);
39
40
}
40
41
42
+ at::Tensor deepseek_moe_tpp (
43
+ const at::Tensor& hidden_states,
44
+ const at::Tensor& expert_mask,
45
+ const std::vector<at::Tensor>& gate_wei,
46
+ const std::vector<at::Tensor>& up_wei,
47
+ const std::vector<at::Tensor>& down_wei,
48
+ bool tpp_fallback,
49
+ const at::Tensor& routing_weights,
50
+ at::Tensor& output,
51
+ bool is_distributed) {
52
+ RECORD_FUNCTION (" ipex::deepseek_moe_tpp" , c10::ArrayRef<c10::IValue>({}));
53
+
54
+ int num_experts = gate_wei.size ();
55
+ for (auto i = 0 ; i < num_experts; i++) {
56
+ auto non_zero = expert_mask[i].nonzero ();
57
+ if (non_zero.sizes ()[0 ] == 0 )
58
+ continue ;
59
+ auto idx = non_zero.select (1 , 0 );
60
+ auto top_x = non_zero.select (1 , 1 );
61
+ output = mixtral_moe_tpp_kernel_stub (
62
+ kCPU ,
63
+ hidden_states,
64
+ top_x,
65
+ idx,
66
+ gate_wei[i],
67
+ up_wei[i],
68
+ down_wei[i],
69
+ tpp_fallback,
70
+ routing_weights,
71
+ output,
72
+ is_distributed);
73
+ }
74
+ return output;
75
+ }
76
+
41
77
at::Tensor mixtral_moe (
42
78
const at::Tensor& hidden_states,
43
79
const at::Tensor& top_x,
@@ -72,6 +108,87 @@ at::Tensor mixtral_moe(
72
108
output,
73
109
is_distributed);
74
110
}
111
+
112
+ at::Tensor deepseek_moe (
113
+ const at::Tensor& hidden_states,
114
+ const at::Tensor& expert_mask,
115
+ const std::vector<at::Tensor>& gate_wei,
116
+ const std::vector<c10::intrusive_ptr<LinearOpContext>>& gate_op_ctx,
117
+ const std::vector<at::Tensor>& up_wei,
118
+ const std::vector<c10::intrusive_ptr<LinearOpContext>>& up_op_ctx,
119
+ const std::vector<at::Tensor>& down_wei,
120
+ const std::vector<c10::intrusive_ptr<LinearOpContext>>& down_op_ctx,
121
+ const at::Tensor& routing_weights,
122
+ at::Tensor& output,
123
+ bool is_distributed) {
124
+ RECORD_FUNCTION (" ipex::deepseek_moe" , c10::ArrayRef<c10::IValue>({}));
125
+
126
+ int num_experts = gate_wei.size ();
127
+ for (auto i = 0 ; i < num_experts; i++) {
128
+ auto non_zero = expert_mask[i].nonzero ();
129
+ if (non_zero.sizes ()[0 ] == 0 )
130
+ continue ;
131
+ auto idx = non_zero.select (1 , 0 );
132
+ auto top_x = non_zero.select (1 , 1 );
133
+
134
+ output = mixtral_moe_kernel_stub (
135
+ kCPU ,
136
+ hidden_states,
137
+ top_x,
138
+ idx,
139
+ gate_wei[i],
140
+ gate_op_ctx[i]->get_data_handle (),
141
+ up_wei[i],
142
+ up_op_ctx[i]->get_data_handle (),
143
+ down_wei[i],
144
+ down_op_ctx[i]->get_data_handle (),
145
+ true ,
146
+ routing_weights,
147
+ output,
148
+ is_distributed);
149
+ }
150
+ return output;
151
+ }
152
+
153
+ at::Tensor deepseek_moe_mkl (
154
+ const at::Tensor& hidden_states,
155
+ const at::Tensor& expert_mask,
156
+ const std::vector<at::Tensor>& gate_wei,
157
+ const std::vector<c10::intrusive_ptr<MKLOpContext>>& gate_op_ctx,
158
+ const std::vector<at::Tensor>& up_wei,
159
+ const std::vector<c10::intrusive_ptr<MKLOpContext>>& up_op_ctx,
160
+ const std::vector<at::Tensor>& down_wei,
161
+ const std::vector<c10::intrusive_ptr<MKLOpContext>>& down_op_ctx,
162
+ const at::Tensor& routing_weights,
163
+ at::Tensor& output,
164
+ bool is_distributed) {
165
+ RECORD_FUNCTION (" ipex::deepseek_moe_mkl" , c10::ArrayRef<c10::IValue>({}));
166
+
167
+ int num_experts = gate_wei.size ();
168
+ for (auto i = 0 ; i < num_experts; i++) {
169
+ auto non_zero = expert_mask[i].nonzero ();
170
+ if (non_zero.sizes ()[0 ] == 0 )
171
+ continue ;
172
+ auto idx = non_zero.select (1 , 0 );
173
+ auto top_x = non_zero.select (1 , 1 );
174
+ output = mixtral_moe_kernel_stub (
175
+ kCPU ,
176
+ hidden_states,
177
+ top_x,
178
+ idx,
179
+ gate_wei[i],
180
+ gate_op_ctx[i]->get_data_handle (),
181
+ up_wei[i],
182
+ up_op_ctx[i]->get_data_handle (),
183
+ down_wei[i],
184
+ down_op_ctx[i]->get_data_handle (),
185
+ false ,
186
+ routing_weights,
187
+ output,
188
+ is_distributed);
189
+ }
190
+ return output;
191
+ }
75
192
at::Tensor mixtral_moe_woq (
76
193
const at::Tensor& hidden_states,
77
194
const at::Tensor& top_x,
@@ -98,6 +215,38 @@ at::Tensor mixtral_moe_woq(
98
215
output,
99
216
is_distributed);
100
217
}
218
+ at::Tensor deepseek_moe_woq (
219
+ const at::Tensor& hidden_states,
220
+ const at::Tensor& expert_mask,
221
+ const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& gate_ctx,
222
+ const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& up_ctx,
223
+ const std::vector<c10::intrusive_ptr<WoqLinearOpContext>>& down_ctx,
224
+ const at::Tensor& routing_weights,
225
+ at::Tensor& output,
226
+ bool is_distributed) {
227
+ RECORD_FUNCTION (" ipex::deepseek_moe_woq" , c10::ArrayRef<c10::IValue>({}));
228
+
229
+ int num_experts = gate_ctx.size ();
230
+ for (auto i = 0 ; i < num_experts; i++) {
231
+ auto non_zero = expert_mask[i].nonzero ();
232
+ if (non_zero.sizes ()[0 ] == 0 )
233
+ continue ;
234
+ auto idx = non_zero.select (1 , 0 );
235
+ auto top_x = non_zero.select (1 , 1 );
236
+ output = mixtral_moe_woq_kernel_stub (
237
+ kCPU ,
238
+ hidden_states,
239
+ top_x,
240
+ idx,
241
+ gate_ctx[i]->get_data_handle (),
242
+ up_ctx[i]->get_data_handle (),
243
+ down_ctx[i]->get_data_handle (),
244
+ routing_weights,
245
+ output,
246
+ is_distributed);
247
+ }
248
+ return output;
249
+ }
101
250
} // namespace cpu
102
251
} // namespace torch_ipex
103
252
@@ -112,17 +261,53 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
112
261
" mixtral_moe_tpp" ,
113
262
c10::DispatchKey::CPU,
114
263
torch_ipex::cpu::mixtral_moe_tpp);
264
+ m.def (
265
+ " deepseek_moe_tpp(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
266
+ Tensor[] up_wei, Tensor[] down_wei, bool tpp_fallback, Tensor routing_weights, \
267
+ Tensor output, bool is_distributed) -> Tensor" );
268
+ m.impl (
269
+ " deepseek_moe_tpp" ,
270
+ c10::DispatchKey::CPU,
271
+ torch_ipex::cpu::deepseek_moe_tpp);
115
272
m.def (
116
273
" mixtral_moe(Tensor hidden_states, Tensor top_x, Tensor idx, Tensor gate_wei, \
117
274
Tensor gate_op_ctx, Tensor up_wei, Tensor up_op_ctx, Tensor down_wei, \
118
275
Tensor down_op_ctx, bool use_dnnl, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
119
276
m.impl (" mixtral_moe" , c10::DispatchKey::CPU, torch_ipex::cpu::mixtral_moe);
277
+ m.def (
278
+ " deepseek_moe(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
279
+ __torch__.torch.classes.ipex_prepack.LinearOpContext[] gate_op_ctx, Tensor[] up_wei, \
280
+ __torch__.torch.classes.ipex_prepack.LinearOpContext[] up_op_ctx, Tensor[] down_wei, \
281
+ __torch__.torch.classes.ipex_prepack.LinearOpContext[] down_op_ctx, Tensor routing_weights, \
282
+ Tensor output, bool is_distributed) -> Tensor" );
283
+ m.impl (" deepseek_moe" , c10::DispatchKey::CPU, torch_ipex::cpu::deepseek_moe);
284
+ m.def (
285
+ " deepseek_moe_mkl(Tensor hidden_states, Tensor expert_mask, Tensor[] gate_wei, \
286
+ __torch__.torch.classes.ipex_prepack.MKLOpContext[] gate_op_ctx, Tensor[] up_wei, \
287
+ __torch__.torch.classes.ipex_prepack.MKLOpContext[] up_op_ctx, \
288
+ Tensor[] down_wei, __torch__.torch.classes.ipex_prepack.MKLOpContext[] down_op_ctx, \
289
+ Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
290
+ m.impl (
291
+ " deepseek_moe_mkl" ,
292
+ c10::DispatchKey::CPU,
293
+ torch_ipex::cpu::deepseek_moe_mkl);
120
294
m.def (
121
295
" mixtral_moe_woq(Tensor hidden_states, Tensor top_x, Tensor idx, Tensor gate_wei, \
122
296
Tensor up_wei, Tensor down_wei, Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
123
297
m.impl (
124
298
" mixtral_moe_woq" ,
125
299
c10::DispatchKey::CPU,
126
300
torch_ipex::cpu::mixtral_moe_woq);
301
+ m.def (
302
+ " deepseek_moe_woq(Tensor hidden_states, Tensor expert_mask, \
303
+ __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] gate_ctx, \
304
+ __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] up_ctx, \
305
+ __torch__.torch.classes.ipex_prepack.WoqLinearOpContext[] down_ctx, \
306
+ Tensor routing_weights, Tensor output, bool is_distributed) -> Tensor" );
307
+
308
+ m.impl (
309
+ " deepseek_moe_woq" ,
310
+ c10::DispatchKey::CPU,
311
+ torch_ipex::cpu::deepseek_moe_woq);
127
312
}
128
313
} // namespace
0 commit comments