Skip to content

Commit dc14d61

Browse files
authored
Allow pass-through keyword arguments in Python callbacks (#73)
1 parent 91d88e9 commit dc14d61

File tree

4 files changed

+88
-17
lines changed

4 files changed

+88
-17
lines changed

config/ci/python-env.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ dependencies:
1212
- numpy
1313
- meson
1414
- ninja
15+
- typing-extensions

python/minpack/library.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ class UserData:
2828
disturbing foreign runtime.
2929
"""
3030

31-
def __init__(self, fcn):
31+
def __init__(self, fcn, **kwargs):
3232
self.fcn = fcn
3333
self.exception = None
34+
self.kwargs = kwargs
3435

3536

3637
@ffi.def_extern()
@@ -46,6 +47,7 @@ def func(n, x, fvec, iflag, data) -> None:
4647
handle.fcn(
4748
np.frombuffer(ffi.buffer(x, n * real.itemsize), dtype=real),
4849
np.frombuffer(ffi.buffer(fvec, n * real.itemsize), dtype=real),
50+
**handle.kwargs,
4951
)
5052
except BaseException as e:
5153
iflag[0] = -1
@@ -68,6 +70,7 @@ def fcn_hybrj(n, x, fvec, fjac, ldfjac, iflag, data) -> None:
6870
np.frombuffer(ffi.buffer(fvec, n * real.itemsize), dtype=real),
6971
np.reshape(fjac, (n, ldfjac)),
7072
iflag[0] == 2,
73+
**handle.kwargs,
7174
)
7275
except BaseException as e:
7376
iflag[0] = -1
@@ -90,6 +93,7 @@ def fcn_lmder(m, n, x, fvec, fjac, ldfjac, iflag, data) -> None:
9093
np.frombuffer(ffi.buffer(fvec, m * real.itemsize), dtype=real),
9194
np.reshape(fjac, (n, ldfjac)),
9295
iflag[0] == 2,
96+
**handle.kwargs,
9397
)
9498
except BaseException as e:
9599
iflag[0] = -1
@@ -109,6 +113,7 @@ def func2(m, n, x, fvec, iflag, data) -> None:
109113
handle.fcn(
110114
np.frombuffer(ffi.buffer(x, n * real.itemsize), dtype=real),
111115
np.frombuffer(ffi.buffer(fvec, m * real.itemsize), dtype=real),
116+
**handle.kwargs,
112117
)
113118
except BaseException as e:
114119
iflag[0] = -1
@@ -130,6 +135,7 @@ def fcn_lmstr(m, n, x, fvec, fjrow, iflag, data) -> None:
130135
np.frombuffer(ffi.buffer(fvec, m * real.itemsize), dtype=real),
131136
np.frombuffer(ffi.buffer(fjrow, n * real.itemsize), dtype=real),
132137
iflag[0] - 2 if iflag[0] > 1 else None,
138+
**handle.kwargs,
133139
)
134140
except BaseException as e:
135141
iflag[0] = -1
@@ -162,8 +168,8 @@ def cffi_callback(func, callback):
162168
"""
163169

164170
@functools.wraps(func)
165-
def entry_point(fcn, *args):
166-
data = UserData(fcn)
171+
def entry_point(fcn, *args, **kwargs):
172+
data = UserData(fcn, **kwargs)
167173
handle = ffi.new_handle(data)
168174
func(callback, *args, handle)
169175
if data.exception is not None:
@@ -191,6 +197,7 @@ def hybrd1(
191197
x: np.ndarray,
192198
fvec: np.ndarray,
193199
tol: float = math.sqrt(np.finfo(real).eps),
200+
**kwargs,
194201
) -> int:
195202
"""
196203
Find a zero of a system of n nonlinear functions in n variables
@@ -262,6 +269,7 @@ def hybrd1(
262269
info,
263270
ffi.cast("double*", wa.ctypes.data),
264271
lwa,
272+
**kwargs,
265273
)
266274
ex = info_hy(info[0])
267275
if ex is not None:
@@ -286,6 +294,7 @@ def hybrd(
286294
fjac: Optional[np.ndarray] = None,
287295
r: Optional[np.ndarray] = None,
288296
qtf: Optional[np.ndarray] = None,
297+
**kwargs,
289298
) -> int:
290299
"""
291300
Find a zero of a system of n nonlinear functions in n variables
@@ -353,6 +362,7 @@ def hybrd(
353362
ffi.cast("double*", wa2.ctypes.data),
354363
ffi.cast("double*", wa3.ctypes.data),
355364
ffi.cast("double*", wa4.ctypes.data),
365+
**kwargs,
356366
)
357367
ex = info_hy(info[0])
358368
if ex is not None:
@@ -366,6 +376,7 @@ def hybrj1(
366376
fvec: np.ndarray,
367377
fjac: np.ndarray,
368378
tol: float = math.sqrt(np.finfo(real).eps),
379+
**kwargs,
369380
) -> int:
370381
"""
371382
Find a zero of a system of n nonlinear functions in n variables
@@ -403,6 +414,7 @@ def hybrj1(
403414
info,
404415
ffi.cast("double*", wa.ctypes.data),
405416
lwa,
417+
**kwargs,
406418
)
407419
ex = info_hy(info[0])
408420
if ex is not None:
@@ -424,6 +436,7 @@ def hybrj(
424436
nprint: int = 0,
425437
r: Optional[np.ndarray] = None,
426438
qtf: Optional[np.ndarray] = None,
439+
**kwargs,
427440
) -> int:
428441
"""
429442
Find a zero of a system of n nonlinear functions in n variables
@@ -486,6 +499,7 @@ def hybrj(
486499
ffi.cast("double*", wa2.ctypes.data),
487500
ffi.cast("double*", wa3.ctypes.data),
488501
ffi.cast("double*", wa4.ctypes.data),
502+
**kwargs,
489503
)
490504
ex = info_hy(info[0])
491505
if ex is not None:
@@ -499,6 +513,7 @@ def lmder1(
499513
fvec: np.ndarray,
500514
fjac: np.ndarray,
501515
tol: float = math.sqrt(np.finfo(real).eps),
516+
**kwargs,
502517
) -> int:
503518
"""
504519
Minimize the sum of the squares of m nonlinear functions in n variables
@@ -539,6 +554,7 @@ def lmder1(
539554
ffi.cast("int*", ipvt.ctypes.data),
540555
ffi.cast("double*", wa.ctypes.data),
541556
lwa,
557+
**kwargs,
542558
)
543559
ex = info_lm(info[0])
544560
if ex is not None:
@@ -562,6 +578,7 @@ def lmder(
562578
nprint=0,
563579
ipvt: Optional[np.ndarray] = None,
564580
qtf: Optional[np.ndarray] = None,
581+
**kwargs,
565582
) -> int:
566583
"""
567584
Minimize the sum of the squares of m nonlinear functions in n variables
@@ -624,6 +641,7 @@ def lmder(
624641
ffi.cast("double*", wa2.ctypes.data),
625642
ffi.cast("double*", wa3.ctypes.data),
626643
ffi.cast("double*", wa4.ctypes.data),
644+
**kwargs,
627645
)
628646
ex = info_lm(info[0])
629647
if ex is not None:
@@ -636,6 +654,7 @@ def lmdif1(
636654
x: np.ndarray,
637655
fvec: np.ndarray,
638656
tol: float = math.sqrt(np.finfo(real).eps),
657+
**kwargs,
639658
) -> int:
640659
"""
641660
Minimize the sum of the squares of m nonlinear functions in n variables
@@ -709,6 +728,7 @@ def lmdif1(
709728
ffi.cast("int*", ipvt.ctypes.data),
710729
ffi.cast("double*", wa.ctypes.data),
711730
lwa,
731+
**kwargs,
712732
)
713733
ex = info_lm(info[0])
714734
if ex is not None:
@@ -733,6 +753,7 @@ def lmdif(
733753
fjac: Optional[np.ndarray] = None,
734754
ipvt: Optional[np.ndarray] = None,
735755
qtf: Optional[np.ndarray] = None,
756+
**kwargs,
736757
) -> int:
737758
"""
738759
Minimize the sum of the squares of m nonlinear functions in n variables
@@ -797,6 +818,7 @@ def lmdif(
797818
ffi.cast("double*", wa2.ctypes.data),
798819
ffi.cast("double*", wa3.ctypes.data),
799820
ffi.cast("double*", wa4.ctypes.data),
821+
**kwargs,
800822
)
801823
ex = info_lm(info[0])
802824
if ex is not None:
@@ -810,6 +832,7 @@ def lmstr1(
810832
fvec: np.ndarray,
811833
fjac: np.ndarray,
812834
tol: float = math.sqrt(np.finfo(real).eps),
835+
**kwargs,
813836
) -> int:
814837
"""
815838
Minimize the sum of the squares of m nonlinear functions in n variables by
@@ -851,6 +874,7 @@ def lmstr1(
851874
ffi.cast("int*", ipvt.ctypes.data),
852875
ffi.cast("double*", wa.ctypes.data),
853876
lwa,
877+
**kwargs,
854878
)
855879
ex = info_lm(info[0])
856880
if ex is not None:
@@ -874,6 +898,7 @@ def lmstr(
874898
nprint=0,
875899
ipvt: Optional[np.ndarray] = None,
876900
qtf: Optional[np.ndarray] = None,
901+
**kwargs,
877902
) -> int:
878903
"""
879904
Minimize the sum of the squares of m nonlinear functions in n variables by
@@ -937,6 +962,7 @@ def lmstr(
937962
ffi.cast("double*", wa2.ctypes.data),
938963
ffi.cast("double*", wa3.ctypes.data),
939964
ffi.cast("double*", wa4.ctypes.data),
965+
**kwargs,
940966
)
941967
ex = info_lm(info[0])
942968
if ex is not None:

python/minpack/test_library.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_lmder(driver):
127127
]
128128
)
129129

