@@ -182,73 +182,60 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
182
182
}
183
183
};
184
184
185
- class DifferentiableFIR : public torch ::autograd::Function<DifferentiableFIR> {
186
- public:
187
- static torch::Tensor forward (
188
- torch::autograd::AutogradContext* ctx,
189
- const torch::Tensor& waveform,
190
- const torch::Tensor& b_coeffs) {
191
- int64_t n_order = b_coeffs.size (1 );
192
- int64_t n_channel = b_coeffs.size (0 );
193
-
194
- namespace F = torch::nn::functional;
195
- auto b_coeff_flipped = b_coeffs.flip (1 ).contiguous ();
196
- auto padded_waveform =
197
- F::pad (waveform, F::PadFuncOptions ({n_order - 1 , 0 }));
198
-
199
- auto output = F::conv1d (
200
- padded_waveform,
201
- b_coeff_flipped.unsqueeze (1 ),
202
- F::Conv1dFuncOptions ().groups (n_channel));
203
-
204
- ctx->save_for_backward ({waveform, b_coeffs, output});
205
- return output;
206
- }
207
-
208
- static torch::autograd::tensor_list backward (
209
- torch::autograd::AutogradContext* ctx,
210
- torch::autograd::tensor_list grad_outputs) {
211
- auto saved = ctx->get_saved_variables ();
212
- auto x = saved[0 ];
213
- auto b_coeffs = saved[1 ];
214
- auto y = saved[2 ];
215
-
216
- int64_t n_batch = x.size (0 );
217
- int64_t n_channel = x.size (1 );
218
- int64_t n_order = b_coeffs.size (1 );
185
+ // FIR filter forward and backward functions (no autograd inheritance)
186
+ torch::Tensor fir_forward (
187
+ const torch::Tensor& waveform,
188
+ const torch::Tensor& b_coeffs) {
189
+ int64_t n_order = b_coeffs.size (1 );
190
+ int64_t n_channel = b_coeffs.size (0 );
219
191
220
- auto dx = torch::Tensor ();
221
- auto db = torch::Tensor ();
222
- auto dy = grad_outputs[0 ];
192
+ namespace F = torch::nn::functional;
193
+ auto b_coeff_flipped = b_coeffs.flip (1 ).contiguous ();
194
+ auto padded_waveform =
195
+ F::pad (waveform, F::PadFuncOptions ({n_order - 1 , 0 }));
223
196
224
- namespace F = torch::nn::functional;
197
+ auto output = F::conv1d (
198
+ padded_waveform,
199
+ b_coeff_flipped.unsqueeze (1 ),
200
+ F::Conv1dFuncOptions ().groups (n_channel));
225
201
226
- if (b_coeffs.requires_grad ()) {
227
- db = F::conv1d (
228
- F::pad (x, F::PadFuncOptions ({n_order - 1 , 0 }))
229
- .view ({1 , n_batch * n_channel, -1 }),
230
- dy.view ({n_batch * n_channel, 1 , -1 }),
231
- F::Conv1dFuncOptions ().groups (n_batch * n_channel))
232
- .view ({n_batch, n_channel, -1 })
233
- .sum (0 )
234
- .flip (1 );
235
- }
202
+ return output;
203
+ }
236
204
237
- if (x.requires_grad ()) {
238
- dx = F::conv1d (
239
- F::pad (dy, F::PadFuncOptions ({0 , n_order - 1 })),
240
- b_coeffs.unsqueeze (1 ),
241
- F::Conv1dFuncOptions ().groups (n_channel));
242
- }
205
+ std::tuple<torch::Tensor, torch::Tensor> fir_backward (
206
+ const torch::Tensor& grad_output,
207
+ const torch::Tensor& waveform,
208
+ const torch::Tensor& b_coeffs) {
209
+ int64_t n_batch = waveform.size (0 );
210
+ int64_t n_channel = waveform.size (1 );
211
+ int64_t n_order = b_coeffs.size (1 );
212
+
213
+ auto dx = torch::Tensor ();
214
+ auto db = torch::Tensor ();
215
+
216
+ namespace F = torch::nn::functional;
217
+
218
+ // Compute gradient w.r.t. b_coeffs
219
+ if (b_coeffs.requires_grad ()) {
220
+ db = F::conv1d (
221
+ F::pad (waveform, F::PadFuncOptions ({n_order - 1 , 0 }))
222
+ .view ({1 , n_batch * n_channel, -1 }),
223
+ grad_output.view ({n_batch * n_channel, 1 , -1 }),
224
+ F::Conv1dFuncOptions ().groups (n_batch * n_channel))
225
+ .view ({n_batch, n_channel, -1 })
226
+ .sum (0 )
227
+ .flip (1 );
228
+ }
243
229
244
- return {dx, db};
230
+ // Compute gradient w.r.t. waveform
231
+ if (waveform.requires_grad ()) {
232
+ dx = F::conv1d (
233
+ F::pad (grad_output, F::PadFuncOptions ({0 , n_order - 1 })),
234
+ b_coeffs.unsqueeze (1 ),
235
+ F::Conv1dFuncOptions ().groups (n_channel));
245
236
}
246
- };
247
237
248
- torch::Tensor differentiable_fir_apply (
249
- const torch::Tensor& waveform,
250
- const torch::Tensor& b_coeffs) {
251
- return DifferentiableFIR::apply (waveform, b_coeffs);
238
+ return std::make_tuple (dx, db);
252
239
}
253
240
254
241
torch::Tensor differentiable_iir_apply (
@@ -267,14 +254,15 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
267
254
268
255
TORCH_LIBRARY (torchaudio, m) {
269
256
m.def (
270
- " torchaudio::_lfilter (Tensor waveform, Tensor a_coeffs, Tensor b_coeffs ) -> Tensor" );
257
+ " torchaudio::_differentiable_iir_apply (Tensor waveform, Tensor a_coeffs_normalized ) -> Tensor" );
271
258
m.def (
272
- " torchaudio::_differentiable_fir_apply (Tensor waveform, Tensor b_coeffs) -> Tensor" );
259
+ " torchaudio::_fir_forward (Tensor waveform, Tensor b_coeffs) -> Tensor" );
273
260
m.def (
274
- " torchaudio::_differentiable_iir_apply (Tensor waveform, Tensor a_coeffs_normalized ) -> Tensor" );
261
+ " torchaudio::_fir_backward (Tensor grad_output, Tensor waveform, Tensor b_coeffs ) -> ( Tensor, Tensor) " );
275
262
}
276
263
277
264
TORCH_LIBRARY_IMPL (torchaudio, CompositeImplicitAutograd, m) {
278
- m.impl (" torchaudio::_differentiable_fir_apply" , differentiable_fir_apply);
279
265
m.impl (" torchaudio::_differentiable_iir_apply" , differentiable_iir_apply);
266
+ m.impl (" torchaudio::_fir_forward" , fir_forward);
267
+ m.impl (" torchaudio::_fir_backward" , fir_backward);
280
268
}
0 commit comments