Skip to content

Commit a3fe94e

Browse files
samanklesariaSam AnklesariaNicolasHug
authored
Add missing deprecations (#3959)
Co-authored-by: Sam Anklesaria <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 70caf76 commit a3fe94e

File tree

7 files changed

+54
-20
lines changed

7 files changed

+54
-20
lines changed

src/torchaudio/functional/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
preemphasis,
5959
psd,
6060
resample,
61-
rnnt_loss as _rnnt_loss,
61+
rnnt_loss,
6262
rtf_evd,
6363
rtf_power,
6464
sliding_window_cmn,
@@ -67,7 +67,6 @@
6767
speed,
6868
)
6969

70-
rnnt_loss = dropping_support(_rnnt_loss)
7170
__all__ = [
7271
"amplitude_to_DB",
7372
"compute_deltas",

src/torchaudio/functional/filtering.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch import Tensor
77

88
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
9+
from torchaudio._internal.module_utils import dropping_support
910

1011

1112
def _dB2Linear(x: float) -> float:
@@ -324,7 +325,7 @@ def biquad(waveform: Tensor, b0: float, b1: float, b2: float, a0: float, a1: flo
324325
a1 = torch.as_tensor(a1, dtype=dtype, device=device).view(1)
325326
a2 = torch.as_tensor(a2, dtype=dtype, device=device).view(1)
326327

327-
output_waveform = lfilter(
328+
output_waveform = _lfilter_deprecated(
328329
waveform,
329330
torch.cat([a0, a1, a2]),
330331
torch.cat([b0, b1, b2]),
@@ -698,8 +699,8 @@ def filtfilt(
698699
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
699700
are 2D Tensors, or `(..., time)` otherwise.
700701
"""
701-
forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
702-
backward_filtered = lfilter(
702+
forward_filtered = _lfilter_deprecated(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
703+
backward_filtered = _lfilter_deprecated(
703704
forward_filtered.flip(-1),
704705
a_coeffs,
705706
b_coeffs,
@@ -997,7 +998,7 @@ def _lfilter_core(
997998
_lfilter = _lfilter_core
998999

9991000

1000-
def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
1001+
def _lfilter_deprecated(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
10011002
r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
10021003
developed separately by *Yu et al.* :cite:`ismir_YuF23` and *Forgione et al.* :cite:`forgione2021dynonet`.
10031004
The gradients of ``a_coeffs`` are computed based on a faster algorithm from :cite:`ycy2024diffapf`.
@@ -1066,6 +1067,7 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =
10661067

10671068
return output
10681069

1070+
lfilter = dropping_support(_lfilter_deprecated)
10691071

10701072
def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
10711073
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
@@ -1115,7 +1117,7 @@ def _overdrive_core_loop_generic(
11151117
_overdrive_core_loop_cpu = _overdrive_core_loop_generic
11161118

11171119

1118-
def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
1120+
def _overdrive_deprecated(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
11191121
r"""Apply a overdrive effect to the audio. Similar to SoX implementation.
11201122
11211123
.. devices:: CPU CUDA
@@ -1170,6 +1172,7 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor:
11701172

11711173
return output_waveform.clamp(min=-1, max=1).view(actual_shape)
11721174

1175+
overdrive = dropping_support(_overdrive_deprecated)
11731176

11741177
def phaser(
11751178
waveform: Tensor,

src/torchaudio/functional/functional.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch
1010
import torchaudio
1111
from torch import Tensor
12-
from torchaudio._internal.module_utils import deprecated
12+
from torchaudio._internal.module_utils import deprecated, dropping_support
13+
1314

1415
from .filtering import highpass_biquad, treble_biquad
1516

@@ -1760,7 +1761,7 @@ def _fix_waveform_shape(
17601761
return waveform_shift
17611762

17621763

1763-
def rnnt_loss(
1764+
def _rnnt_loss(
17641765
logits: Tensor,
17651766
targets: Tensor,
17661767
logit_lengths: Tensor,
@@ -1864,6 +1865,9 @@ def psd(
18641865
psd = psd.sum(dim=-3)
18651866
return psd
18661867

1868+
# Expose both deprecated wrapper as well as original because torchscript breaks on
1869+
# wrapped functions.
1870+
rnnt_loss = dropping_support(_rnnt_loss)
18671871

18681872
def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
18691873
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
@@ -2472,7 +2476,7 @@ def preemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
24722476
return waveform
24732477

24742478

2475-
def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
2479+
def _deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
24762480
r"""De-emphasizes a waveform along its last dimension.
24772481
Inverse of :meth:`preemphasis`. Concretely, for each signal
24782482
:math:`x` in ``waveform``, computes output :math:`y` as
@@ -2494,8 +2498,9 @@ def deemphasis(waveform, coeff: float = 0.97) -> torch.Tensor:
24942498
"""
24952499
a_coeffs = torch.tensor([1.0, -coeff], dtype=waveform.dtype, device=waveform.device)
24962500
b_coeffs = torch.tensor([1.0, 0.0], dtype=waveform.dtype, device=waveform.device)
2497-
return torchaudio.functional.lfilter(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
2501+
return torchaudio.functional.filtering._lfilter_deprecated(waveform, a_coeffs=a_coeffs, b_coeffs=b_coeffs)
24982502

2503+
deemphasis = dropping_support(_deemphasis)
24992504

25002505
def frechet_distance(mu_x, sigma_x, mu_y, sigma_y):
25012506
r"""Computes the Fréchet distance between two multivariate normal distributions :cite:`dowson1982frechet`.

src/torchaudio/models/decoder/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from torchaudio._internal.module_utils import dropping_support
1+
from torchaudio._internal.module_utils import dropping_support, dropping_class_support
2+
import inspect
23
_CTC_DECODERS = [
34
"CTCHypothesis",
45
"CTCDecoder",
@@ -34,7 +35,11 @@ def __getattr__(name: str):
3435
"To use CUCTC decoder, please set BUILD_CUDA_CTC_DECODER=1 when building from source."
3536
) from err
3637

37-
item = dropping_support(getattr(_cuda_ctc_decoder, name))
38+
orig_item = getattr(_cuda_ctc_decoder, name)
39+
if inspect.isclass(orig_item):
40+
item = dropping_class_support(orig_item)
41+
else:
42+
item = dropping_support(orig_item)
3843
globals()[name] = item
3944
return item
4045
raise AttributeError(f"module {__name__} has no attribute {name}")

src/torchaudio/transforms/_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.nn.parameter import UninitializedParameter
1111

1212
from torchaudio import functional as F
13-
from torchaudio.functional.functional import rnnt_loss
13+
from torchaudio.functional.functional import _rnnt_loss
1414
from torchaudio.functional.functional import (
1515
_apply_sinc_resample_kernel,
1616
_check_convolve_mode,
@@ -1847,7 +1847,7 @@ def forward(
18471847
Tensor: Loss with the reduction option applied. If ``reduction`` is ``"none"``, then size (batch),
18481848
otherwise scalar.
18491849
"""
1850-
return rnnt_loss(
1850+
return _rnnt_loss(
18511851
logits,
18521852
targets,
18531853
logit_lengths,
@@ -2135,4 +2135,4 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
21352135
Returns:
21362136
torch.Tensor: De-emphasized waveform, with shape `(..., N)`.
21372137
"""
2138-
return F.deemphasis(waveform, coeff=self.coeff)
2138+
return F.functional._deemphasis(waveform, coeff=self.coeff)

test/torchaudio_unittest/common_utils/func_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
def torch_script(obj):
77
"""TorchScript the given function or Module"""
88
buffer = io.BytesIO()
9+
if hasattr(obj, '__wrapped__'):
10+
# This is hack for those functions which are deprecated with decorators
11+
# like @deprecated or @dropping_support. Adding the decorators breaks
12+
# TorchScript. We need to unwrap the function to get the original one,
13+
# which make the tests pass, but that's a lie: the public (deprecated)
14+
# function doesn't support torchscript anymore
15+
obj = obj.__wrapped__
916
torch.jit.save(torch.jit.script(obj), buffer)
1017
buffer.seek(0)
1118
return torch.jit.load(buffer)

test/torchaudio_unittest/functional/torchscript_consistency_impl.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,12 @@ def test_lfilter(self):
285285
device=waveform.device,
286286
dtype=waveform.dtype,
287287
)
288-
self._assert_consistency(F.lfilter, (waveform, a_coeffs, b_coeffs, True, True))
288+
# This is hack for those functions which are deprecated with decorators
289+
# like @deprecated or @dropping_support. Adding the decorators breaks
290+
# TorchScript. So here we use the private function which make the tests
291+
# pass, but that's a lie: the public (deprecated) function doesn't
292+
# support torchscript anymore
293+
self._assert_consistency(F.filtering._lfilter_deprecated, (waveform, a_coeffs, b_coeffs, True, True))
289294

290295
def test_filtfilt(self):
291296
waveform = common_utils.get_whitenoise(sample_rate=8000)
@@ -485,7 +490,7 @@ def test_perf_biquad_filtering(self):
485490
def func(tensor):
486491
a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
487492
b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
488-
return F.lfilter(tensor, a, b)
493+
return F.filtering._lfilter_deprecated(tensor, a, b)
489494

490495
self._assert_consistency(func, (waveform,))
491496

@@ -530,7 +535,12 @@ def test_overdrive(self):
530535
def func(tensor):
531536
gain = 30.0
532537
colour = 50.0
533-
return F.overdrive(tensor, gain, colour)
538+
# This is hack for those functions which are deprecated with decorators
539+
# like @deprecated or @dropping_support. Adding the decorators breaks
540+
# TorchScript. So here we use the private function which make the tests
541+
# pass, but that's a lie: the public (deprecated) function doesn't
542+
# support torchscript anymore
543+
return F.filtering._overdrive_deprecated(tensor, gain, colour)
534544

535545
self._assert_consistency(func, (waveform,))
536546

@@ -803,7 +813,12 @@ def func(tensor):
803813
targets = torch.tensor([[1, 2]], device=tensor.device, dtype=torch.int32)
804814
logit_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
805815
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
806-
return rnnt_loss(tensor, targets, logit_lengths, target_lengths)
816+
# This is hack for those functions which are deprecated with decorators
817+
# like @deprecated or @dropping_support. Adding the decorators breaks
818+
# TorchScript. So here we use the private function which make the tests
819+
# pass, but that's a lie: the public (deprecated) function doesn't
820+
# support torchscript anymore
821+
return F.functional._rnnt_loss(tensor, targets, logit_lengths, target_lengths)
807822

808823
logits = torch.tensor(
809824
[

0 commit comments

Comments
 (0)