Skip to content

Commit e18793e

Browse files
committed
Remove autograd from FIR
1 parent c4bd42f commit e18793e

File tree

2 files changed

+83
-71
lines changed

2 files changed

+83
-71
lines changed

src/libtorchaudio/lfilter.cpp

Lines changed: 52 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -182,73 +182,60 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
182182
}
183183
};
184184

185-
class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
186-
public:
187-
static torch::Tensor forward(
188-
torch::autograd::AutogradContext* ctx,
189-
const torch::Tensor& waveform,
190-
const torch::Tensor& b_coeffs) {
191-
int64_t n_order = b_coeffs.size(1);
192-
int64_t n_channel = b_coeffs.size(0);
193-
194-
namespace F = torch::nn::functional;
195-
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
196-
auto padded_waveform =
197-
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
198-
199-
auto output = F::conv1d(
200-
padded_waveform,
201-
b_coeff_flipped.unsqueeze(1),
202-
F::Conv1dFuncOptions().groups(n_channel));
203-
204-
ctx->save_for_backward({waveform, b_coeffs, output});
205-
return output;
206-
}
207-
208-
static torch::autograd::tensor_list backward(
209-
torch::autograd::AutogradContext* ctx,
210-
torch::autograd::tensor_list grad_outputs) {
211-
auto saved = ctx->get_saved_variables();
212-
auto x = saved[0];
213-
auto b_coeffs = saved[1];
214-
auto y = saved[2];
215-
216-
int64_t n_batch = x.size(0);
217-
int64_t n_channel = x.size(1);
218-
int64_t n_order = b_coeffs.size(1);
185+
// FIR filter forward and backward functions (no autograd inheritance)
186+
torch::Tensor fir_forward(
187+
const torch::Tensor& waveform,
188+
const torch::Tensor& b_coeffs) {
189+
int64_t n_order = b_coeffs.size(1);
190+
int64_t n_channel = b_coeffs.size(0);
219191

220-
auto dx = torch::Tensor();
221-
auto db = torch::Tensor();
222-
auto dy = grad_outputs[0];
192+
namespace F = torch::nn::functional;
193+
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
194+
auto padded_waveform =
195+
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
223196

224-
namespace F = torch::nn::functional;
197+
auto output = F::conv1d(
198+
padded_waveform,
199+
b_coeff_flipped.unsqueeze(1),
200+
F::Conv1dFuncOptions().groups(n_channel));
225201

226-
if (b_coeffs.requires_grad()) {
227-
db = F::conv1d(
228-
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
229-
.view({1, n_batch * n_channel, -1}),
230-
dy.view({n_batch * n_channel, 1, -1}),
231-
F::Conv1dFuncOptions().groups(n_batch * n_channel))
232-
.view({n_batch, n_channel, -1})
233-
.sum(0)
234-
.flip(1);
235-
}
202+
return output;
203+
}
236204

237-
if (x.requires_grad()) {
238-
dx = F::conv1d(
239-
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
240-
b_coeffs.unsqueeze(1),
241-
F::Conv1dFuncOptions().groups(n_channel));
242-
}
205+
std::tuple<torch::Tensor, torch::Tensor> fir_backward(
206+
const torch::Tensor& grad_output,
207+
const torch::Tensor& waveform,
208+
const torch::Tensor& b_coeffs) {
209+
int64_t n_batch = waveform.size(0);
210+
int64_t n_channel = waveform.size(1);
211+
int64_t n_order = b_coeffs.size(1);
212+
213+
auto dx = torch::Tensor();
214+
auto db = torch::Tensor();
215+
216+
namespace F = torch::nn::functional;
217+
218+
// Compute gradient w.r.t. b_coeffs
219+
if (b_coeffs.requires_grad()) {
220+
db = F::conv1d(
221+
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}))
222+
.view({1, n_batch * n_channel, -1}),
223+
grad_output.view({n_batch * n_channel, 1, -1}),
224+
F::Conv1dFuncOptions().groups(n_batch * n_channel))
225+
.view({n_batch, n_channel, -1})
226+
.sum(0)
227+
.flip(1);
228+
}
243229

244-
return {dx, db};
230+
// Compute gradient w.r.t. waveform
231+
if (waveform.requires_grad()) {
232+
dx = F::conv1d(
233+
F::pad(grad_output, F::PadFuncOptions({0, n_order - 1})),
234+
b_coeffs.unsqueeze(1),
235+
F::Conv1dFuncOptions().groups(n_channel));
245236
}
246-
};
247237