130-
def fcn(x, fvec, fjac, jacobian: bool) -> None:
130+
def fcn(x, fvec, fjac, jacobian: bool, y) -> None:
131131
if jacobian:
132132
for i in range(fvec.size):
133133
tmp1, tmp2 = i + 1, 16 - i - 1
@@ -151,14 +151,14 @@ def fcn(x, fvec, fjac, jacobian: bool) -> None:
151151
fvecp = np.zeros(15, dtype=np.float64)
152152
err = np.zeros(15, dtype=np.float64)
153153
minpack.library.chkder(x, fvec, fjac, xp, fvecp, False, err)
154-
fcn(x, fvec, fjac, False)
155-
fcn(x, fvec, fjac, True)
156-
fcn(xp, fvecp, fjac, False)
154+
fcn(x, fvec, fjac, False, y=y)
155+
fcn(x, fvec, fjac, True, y=y)
156+
fcn(xp, fvecp, fjac, False, y=y)
157157
minpack.library.chkder(x, fvec, fjac, xp, fvecp, True, err)
158158

159159
assert pytest.approx(err) == 15 * [1.0]
160160

161-
assert driver(fcn, x, fvec, fjac, tol) == 1
161+
assert driver(fcn, x, fvec, fjac, tol, y=y) == 1
162162

