diff --git a/src/useMergedRefs.ts b/src/useMergedRefs.ts index ec6be9b..c497a3d 100644 --- a/src/useMergedRefs.ts +++ b/src/useMergedRefs.ts @@ -1,22 +1,42 @@ -import { useMemo } from 'react' +import { RefCallback, useMemo, version } from 'react' + +const isReact19 = parseInt(version.split('.')[0]!, 10) >= 19 type CallbackRef = (ref: T | null) => void type Ref = React.MutableRefObject | CallbackRef function toFnRef(ref?: Ref | null) { - return !ref || typeof ref === 'function' - ? ref - : (value: T | null) => { - ref.current = value as T - } + if (!ref || typeof ref === 'function') { + return ref + } + + return (value: T | null) => { + ref.current = value as T + } +} + +function cleanUp(cleanup: unknown, ref: RefCallback | null | undefined) { + if (typeof cleanup === 'function') { + cleanup() + } else if (ref) { + ref(null) + } } export function mergeRefs(refA?: Ref | null, refB?: Ref | null) { - const a = toFnRef(refA) - const b = toFnRef(refB) + const refASetter = toFnRef(refA) + const refBSetter = toFnRef(refB) + return (value: T | null) => { - if (a) a(value) - if (b) b(value) + const cleanupA = refASetter?.(value) + const cleanupB = refBSetter?.(value) + + if (isReact19) { + return () => { + cleanUp(cleanupA, refASetter) + cleanUp(cleanupB, refBSetter) + } + } } } @@ -36,7 +56,10 @@ export function mergeRefs(refA?: Ref | null, refB?: Ref | null) { * @param refB A Callback or mutable Ref * @category refs */ -function useMergedRefs(refA?: Ref | null, refB?: Ref | null) { +function useMergedRefs( + refA?: Ref | null, + refB?: Ref | null, +) { return useMemo(() => mergeRefs(refA, refB), [refA, refB]) } diff --git a/test/useMergedRefs.test.tsx b/test/useMergedRefs.test.tsx index b1f6678..26d949c 100644 --- a/test/useMergedRefs.test.tsx +++ b/test/useMergedRefs.test.tsx @@ -5,7 +5,7 @@ import useMergedRefs from '../src/useMergedRefs.js' import { render } from '@testing-library/react' describe('useMergedRefs', () => { - it('should return a function that returns mount state', () => { + it('should work with forwardRef', () => { let innerRef: HTMLButtonElement const outerRef = React.createRef() @@ -18,9 +18,80 @@ describe('useMergedRefs', () => { return