248-
torch::Tensor differentiable_fir_apply(
249-
const torch::Tensor& waveform,
250-
const torch::Tensor& b_coeffs) {
251-
return DifferentiableFIR::apply(waveform, b_coeffs);
238+
return std::make_tuple(dx, db);
252239
}
253240

254241
torch::Tensor differentiable_iir_apply(
@@ -267,14 +254,15 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
267254

268255
TORCH_LIBRARY(torchaudio, m) {
269256
m.def(
270-
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
257+
"torchaudio::_differentiable_iir_apply(Tensor waveform, Tensor a_coeffs_normalized) -> Tensor");
271258
m.def(
272-
"torchaudio::_differentiable_fir_apply(Tensor waveform, Tensor b_coeffs) -> Tensor");
259+
"torchaudio::_fir_forward(Tensor waveform, Tensor b_coeffs) -> Tensor");
273260
m.def(
274-
"torchaudio::_differentiable_iir_apply(Tensor waveform, Tensor a_coeffs_normalized) -> Tensor");
261+
"torchaudio::_fir_backward(Tensor grad_output, Tensor waveform, Tensor b_coeffs) -> (Tensor, Tensor)");
275262
}
276263

277264
TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
278-
m.impl("torchaudio::_differentiable_fir_apply", differentiable_fir_apply);
279265
m.impl("torchaudio::_differentiable_iir_apply", differentiable_iir_apply);
266+
m.impl("torchaudio::_fir_forward", fir_forward);
267+
m.impl("torchaudio::_fir_backward", fir_backward);
280268
}

src/torchaudio/functional/filtering.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -933,8 +933,9 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
933933

934934
if _IS_TORCHAUDIO_EXT_AVAILABLE:
935935
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
936-
_differentiable_fir_apply = torch.ops.torchaudio._differentiable_fir_apply
937936
_differentiable_iir_apply = torch.ops.torchaudio._differentiable_iir_apply
937+
_fir_forward = torch.ops.torchaudio._fir_forward
938+
_fir_backward = torch.ops.torchaudio._fir_backward
938939
else:
939940
_lfilter_core_cpu_loop = _lfilter_core_generic_loop
940941

@@ -993,8 +994,26 @@ def _lfilter_core(
993994
return output
994995

995996

996-
# TODO find a better name for this, possibly renaming the existing `_lfilter_core`
997-
def _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR(
997+
class _DifferentiableFIRFunction(torch.autograd.Function):
998+
@staticmethod
999+
def forward(ctx, waveform, b_coeffs):
1000+
ctx.save_for_backward(waveform, b_coeffs)
1001+
1002+
output = _fir_forward(waveform, b_coeffs)
1003+
return output
1004+
1005+
@staticmethod
1006+
def backward(ctx, grad_output):
1007+
# Retrieve saved inputs
1008+
waveform, b_coeffs = ctx.saved_tensors
1009+
1010+
# Call C++ backward function
1011+
grad_waveform, grad_b_coeffs = _fir_backward(grad_output, waveform, b_coeffs)
1012+
1013+
return grad_waveform, grad_b_coeffs
1014+
1015+
1016+
def _lfilter_core_python(
9981017
waveform: Tensor,
9991018
a_coeffs: Tensor,
10001019
b_coeffs: Tensor,
@@ -1006,13 +1025,18 @@ def _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR(
10061025
a0 = a_coeffs[:, 0:1] # Keep dimension for broadcasting
10071026
b_coeffs_normalized = b_coeffs / a0
10081027
a_coeffs_normalized = a_coeffs / a0
1009-
1010-
filtered_waveform = _differentiable_fir_apply(waveform, b_coeffs_normalized)
1011-
return _differentiable_iir_apply(filtered_waveform, a_coeffs_normalized)
1028+
1029+
# Apply FIR filter using Python autograd function
1030+
filtered_waveform = _DifferentiableFIRFunction.apply(waveform, b_coeffs_normalized)
1031+
1032+
# Apply IIR filter (still using C++ autograd)
1033+
output = _differentiable_iir_apply(filtered_waveform, a_coeffs_normalized)
1034+
1035+
return output
10121036

10131037

10141038
if _IS_TORCHAUDIO_EXT_AVAILABLE:
1015-
_lfilter = _lfilter_core_in_python_calling_into_cpp_FIR_and_IIR
1039+
_lfilter = _lfilter_core_python
10161040
else:
10171041
_lfilter = _lfilter_core
10181042

0 commit comments

Comments
 (0)