163163
assert pytest.approx(x, abs=100 * tol) == [0.8241058e-1, 0.1133037e1, 0.2343695e1]
164164

@@ -202,7 +202,7 @@ def test_lmdif(driver):
202202
]
203203
)
204204

205-
def fcn(x, fvec) -> None:
205+
def fcn(x, fvec, y) -> None:
206206
for i in range(fvec.size):
207207
tmp1, tmp2 = i + 1, 16 - i - 1
208208
tmp3 = tmp2 if i >= 8 else tmp1
@@ -213,7 +213,7 @@ def fcn(x, fvec) -> None:
213213
fjac = np.zeros((3, 15), dtype=np.float64)
214214
tol = sqrt(np.finfo(np.float64).eps)
215215

216-
assert driver(fcn, x, fvec, tol) == 1
216+
assert driver(fcn, x, fvec, tol, y=y) == 1
217217

218218
assert pytest.approx(x, abs=100 * tol) == [0.8241058e-1, 0.1133037e1, 0.2343695e1]
219219

@@ -260,7 +260,7 @@ def test_lmstr_exception(driver):
260260
class DummyException(Exception):
261261
...
262262

263-
def fcn(x, fvec, fjac, row) -> None:
263+
def fcn(x, fvec, fjrow, row) -> None:
264264
raise DummyException()
265265

266266
x = np.array([-1.2, 1.0])

python/minpack/typing.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,54 @@
44
"""
55

66
import numpy as np
7-
from typing import Optional, Callable
7+
from typing import Optional
8+
try:
9+
from typing import Protocol
10+
except ImportError:
11+
from typing_extensions import Protocol
812

9-
CallableHybrd = Callable[[np.ndarray, np.ndarray], None]
10-
CallableHybrj = Callable[[np.ndarray, np.ndarray, np.ndarray, bool], None]
11-
CallableLmder = Callable[[np.ndarray, np.ndarray, np.ndarray, bool], None]
12-
CallableLmdif = Callable[[np.ndarray, np.ndarray], None]
13-
CallableLmstr = Callable[[np.ndarray, np.ndarray, np.ndarray, Optional[int]], None]
13+
14+
class CallableHybrd(Protocol):
15+
def __call__(self, x: np.ndarray, fvec: np.ndarray, **kwargs) -> None:
16+
...
17+
18+
19+
class CallableHybrj(Protocol):
20+
def __call__(
21+
self,
22+
x: np.ndarray,
23+
fvec: np.ndarray,
24+
fjac: np.ndarray,
25+
jacobian: bool,
26+
**kwargs
27+
) -> None:
28+
...
29+
30+
31+
class CallableLmder(Protocol):
32+
def __call__(
33+
self,
34+
x: np.ndarray,
35+
fvec: np.ndarray,
36+
fjac: np.ndarray,
37+
jacobian: bool,
38+
**kwargs
39+
) -> None:
40+
...
41+
42+
43+
class CallableLmdif(Protocol):
44+
def __call__(self, x: np.ndarray, fvec: np.ndarray, **kwargs) -> None:
45+
...
46+
47+
48+
class CallableLmstr(Protocol):
49+
def __call__(
50+
self,
51+
x: np.ndarray,
52+
fvec: np.ndarray,
53+
fjrow: np.ndarray,
54+
row: Optional[int],
55+
**kwargs
56+
) -> None:
57+
...

0 commit comments

Comments
 (0)