From 0990591b41fda3ac1f00a2e74eb51d8888ee7290 Mon Sep 17 00:00:00 2001 From: Oscar Benjamin Date: Mon, 2 Oct 2023 17:22:56 +0100 Subject: [PATCH] Fix pow(int, int, fmpz) --- src/flint/test/test.py | 4 +++ src/flint/types/fmpz.pyx | 57 ++++++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/flint/test/test.py b/src/flint/test/test.py index fb2746d7..6b809a48 100644 --- a/src/flint/test/test.py +++ b/src/flint/test/test.py @@ -136,12 +136,16 @@ def test_fmpz(): (2, 2, 3, 1), (2, -1, 5, 3), (2, 0, 5, 1), + (2, 5, 1000, 32), ] for a, b, c, ab_mod_c in pow_mod_examples: assert pow(a, b, c) == ab_mod_c assert pow(flint.fmpz(a), b, c) == ab_mod_c assert pow(a, flint.fmpz(b), c) == ab_mod_c + assert pow(a, b, flint.fmpz(c)) == ab_mod_c assert pow(flint.fmpz(a), flint.fmpz(b), c) == ab_mod_c + assert pow(flint.fmpz(a), b, flint.fmpz(c)) == ab_mod_c + assert pow(a, flint.fmpz(b), flint.fmpz(c)) == ab_mod_c assert pow(flint.fmpz(a), flint.fmpz(b), flint.fmpz(c)) == ab_mod_c assert raises(lambda: pow(flint.fmpz(2), 2, 0), ValueError) diff --git a/src/flint/types/fmpz.pyx b/src/flint/types/fmpz.pyx index 31659c4f..1d27bc70 100644 --- a/src/flint/types/fmpz.pyx +++ b/src/flint/types/fmpz.pyx @@ -360,36 +360,41 @@ cdef class fmpz(flint_scalar): return u def __pow__(s, t, m): + cdef fmpz_struct sval[1] cdef fmpz_struct tval[1] cdef fmpz_struct mval[1] + cdef int stype = FMPZ_UNKNOWN cdef int ttype = FMPZ_UNKNOWN cdef int mtype = FMPZ_UNKNOWN cdef int success u = NotImplemented - ttype = fmpz_set_any_ref(tval, t) - if ttype == FMPZ_UNKNOWN: - return NotImplemented - if m is None: - # fmpz_pow_fmpz throws if x is negative - if fmpz_sgn(tval) == -1: - if ttype == FMPZ_TMP: fmpz_clear(tval) - raise ValueError("negative exponent") + try: + stype = fmpz_set_any_ref(sval, s) + if stype == FMPZ_UNKNOWN: + return NotImplemented + ttype = fmpz_set_any_ref(tval, t) + if ttype == FMPZ_UNKNOWN: + return NotImplemented + if m is None: + # fmpz_pow_fmpz throws if x is negative + if fmpz_sgn(tval) == -1: + raise ValueError("negative exponent") - u = fmpz.__new__(fmpz) - success = fmpz_pow_fmpz((u).val, (s).val, tval) + u = fmpz.__new__(fmpz) + success = fmpz_pow_fmpz((u).val, (s).val, tval) - if not success: - if ttype == FMPZ_TMP: fmpz_clear(tval) - raise OverflowError("fmpz_pow_fmpz: exponent too large") - else: - # Modular exponentiation - mtype = fmpz_set_any_ref(mval, m) - if mtype != FMPZ_UNKNOWN: + if not success: + raise OverflowError("fmpz_pow_fmpz: exponent too large") + + return u + else: + # Modular exponentiation + mtype = fmpz_set_any_ref(mval, m) + if mtype == FMPZ_UNKNOWN: + return NotImplemented if fmpz_is_zero(mval): - if ttype == FMPZ_TMP: fmpz_clear(tval) - if mtype == FMPZ_TMP: fmpz_clear(mval) raise ValueError("pow(): modulus cannot be zero") # The Flint docs say that fmpz_powm will throw if m is zero @@ -397,16 +402,16 @@ cdef class fmpz(flint_scalar): # e.g. pow(2, 2, -3) == (2^2) % (-3) == -2. We could implement # that here as well but it is not clear how useful it is. if fmpz_sgn(mval) == -1: - if ttype == FMPZ_TMP: fmpz_clear(tval) - if mtype == FMPZ_TMP: fmpz_clear(mval) - raise ValueError("pow(): negative modulua not supported") + raise ValueError("pow(): negative modulus not supported") u = fmpz.__new__(fmpz) - fmpz_powm((u).val, (s).val, tval, mval) + fmpz_powm((u).val, sval, tval, mval) - if ttype == FMPZ_TMP: fmpz_clear(tval) - if mtype == FMPZ_TMP: fmpz_clear(mval) - return u + return u + finally: + if stype == FMPZ_TMP: fmpz_clear(sval) + if ttype == FMPZ_TMP: fmpz_clear(tval) + if mtype == FMPZ_TMP: fmpz_clear(mval) def __rpow__(s, t, m): t = any_as_fmpz(t)