From 38d5a0049e94c3cc7707595572c93eb6562b8f39 Mon Sep 17 00:00:00 2001 From: Jim Kitchen Date: Fri, 24 Feb 2023 15:42:37 -0600 Subject: [PATCH] Data types no longer need to match Conversion will happen automatically and should be efficient, avoiding creating intermediate arrays whenever possible. --- mlir_graphblas/implementations.py | 399 ++++++++++--------- mlir_graphblas/operations.py | 108 +---- mlir_graphblas/operators.py | 124 +++--- mlir_graphblas/tensor.py | 3 +- mlir_graphblas/tests/test_operations.py | 500 +++++++++++++----------- mlir_graphblas/tests/test_types.py | 17 +- mlir_graphblas/tests/utils.py | 2 + mlir_graphblas/types.py | 37 ++ 8 files changed, 614 insertions(+), 576 deletions(-) diff --git a/mlir_graphblas/implementations.py b/mlir_graphblas/implementations.py index 4d05296..49e45c4 100644 --- a/mlir_graphblas/implementations.py +++ b/mlir_graphblas/implementations.py @@ -18,7 +18,7 @@ from . descriptor import Descriptor, NULL as NULL_DESC from .utils import (get_sparse_output_pointer, get_scalar_output_pointer, get_scalar_input_arg, pick_and_renumber_indices, determine_sparsity) -from .types import RankedTensorType, BOOL, INT64, FP64 +from .types import DType, RankedTensorType, BOOL, INT64, FP64, cast from .exceptions import GrbError, GrbIndexOutOfBounds, GrbDimensionMismatch @@ -50,7 +50,7 @@ def select_by_mask(sp: SparseTensorBase, mask: SparseTensor, desc: Descriptor = # Convert value mask to structural mask if not desc.mask_structure: zero = Scalar.new(mask.dtype, 0) - mask = select(SelectOp.valuene, mask, thunk=zero) + mask = select(mask.dtype, SelectOp.valuene, mask, thunk=zero) # Build and compile if needed key = ('select_by_mask', *sp.get_loop_key(), *mask.get_loop_key(), desc.mask_complement) @@ -134,7 +134,7 @@ def select_by_indices(sp: SparseTensorBase, if row_indices is None: if complement: return Vector.new(sp.dtype, *sp.shape) - return dup(sp) + return dup(sp.dtype, sp) idx, vals = sp.extract_tuples() row_indices = np.array(row_indices, dtype=np.uint64) @@ -147,7 +147,7 @@ def select_by_indices(sp: SparseTensorBase, if row_indices is None and col_indices is None: if complement: return Matrix.new(sp.dtype, *sp.shape) - return dup(sp) + return dup(sp.dtype, sp) rowidx, colidx, vals = sp.extract_tuples() if row_indices is not None: @@ -241,34 +241,39 @@ def main(x): return compile(module) -def dup(sp: SparseTensorBase, intermediate: bool = True): +def dup(out_type: DType, sp: SparseTensorBase, intermediate: bool = True): if sp._obj is None: - return sp.baseclass(sp.dtype, sp.shape, intermediate_result=intermediate) + return sp.baseclass(out_type, sp.shape, intermediate_result=intermediate) + + if sp.ndims == 0: # Scalar + return Scalar.new(out_type, sp._obj) # Build and compile if needed - key = ('dup', *sp.get_loop_key()) + key = ('dup', out_type, *sp.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_dup(sp) + engine_cache[key] = _build_dup(out_type, sp) # Call the compiled function mem_out = get_sparse_output_pointer() arg_pointers = [sp._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return sp.baseclass(sp.dtype, sp.shape, mem_out, sp._sparsity, + return sp.baseclass(out_type, sp.shape, mem_out, sp._sparsity, sp.perceived_ordering, intermediate_result=intermediate) -def _build_dup(sp: SparseTensorBase): +def _build_dup(out_type: DType, sp: SparseTensorBase): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): rank = sp.ndims index = ir.IndexType.get() dtype = sp.dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm = ir.AffineMap.get_permutation(sp.permutation) perm_out = ir.AffineMap.get_permutation(range(rank)) rtt = sp.rtt.as_mlir_type() - rtt_out = sp.rtt.copy(ordering=sp.perceived_ordering).as_mlir_type() + rtt_out = sp.rtt.copy(dtype=out_type, + ordering=sp.perceived_ordering).as_mlir_type() @func.FuncOp.from_py_func(rtt) def main(x): @@ -282,10 +287,11 @@ def main(x): ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (perm, perm_out)]), ir.ArrayAttr.get([ir.Attribute.parse('#linalg.iterator_type')]*rank) ) - block = generic_op.regions[0].blocks.append(dtype, dtype) + block = generic_op.regions[0].blocks.append(dtype, dtype_out) with ir.InsertionPoint(block): a, _ = block.arguments - linalg.YieldOp([a]) + result = cast(a, sp.dtype, out_type) + linalg.YieldOp([result]) return generic_op.result main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() @@ -333,69 +339,77 @@ def main(x): return compile(module) -def _build_scalar_binop(op: BinaryOp, left: Scalar, right: Scalar): - # Both scalars are present +def _build_scalar_binop(out_type: DType, op: BinaryOp, left: Scalar, right: Scalar): + # Both scalars are non-empty with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): - dtype = left.dtype.build_mlir_type() + dtype_left = left.dtype.build_mlir_type() + dtype_right = right.dtype.build_mlir_type() - @func.FuncOp.from_py_func(dtype, dtype) + @func.FuncOp.from_py_func(dtype_left, dtype_right) def main(x, y): - result = op(x, y) + result = op(out_type, x, y) return result main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() return compile(module) -def ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): +def ewise_add(out_type: DType, op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): assert left.ndims == right.ndims assert left.dtype == right.dtype if left._obj is None: - return right + if right.dtype == out_type: + return right + return dup(out_type, right) if right._obj is None: - return left + if left.dtype == out_type: + return left + return dup(out_type, left) rank = left.ndims if rank == 0: # Scalar - key = ('scalar_binop', op.name, left.dtype, right.dtype) + key = ('scalar_binop', op.name, out_type, left.dtype, right.dtype) if key not in engine_cache: - engine_cache[key] = _build_scalar_binop(op, left, right) + engine_cache[key] = _build_scalar_binop(out_type, op, left, right) mem_out = get_scalar_output_pointer(left.dtype) arg_pointers = [get_scalar_input_arg(left), get_scalar_input_arg(right), mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Scalar(left.dtype, (), left.dtype.np_type(mem_out.contents.value)) + return Scalar(out_type, (), out_type.np_type(mem_out.contents.value)) # Build and compile if needed - key = ('ewise_add', op.name, *left.get_loop_key(), *right.get_loop_key()) + key = ('ewise_add', op.name, out_type, *left.get_loop_key(), *right.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_ewise_add(op, left, right) + engine_cache[key] = _build_ewise_add(out_type, op, left, right) # Call the compiled function mem_out = get_sparse_output_pointer() arg_pointers = [left._obj, right._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return left.baseclass(op.get_output_type(left.dtype, right.dtype), left.shape, mem_out, + return left.baseclass(out_type, left.shape, mem_out, determine_sparsity(left, right, union=True), left.perceived_ordering, intermediate_result=True) -def _build_ewise_add(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): +def _build_ewise_add(out_type: DType, op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): rank = left.ndims index = ir.IndexType.get() - dtype = left.dtype.build_mlir_type() + dtype_left = left.dtype.build_mlir_type() + dtype_right = right.dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm_left = ir.AffineMap.get_permutation(left.permutation) perm_right = ir.AffineMap.get_permutation(right.permutation) perm_out = ir.AffineMap.get_permutation(range(rank)) rtt_left = left.rtt.as_mlir_type() rtt_right = right.rtt.as_mlir_type() - rtt_out = left.rtt.copy(ordering=left.perceived_ordering, - sparsity=determine_sparsity(left, right, union=True)).as_mlir_type() + rtt_out = RankedTensorType(dtype=out_type, + sparsity=determine_sparsity(left, right, union=True), + ordering=left.perceived_ordering).as_mlir_type() @func.FuncOp.from_py_func(rtt_left, rtt_right) def main(x, y): @@ -409,15 +423,25 @@ def main(x, y): ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (perm_left, perm_right, perm_out)]), ir.ArrayAttr.get([ir.Attribute.parse('#linalg.iterator_type')]*rank) ) - block = generic_op.regions[0].blocks.append(dtype, dtype, dtype) + block = generic_op.regions[0].blocks.append(dtype_left, dtype_right, dtype_out) with ir.InsertionPoint(block): a, b, o = block.arguments - res = sparse_tensor.BinaryOp(dtype, a, b, left_identity=True, right_identity=True) - overlap = res.regions[0].blocks.append(dtype, dtype) + res = sparse_tensor.BinaryOp(dtype_out, a, b) + overlap = res.regions[0].blocks.append(dtype_left, dtype_right) with ir.InsertionPoint(overlap): arg0, arg1 = overlap.arguments - overlap_res = op(arg0, arg1) + overlap_res = op(out_type, arg0, arg1) sparse_tensor.YieldOp(result=overlap_res) + left_region = res.regions[1].blocks.append(dtype_left) + with ir.InsertionPoint(left_region): + arg0, = left_region.arguments + left_res = cast(arg0, left.dtype, out_type) + sparse_tensor.YieldOp(result=left_res) + right_region = res.regions[2].blocks.append(dtype_right) + with ir.InsertionPoint(right_region): + arg0, = right_region.arguments + right_res = cast(arg0, right.dtype, out_type) + sparse_tensor.YieldOp(result=right_res) linalg.YieldOp([res]) return generic_op.result main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() @@ -425,53 +449,51 @@ def main(x, y): return compile(module) -def ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): +def ewise_mult(out_type: DType, op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): assert left.ndims == right.ndims - assert left.dtype == right.dtype - output_dtype = op.get_output_type(left.dtype, right.dtype) if left._obj is None or right._obj is None: - return left.baseclass(output_dtype, left.shape) + return left.baseclass(out_type, left.shape) rank = left.ndims if rank == 0: # Scalar - key = ('scalar_binop', op.name, left.dtype, right.dtype) + key = ('scalar_binop', op.name, out_type, left.dtype, right.dtype) if key not in engine_cache: - engine_cache[key] = _build_scalar_binop(op, left, right) - mem_out = get_scalar_output_pointer(output_dtype) + engine_cache[key] = _build_scalar_binop(out_type, op, left, right) + mem_out = get_scalar_output_pointer(out_type) arg_pointers = [get_scalar_input_arg(left), get_scalar_input_arg(right), mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Scalar(output_dtype, (), output_dtype.np_type(mem_out.contents.value)) + return Scalar(out_type, (), out_type.np_type(mem_out.contents.value)) # Build and compile if needed - key = ('ewise_mult', op.name, *left.get_loop_key(), *right.get_loop_key()) + key = ('ewise_mult', op.name, out_type, *left.get_loop_key(), *right.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_ewise_mult(op, left, right) + engine_cache[key] = _build_ewise_mult(out_type, op, left, right) # Call the compiled function mem_out = get_sparse_output_pointer() arg_pointers = [left._obj, right._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return left.baseclass(output_dtype, left.shape, mem_out, + return left.baseclass(out_type, left.shape, mem_out, determine_sparsity(left, right), left.perceived_ordering, intermediate_result=True) -def _build_ewise_mult(op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): - op_result_dtype = op.get_output_type(left.dtype, right.dtype) +def _build_ewise_mult(out_type: DType, op: BinaryOp, left: SparseTensorBase, right: SparseTensorBase): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): rank = left.ndims index = ir.IndexType.get() - dtype = left.dtype.build_mlir_type() - dtype_out = op_result_dtype.build_mlir_type() + dtype_left = left.dtype.build_mlir_type() + dtype_right = right.dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm_left = ir.AffineMap.get_permutation(left.permutation) perm_right = ir.AffineMap.get_permutation(right.permutation) perm_out = ir.AffineMap.get_permutation(range(rank)) rtt_left = left.rtt.as_mlir_type() rtt_right = right.rtt.as_mlir_type() - rtt_out = RankedTensorType(dtype=op_result_dtype, + rtt_out = RankedTensorType(dtype=out_type, sparsity=determine_sparsity(left, right), ordering=left.perceived_ordering).as_mlir_type() @@ -487,14 +509,14 @@ def main(x, y): ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (perm_left, perm_right, perm_out)]), ir.ArrayAttr.get([ir.Attribute.parse('#linalg.iterator_type')]*rank) ) - block = generic_op.regions[0].blocks.append(dtype, dtype, dtype_out) + block = generic_op.regions[0].blocks.append(dtype_left, dtype_right, dtype_out) with ir.InsertionPoint(block): a, b, o = block.arguments res = sparse_tensor.BinaryOp(dtype_out, a, b) - overlap = res.regions[0].blocks.append(dtype, dtype) + overlap = res.regions[0].blocks.append(dtype_left, dtype_right) with ir.InsertionPoint(overlap): arg0, arg1 = overlap.arguments - overlap_res = op(arg0, arg1) + overlap_res = op(out_type, arg0, arg1) sparse_tensor.YieldOp(result=overlap_res) linalg.YieldOp([res]) return generic_op.result @@ -504,41 +526,42 @@ def main(x, y): # TODO: pass the mask to mxm -def mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix, TransposedMatrix]): +def mxm(out_type: DType, op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix, TransposedMatrix]): assert left.ndims == right.ndims == 2 - assert left.dtype == right.dtype - optype = op.binop.get_output_type(left.dtype, right.dtype) if left._obj is None or right._obj is None: - return Matrix.new(optype, left.shape[0], right.shape[1]) + return Matrix.new(out_type, left.shape[0], right.shape[1]) # Build and compile if needed - key = ('mxm', op.name, *left.get_loop_key(), *right.get_loop_key()) + key = ('mxm', op.name, out_type, *left.get_loop_key(), *right.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_mxm(op, left, right) + engine_cache[key] = _build_mxm(out_type, op, left, right) # Call the compiled function mem_out = get_sparse_output_pointer() arg_pointers = [left._obj, right._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Matrix(optype, [left.shape[0], right.shape[1]], mem_out, + return Matrix(out_type, [left.shape[0], right.shape[1]], mem_out, determine_sparsity(left, right), left.perceived_ordering, intermediate_result=True) -def _build_mxm(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Union[Matrix, TransposedMatrix]): - op_result_dtype = op.binop.get_output_type(left.dtype, right.dtype) +def _build_mxm(out_type: DType, + op: Semiring, + left: Union[Matrix, TransposedMatrix], + right: Union[Matrix, TransposedMatrix]): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): index = ir.IndexType.get() - dtype = left.dtype.build_mlir_type() - dtype_out = op_result_dtype.build_mlir_type() + dtype_left = left.dtype.build_mlir_type() + dtype_right = right.dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm_left = ir.AffineMap.get(3, 0, left._permute([ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(2)])) perm_right = ir.AffineMap.get(3, 0, right._permute([ir.AffineDimExpr.get(2), ir.AffineDimExpr.get(1)])) perm_out = ir.AffineMap.get(3, 0, [ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1)]) rtt_left = left.rtt.as_mlir_type() rtt_right = right.rtt.as_mlir_type() - rtt_out = RankedTensorType(dtype=op_result_dtype, + rtt_out = RankedTensorType(dtype=out_type, sparsity=determine_sparsity(left, right), ordering=left.perceived_ordering).as_mlir_type() @@ -559,21 +582,21 @@ def main(x, y): ir.Attribute.parse('#linalg.iterator_type'), ]) ) - block = generic_op.regions[0].blocks.append(dtype, dtype, dtype_out) + block = generic_op.regions[0].blocks.append(dtype_left, dtype_right, dtype_out) with ir.InsertionPoint(block): a, b, o = block.arguments bin_result = sparse_tensor.BinaryOp(dtype_out, a, b) - overlap = bin_result.regions[0].blocks.append(dtype, dtype) + overlap = bin_result.regions[0].blocks.append(dtype_left, dtype_right) with ir.InsertionPoint(overlap): arg0, arg1 = overlap.arguments - overlap_res = op.binop(arg0, arg1) + overlap_res = op.binop(out_type, arg0, arg1) sparse_tensor.YieldOp(result=overlap_res) - ident = op.monoid.identity(op_result_dtype) + ident = op.monoid.identity(out_type) red_result = sparse_tensor.ReduceOp(bin_result, o, ident) reduce = red_result.regions[0].blocks.append(dtype_out, dtype_out) with ir.InsertionPoint(reduce): arg0, arg1 = reduce.arguments - reduce_res = op.monoid.binop(arg0, arg1) + reduce_res = op.monoid.binop(out_type, arg0, arg1) sparse_tensor.YieldOp(result=reduce_res) linalg.YieldOp([red_result]) return generic_op.result @@ -583,41 +606,40 @@ def main(x, y): # TODO: pass the mask to mxv -def mxv(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector): +def mxv(out_type: DType, op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector): assert left.ndims == 2 assert right.ndims == 1 - optype = op.binop.get_output_type(left.dtype, right.dtype) if left._obj is None or right._obj is None: - return Vector.new(optype, left.shape[0]) + return Vector.new(out_type, left.shape[0]) # Build and compile if needed - key = ('mxv', op.name, *left.get_loop_key(), *right.get_loop_key()) + key = ('mxv', op.name, out_type, *left.get_loop_key(), *right.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_mxv(op, left, right) + engine_cache[key] = _build_mxv(out_type, op, left, right) # Call the compiled function mem_out = get_sparse_output_pointer() arg_pointers = [left._obj, right._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Vector(optype, [left.shape[0]], mem_out, + return Vector(out_type, [left.shape[0]], mem_out, right._sparsity, right.perceived_ordering, intermediate_result=True) -def _build_mxv(op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector): - op_result_dtype = op.binop.get_output_type(left.dtype, right.dtype) +def _build_mxv(out_type: DType, op: Semiring, left: Union[Matrix, TransposedMatrix], right: Vector): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): index = ir.IndexType.get() - dtype = left.dtype.build_mlir_type() - dtype_out = op_result_dtype.build_mlir_type() + dtype_left = left.dtype.build_mlir_type() + dtype_right = right.dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm_left = ir.AffineMap.get(2, 0, left._permute([ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1)])) perm_right = ir.AffineMap.get(2, 0, right._permute([ir.AffineDimExpr.get(1)])) perm_out = ir.AffineMap.get(2, 0, [ir.AffineDimExpr.get(0)]) rtt_left = left.rtt.as_mlir_type() rtt_right = right.rtt.as_mlir_type() - rtt_out = right.rtt.copy(dtype=op_result_dtype).as_mlir_type() + rtt_out = right.rtt.copy(dtype=out_type).as_mlir_type() @func.FuncOp.from_py_func(rtt_left, rtt_right) def main(x, y): @@ -634,21 +656,21 @@ def main(x, y): ir.Attribute.parse('#linalg.iterator_type'), ]) ) - block = generic_op.regions[0].blocks.append(dtype, dtype, dtype_out) + block = generic_op.regions[0].blocks.append(dtype_left, dtype_right, dtype_out) with ir.InsertionPoint(block): a, b, o = block.arguments bin_result = sparse_tensor.BinaryOp(dtype_out, a, b) - overlap = bin_result.regions[0].blocks.append(dtype, dtype) + overlap = bin_result.regions[0].blocks.append(dtype_left, dtype_right) with ir.InsertionPoint(overlap): arg0, arg1 = overlap.arguments - overlap_res = op.binop(arg0, arg1) + overlap_res = op.binop(out_type, arg0, arg1) sparse_tensor.YieldOp(result=overlap_res) - ident = op.monoid.identity(op_result_dtype) + ident = op.monoid.identity(out_type) red_result = sparse_tensor.ReduceOp(bin_result, o, ident) reduce = red_result.regions[0].blocks.append(dtype_out, dtype_out) with ir.InsertionPoint(reduce): arg0, arg1 = reduce.arguments - reduce_res = op.monoid.binop(arg0, arg1) + reduce_res = op.monoid.binop(out_type, arg0, arg1) sparse_tensor.YieldOp(result=reduce_res) linalg.YieldOp([red_result]) return generic_op.result @@ -658,41 +680,40 @@ def main(x, y): # TODO: pass the mask to vxm -def vxm(op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]): +def vxm(out_type: DType, op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]): assert left.ndims == 1 assert right.ndims == 2 - optype = op.binop.get_output_type(left.dtype, right.dtype) if left._obj is None or right._obj is None: - return Vector.new(optype, right.shape[1]) + return Vector.new(out_type, right.shape[1]) # Build and compile if needed - key = ('vxm', op.name, *left.get_loop_key(), *right.get_loop_key()) + key = ('vxm', op.name, out_type, *left.get_loop_key(), *right.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_vxm(op, left, right) + engine_cache[key] = _build_vxm(out_type, op, left, right) # Call the compiled function mem_out = get_sparse_output_pointer() arg_pointers = [left._obj, right._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Vector(optype, [right.shape[1]], mem_out, + return Vector(out_type, [right.shape[1]], mem_out, left._sparsity, left.perceived_ordering, intermediate_result=True) -def _build_vxm(op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]): - op_result_dtype = op.binop.get_output_type(left.dtype, right.dtype) +def _build_vxm(out_type: DType, op: Semiring, left: Vector, right: Union[Matrix, TransposedMatrix]): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): index = ir.IndexType.get() - dtype = left.dtype.build_mlir_type() - dtype_out = op_result_dtype.build_mlir_type() + dtype_left = left.dtype.build_mlir_type() + dtype_right = right.dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm_left = ir.AffineMap.get(2, 0, left._permute([ir.AffineDimExpr.get(0)])) perm_right = ir.AffineMap.get(2, 0, right._permute([ir.AffineDimExpr.get(0), ir.AffineDimExpr.get(1)])) perm_out = ir.AffineMap.get(2, 0, [ir.AffineDimExpr.get(1)]) rtt_left = left.rtt.as_mlir_type() rtt_right = right.rtt.as_mlir_type() - rtt_out = left.rtt.copy(dtype=op_result_dtype).as_mlir_type() + rtt_out = left.rtt.copy(dtype=out_type).as_mlir_type() @func.FuncOp.from_py_func(rtt_left, rtt_right) def main(x, y): @@ -709,21 +730,21 @@ def main(x, y): ir.Attribute.parse('#linalg.iterator_type'), ]) ) - block = generic_op.regions[0].blocks.append(dtype, dtype, dtype_out) + block = generic_op.regions[0].blocks.append(dtype_left, dtype_right, dtype_out) with ir.InsertionPoint(block): a, b, o = block.arguments bin_result = sparse_tensor.BinaryOp(dtype_out, a, b) - overlap = bin_result.regions[0].blocks.append(dtype, dtype) + overlap = bin_result.regions[0].blocks.append(dtype_left, dtype_right) with ir.InsertionPoint(overlap): arg0, arg1 = overlap.arguments - overlap_res = op.binop(arg0, arg1) + overlap_res = op.binop(out_type, arg0, arg1) sparse_tensor.YieldOp(result=overlap_res) - ident = op.monoid.identity(op_result_dtype) + ident = op.monoid.identity(out_type) red_result = sparse_tensor.ReduceOp(bin_result, o, ident) reduce = red_result.regions[0].blocks.append(dtype_out, dtype_out) with ir.InsertionPoint(reduce): arg0, arg1 = reduce.arguments - reduce_res = op.monoid.binop(arg0, arg1) + reduce_res = op.monoid.binop(out_type, arg0, arg1) sparse_tensor.YieldOp(result=reduce_res) linalg.YieldOp([red_result]) return generic_op.result @@ -732,64 +753,53 @@ def main(x, y): return compile(module) -def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp], +def apply(out_type: DType, op: Union[UnaryOp, BinaryOp, IndexUnaryOp], sp: SparseTensorBase, left: Optional[Scalar] = None, right: Optional[Scalar] = None, thunk: Optional[Scalar] = None, inplace: bool = False): - # Find output dtype - optype = type(op) - if optype is UnaryOp: - output_dtype = op.get_output_type(sp.dtype) - elif optype is BinaryOp: - if left is not None: - output_dtype = op.get_output_type(left.dtype, sp.dtype) - else: - output_dtype = op.get_output_type(sp.dtype, right.dtype) - else: - if inplace: - raise TypeError("apply inplace not supported for IndexUnaryOp") - output_dtype = op.get_output_type(sp.dtype, thunk.dtype) - if sp._obj is None: - return sp.baseclass(output_dtype, sp.shape) + return sp.baseclass(out_type, sp.shape) + optype = type(op) rank = sp.ndims if rank == 0: # Scalar if optype is UnaryOp: - key = ('scalar_apply_unary', op.name, sp.dtype) + key = ('scalar_apply_unary', op.name, out_type, sp.dtype) elif optype is BinaryOp: if left is not None: - key = ('scalar_apply_bind_first', op.name, sp.dtype, left._obj) + key = ('scalar_apply_bind_first', op.name, out_type, sp.dtype, left._obj) else: - key = ('scalar_apply_bind_second', op.name, sp.dtype, right._obj) + key = ('scalar_apply_bind_second', op.name, out_type, sp.dtype, right._obj) else: raise GrbError("apply scalar not supported for IndexUnaryOp") if key not in engine_cache: - engine_cache[key] = _build_scalar_apply(op, sp, left, right) - mem_out = get_scalar_output_pointer(output_dtype) + engine_cache[key] = _build_scalar_apply(out_type, op, sp, left, right) + mem_out = get_scalar_output_pointer(out_type) arg_pointers = [get_scalar_input_arg(sp), mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Scalar.new(output_dtype, mem_out.contents.value) + return Scalar.new(out_type, mem_out.contents.value) # Build and compile if needed # Note that Scalars are included in the key because they are inlined in the compiled code if optype is UnaryOp: - key = ('apply_unary', op.name, *sp.get_loop_key(), inplace) + key = ('apply_unary', op.name, out_type, *sp.get_loop_key(), inplace) elif optype is BinaryOp: if left is not None: - key = ('apply_bind_first', op.name, *sp.get_loop_key(), left._obj, inplace) + key = ('apply_bind_first', op.name, out_type, *sp.get_loop_key(), left._obj, inplace) else: - key = ('apply_bind_second', op.name, *sp.get_loop_key(), right._obj, inplace) + key = ('apply_bind_second', op.name, out_type, *sp.get_loop_key(), right._obj, inplace) else: - key = ('apply_indexunary', op.name, *sp.get_loop_key(), thunk._obj) + if inplace: + raise TypeError("apply inplace not supported for IndexUnaryOp") + key = ('apply_indexunary', op.name, out_type, *sp.get_loop_key(), thunk._obj) if key not in engine_cache: if inplace: engine_cache[key] = _build_apply_inplace(op, sp, left, right) else: - engine_cache[key] = _build_apply(op, sp, left, right, thunk, output_dtype) + engine_cache[key] = _build_apply(out_type, op, sp, left, right, thunk) # Call the compiled function if inplace: @@ -798,11 +808,12 @@ def apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp], mem_out = get_sparse_output_pointer() arg_pointers = [sp._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return sp.baseclass(output_dtype, sp.shape, mem_out, + return sp.baseclass(out_type, sp.shape, mem_out, sp._sparsity, sp.perceived_ordering, intermediate_result=True) -def _build_scalar_apply(op: Union[UnaryOp, BinaryOp], +def _build_scalar_apply(out_type: DType, + op: Union[UnaryOp, BinaryOp], sp: SparseTensorBase, left: Optional[Scalar], right: Optional[Scalar]): @@ -817,24 +828,24 @@ def main(x): if optype is BinaryOp: if left is not None: left_val = arith.ConstantOp(left.dtype.build_mlir_type(), left.extract_element()) - result = op(left_val, x) + result = op(out_type, left_val, x) else: right_val = arith.ConstantOp(right.dtype.build_mlir_type(), right.extract_element()) - result = op(x, right_val) + result = op(out_type, x, right_val) else: - result = op(x) + result = op(out_type, x) return result main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() return compile(module) -def _build_apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp], +def _build_apply(out_type: DType, + op: Union[UnaryOp, BinaryOp, IndexUnaryOp], sp: SparseTensorBase, left: Optional[Scalar], right: Optional[Scalar], - thunk: Optional[Scalar], - output_dtype): + thunk: Optional[Scalar]): optype = type(op) with ir.Context(), ir.Location.unknown(): module = ir.Module.create() @@ -843,11 +854,11 @@ def _build_apply(op: Union[UnaryOp, BinaryOp, IndexUnaryOp], index = ir.IndexType.get() i64 = ir.IntegerType.get_signless(64) dtype = sp.dtype.build_mlir_type() - dtype_out = output_dtype.build_mlir_type() + dtype_out = out_type.build_mlir_type() perm = ir.AffineMap.get_permutation(sp.permutation) perm_out = ir.AffineMap.get_permutation(range(rank)) rtt = sp.rtt.as_mlir_type() - rtt_out = sp.rtt.copy(dtype=output_dtype, ordering=sp.perceived_ordering).as_mlir_type() + rtt_out = sp.rtt.copy(dtype=out_type, ordering=sp.perceived_ordering).as_mlir_type() @func.FuncOp.from_py_func(rtt) def main(x): @@ -879,16 +890,16 @@ def main(x): thunk_val = arith.ConstantOp(index, thunk.extract_element()) else: thunk_val = arith.ConstantOp(thunk.dtype.build_mlir_type(), thunk.extract_element()) - val = op(arg0, rowidx, colidx, thunk_val) + val = op(out_type, arg0, rowidx, colidx, thunk_val) elif optype is BinaryOp: if left is not None: left_val = arith.ConstantOp(left.dtype.build_mlir_type(), left.extract_element()) - val = op(left_val, arg0) + val = op(out_type, left_val, arg0) else: right_val = arith.ConstantOp(right.dtype.build_mlir_type(), right.extract_element()) - val = op(arg0, right_val) + val = op(out_type, arg0, right_val) else: - val = op(arg0) + val = op(out_type, arg0) sparse_tensor.YieldOp(result=val) linalg.YieldOp([res]) return generic_op.result @@ -902,10 +913,6 @@ def _build_apply_inplace(op: Union[UnaryOp, BinaryOp], left: Optional[Scalar], right: Optional[Scalar]): optype = type(op) - op_result_dtype = op.get_output_type(sp.dtype) - if op_result_dtype != sp.dtype: - raise TypeError("apply inplace is restricted from changing dtype") - with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): @@ -927,12 +934,12 @@ def main(x): if optype is BinaryOp: if left is not None: left_val = arith.ConstantOp(left.dtype.build_mlir_type(), left.extract_element()) - result = op(left_val, val) + result = op(sp.dtype, left_val, val) else: right_val = arith.ConstantOp(right.dtype.build_mlir_type(), right.extract_element()) - result = op(val, right_val) + result = op(sp.dtype, val, right_val) else: - result = op(val) + result = op(sp.dtype, val) memref.StoreOp(result, vals, [x]) scf.YieldOp([]) main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() @@ -940,10 +947,10 @@ def main(x): return compile(module) -def select(op: SelectOp, sp: SparseTensor, thunk: Scalar): +def select(out_type: DType, op: SelectOp, sp: SparseTensor, thunk: Scalar): # Handle case of empty tensor if sp._obj is None: - return sp.__class__(sp.dtype, sp.shape) + return sp.__class__(out_type, sp.shape) rank = sp.ndims if rank == 0: # Scalar @@ -955,9 +962,9 @@ def select(op: SelectOp, sp: SparseTensor, thunk: Scalar): engine_cache[key].invoke('main', *arg_pointers) # Invocation returns True/False for whether to keep value if mem_out.contents.value: - return sp.dup() + return Scalar.new(out_type, sp._obj) else: - return Scalar.new(sp.dtype) + return Scalar.new(out_type) # Build and compile if needed # Note that thunk is included in the key because it is inlined in the compiled code @@ -969,8 +976,14 @@ def select(op: SelectOp, sp: SparseTensor, thunk: Scalar): mem_out = get_sparse_output_pointer() arg_pointers = [sp._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return sp.baseclass(sp.dtype, sp.shape, mem_out, - sp._sparsity, sp.perceived_ordering, intermediate_result=True) + res = sp.baseclass(sp.dtype, sp.shape, mem_out, + sp._sparsity, sp.perceived_ordering, intermediate_result=True) + + # _build_select cannot change output dtype; handle that now + if out_type != sp.dtype: + res = dup(out_type, res, intermediate=True) + + return res def _build_scalar_select(op: SelectOp, sp: SparseTensorBase, thunk: Scalar): @@ -987,7 +1000,7 @@ def main(x): thunk_val = arith.ConstantOp(index, thunk.extract_element()) else: thunk_val = arith.ConstantOp(thunk.dtype.build_mlir_type(), thunk.extract_element()) - cmp = op(x, c0, c0, thunk_val) + cmp = op(BOOL, x, c0, c0, thunk_val) return cmp main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() @@ -1035,7 +1048,7 @@ def main(x): thunk_val = arith.ConstantOp(index, thunk.extract_element()) else: thunk_val = arith.ConstantOp(thunk.dtype.build_mlir_type(), thunk.extract_element()) - cmp = op(arg0, rowidx, colidx, thunk_val) + cmp = op(BOOL, arg0, rowidx, colidx, thunk_val) sparse_tensor.YieldOp(result=cmp) linalg.YieldOp([res]) return generic_op.result @@ -1044,9 +1057,9 @@ def main(x): return compile(module) -def reduce_to_vector(op: Monoid, mat: Union[Matrix, TransposedMatrix]): +def reduce_to_vector(out_type: DType, op: Monoid, mat: Union[Matrix, TransposedMatrix]): if mat._obj is None: - return Vector.new(mat.dtype, mat.shape[0]) + return Vector.new(out_type, mat.shape[0]) # Build and compile if needed key = ('reduce_to_vector', op.name, *mat.get_loop_key()) @@ -1057,8 +1070,14 @@ def reduce_to_vector(op: Monoid, mat: Union[Matrix, TransposedMatrix]): mem_out = get_sparse_output_pointer() arg_pointers = [mat._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Vector(mat.dtype, [mat.shape[0]], mem_out, - [DimLevelType.compressed], [0], intermediate_result=True) + res = Vector(mat.dtype, [mat.shape[0]], mem_out, + [DimLevelType.compressed], [0], intermediate_result=True) + + # _build_reduce_to_vector cannot change output dtype; handle that now + if out_type != mat.dtype: + res = dup(out_type, res, intermediate=True) + + return res def _build_reduce_to_vector(op: Monoid, mat: Union[Matrix, TransposedMatrix]): @@ -1093,7 +1112,7 @@ def main(x): region = res.regions[0].blocks.append(dtype, dtype) with ir.InsertionPoint(region): arg0, arg1 = region.arguments - reduce_res = op.binop(arg0, arg1) + reduce_res = op.binop(mat.dtype, arg0, arg1) sparse_tensor.YieldOp(result=reduce_res) linalg.YieldOp([res]) return generic_op.result @@ -1102,28 +1121,27 @@ def main(x): return compile(module) -def reduce_to_scalar(op: Monoid, sp: SparseTensorBase): +def reduce_to_scalar(out_type: DType, op: Monoid, sp: SparseTensorBase): if sp._obj is None: - return Scalar.new(sp.dtype) + return Scalar.new(out_type) # Build and compile if needed - key = ('reduce_to_scalar', op.name, *sp.get_loop_key()) + key = ('reduce_to_scalar', op.name, out_type, *sp.get_loop_key()) if key not in engine_cache: - engine_cache[key] = _build_reduce_to_scalar(op, sp) + engine_cache[key] = _build_reduce_to_scalar(out_type, op, sp) # Call the compiled function - mem_out = get_scalar_output_pointer(sp.dtype) + mem_out = get_scalar_output_pointer(out_type) arg_pointers = [sp._obj, mem_out] engine_cache[key].invoke('main', *arg_pointers) - return Scalar.new(sp.dtype, mem_out.contents.value) + return Scalar.new(out_type, mem_out.contents.value) -def _build_reduce_to_scalar(op: Monoid, sp: SparseTensorBase): +def _build_reduce_to_scalar(out_type: DType, op: Monoid, sp: SparseTensorBase): with ir.Context(), ir.Location.unknown(): module = ir.Module.create() with ir.InsertionPoint(module.body): rank = sp.ndims - index = ir.IndexType.get() dtype = sp.dtype.build_mlir_type() perm = ir.AffineMap.get_permutation(sp.permutation) perm_out = ir.AffineMap.get(rank, 0, []) @@ -1148,17 +1166,18 @@ def main(x): region = res.regions[0].blocks.append(dtype, dtype) with ir.InsertionPoint(region): arg0, arg1 = region.arguments - reduce_res = op.binop(arg0, arg1) + reduce_res = op.binop(sp.dtype, arg0, arg1) sparse_tensor.YieldOp(result=reduce_res) linalg.YieldOp([res]) s = tensor.ExtractOp(generic_op, []) + s = cast(s, sp.dtype, out_type) return s.result main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() return compile(module) -def extract(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_size): +def extract(out_type: DType, tensor: SparseTensorBase, row_indices, col_indices, row_size, col_size): # There may be a way to do this in MLIR, but for now we use numpy if tensor.ndims == 1: # Vector @@ -1166,20 +1185,24 @@ def extract(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_si assert col_size is None if row_indices is None: # None indicates GrB_ALL - return dup(tensor) + return dup(out_type, tensor) idx, vals = tensor.extract_tuples() + if out_type != tensor.dtype: + vals = vals.astype(out_type.np_type) pick_list = np.array(row_indices, dtype=np.uint64) idx, vals = pick_and_renumber_indices(pick_list, idx, vals) - v = Vector.new(tensor.dtype, row_size) + v = Vector.new(out_type, row_size) v.build(idx, vals) return v # Matrix if row_indices is None and col_indices is None: - return dup(tensor) + return dup(out_type, tensor) rowidx, colidx, vals = tensor.extract_tuples() + if out_type != tensor.dtype: + vals = vals.astype(out_type.np_type) if row_indices is not None: if type(row_indices) is int: pick_list = np.array([row_indices], dtype=np.uint64) @@ -1195,39 +1218,41 @@ def extract(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_si if type(row_indices) is int: # Extract row as Vector assert np.all(rowidx == 0) - v = Vector.new(tensor.dtype, col_size) + v = Vector.new(out_type, col_size) v.build(colidx, vals) return v if type(col_indices) is int: # Extract col as Vector assert np.all(colidx == 0) - v = Vector.new(tensor.dtype, row_size) + v = Vector.new(out_type, row_size) v.build(rowidx, vals) return v - m = Matrix.new(tensor.dtype, row_size, col_size) + m = Matrix.new(out_type, row_size, col_size) m.build(rowidx, colidx, vals) return m -def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_size=None): +def assign(out_type: DType, tensor: SparseTensorBase, row_indices, col_indices, row_size, col_size=None): # There may be a way to do this in MLIR, but for now we use numpy if tensor.ndims == 1: # Vector input if row_indices is None and col_size is None: # Vector output with GrB_ALL - return dup(tensor) + return dup(out_type, tensor) idx, vals = tensor.extract_tuples() + if out_type != tensor.dtype: + vals = vals.astype(out_type.np_type) if col_size is None: # Vector output - v = Vector.new(tensor.dtype, row_size) + v = Vector.new(out_type, row_size) # Map idx to output indices idx = np.array(row_indices, dtype=np.uint64)[idx] v.build(idx, vals, sparsity=tensor._sparsity) return v # Assign Vector as row or column of Matrix - m = Matrix.new(tensor.dtype, row_size, col_size) + m = Matrix.new(out_type, row_size, col_size) if type(row_indices) is int: # Map idx to output cols colidx = idx if col_indices is None else np.array(col_indices, dtype=np.uint64)[idx] @@ -1240,15 +1265,17 @@ def assign(tensor: SparseTensorBase, row_indices, col_indices, row_size, col_siz # Matrix input if row_indices is None and col_indices is None: - return dup(tensor) + return dup(out_type, tensor) rowidx, colidx, vals = tensor.extract_tuples() + if out_type != tensor.dtype: + vals = vals.astype(out_type.np_type) # Map indices to output if row_indices is not None: rowidx = np.array(row_indices, dtype=np.uint64)[rowidx] if col_indices is not None: colidx = np.array(col_indices, dtype=np.uint64)[colidx] - m = Matrix.new(tensor.dtype, row_size, col_size) + m = Matrix.new(out_type, row_size, col_size) m.build(rowidx, colidx, vals, sparsity=["compressed", "compressed"]) return m diff --git a/mlir_graphblas/operations.py b/mlir_graphblas/operations.py index 3a8b266..5ce3853 100644 --- a/mlir_graphblas/operations.py +++ b/mlir_graphblas/operations.py @@ -61,7 +61,7 @@ def update(output: SparseObject, if accum is None or output._obj is None: output.set_element(tensor.extract_element()) else: - output._obj = impl.ewise_add(accum, output, tensor)._obj + output._obj = impl.ewise_add(output.dtype, accum, output, tensor)._obj return if not isinstance(output, SparseTensor): @@ -80,7 +80,7 @@ def update(output: SparseObject, else: # mask=Y, accum=N, replace=N # Apply inverted mask, then eWiseAdd output._replace(impl.select_by_mask(output, mask, desc_inverted)) - result = impl.ewise_add(BinaryOp.oneb, output, tensor) + result = impl.ewise_add(output.dtype, BinaryOp.oneb, output, tensor) else: if mask is None: # mask=N, accum=N, replace=?, w/ indices # Drop indices in output, then eWiseAdd @@ -106,22 +106,22 @@ def update(output: SparseObject, # Select the row/col indices in the mask, apply it inverted to the output, then eWiseAdd new_mask = impl.select_by_indices(mask, row_indices, col_indices) output._replace(impl.select_by_mask(output, new_mask, desc_inverted)) - result = impl.ewise_add(BinaryOp.oneb, output, tensor) + result = impl.ewise_add(output.dtype, BinaryOp.oneb, output, tensor) elif mask is None or not desc.replace: # eWiseAdd using accum - result = impl.ewise_add(accum, output, tensor) + result = impl.ewise_add(output.dtype, accum, output, tensor) else: # Mask the output, then perform eWiseAdd using accum output._replace(impl.select_by_mask(output, mask, desc)) - result = impl.ewise_add(accum, output, tensor) + result = impl.ewise_add(output.dtype, accum, output, tensor) if result is output: # This can happen if empty tensors are used as input return output - # If not an intermediate result, make a copy - if not result._intermediate_result: - result = impl.dup(result) + # If not an intermediate result or wrong dtype, make a copy + if not result._intermediate_result or result.dtype != output.dtype: + result = impl.dup(output.dtype, result) output._replace(result) @@ -132,10 +132,6 @@ def transpose(out: Matrix, mask: Optional[SparseTensor] = None, accum: Optional[BinaryOp] = None, desc: Descriptor = NULL_DESC): - # Verify dtypes - if out.dtype != tensor.dtype: - raise GrbDomainMismatch(f"output type must be {tensor.dtype}, not {out.dtype}") - # Apply descriptor transpose if tensor.ndims != 2: raise TypeError(f"transpose requires Matrix, not {type(tensor)}") @@ -171,14 +167,6 @@ def ewise_add(out: SparseTensor, if type(op) is not BinaryOp: raise TypeError(f"op must be BinaryOp, Monoid, or Semiring") - # Verify dtypes - if op.output is not None and type(op.output) is not int: - raise GrbDomainMismatch("op must return same type as inputs with ewise_add") - if left.dtype != right.dtype: - raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}") - if out.dtype != left.dtype: - raise GrbDomainMismatch(f"output type must be {left.dtype}, not {out.dtype}") - # Apply transposes if desc.transpose0 and left.ndims == 2: left = TransposedMatrix.wrap(left) @@ -195,7 +183,7 @@ def ewise_add(out: SparseTensor, left = impl.select_by_mask(left, mask, desc) right = impl.select_by_mask(right, mask, desc) - result = impl.ewise_add(op, left, right) + result = impl.ewise_add(out.dtype, op, left, right) update(out, result, mask, accum, desc) @@ -214,13 +202,6 @@ def ewise_mult(out: SparseTensor, else: raise TypeError(f"op must be BinaryOp, Monoid, or Semiring") - # Verify dtypes - if left.dtype != right.dtype: - raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}") - required_out_dtype = op.get_output_type(left.dtype, right.dtype) - if out.dtype != required_out_dtype: - raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}") - # Apply transposes if desc.transpose0 and left.ndims == 2: left = TransposedMatrix.wrap(left) @@ -237,7 +218,7 @@ def ewise_mult(out: SparseTensor, # Only need to apply mask to one of the inputs left = impl.select_by_mask(left, mask, desc) - result = impl.ewise_mult(op, left, right) + result = impl.ewise_mult(out.dtype, op, left, right) update(out, result, mask, accum, desc) @@ -253,13 +234,6 @@ def mxm(out: Matrix, if type(op) is not Semiring: raise TypeError(f"op must be Semiring, not {type(op)}") - # Verify dtypes - if left.dtype != right.dtype: - raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}") - required_out_dtype = op.binop.get_output_type(left.dtype, right.dtype) - if out.dtype != required_out_dtype: - raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}") - # Apply transposes if left.ndims != right.ndims != 2: raise GrbDimensionMismatch("mxm requires rank 2 tensors") @@ -285,7 +259,7 @@ def mxm(out: Matrix, right = impl.flip_layout(right) # TODO: apply the mask during the computation, not at the end - result = impl.mxm(op, left, right) + result = impl.mxm(out.dtype, op, left, right) if mask is not None: result = impl.select_by_mask(result, mask, desc) update(out, result, mask, accum, desc) @@ -303,13 +277,6 @@ def mxv(out: Vector, if type(op) is not Semiring: raise TypeError(f"op must be Semiring, not {type(op)}") - # Verify dtypes - if left.dtype != right.dtype: - raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}") - required_out_dtype = op.binop.get_output_type(left.dtype, right.dtype) - if out.dtype != required_out_dtype: - raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}") - # Apply transpose if left.ndims != 2: raise GrbDimensionMismatch("mxv requires matrix as first input") @@ -325,7 +292,7 @@ def mxv(out: Vector, raise GrbDimensionMismatch(f"output size should be {left.shape[0]} not {out.shape[0]}") # TODO: apply the mask during the computation, not at the end - result = impl.mxv(op, left, right) + result = impl.mxv(out.dtype, op, left, right) if mask is not None: result = impl.select_by_mask(result, mask, desc) update(out, result, mask, accum, desc) @@ -343,13 +310,6 @@ def vxm(out: Vector, if type(op) is not Semiring: raise TypeError(f"op must be Semiring, not {type(op)}") - # Verify dtypes - if left.dtype != right.dtype: - raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}") - required_out_dtype = op.binop.get_output_type(left.dtype, right.dtype) - if out.dtype != required_out_dtype: - raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}") - # Apply transpose if right.ndims != 2: raise GrbDimensionMismatch("vxm requires matrix as second input") @@ -365,7 +325,7 @@ def vxm(out: Vector, raise GrbDimensionMismatch(f"output size should be {right.shape[1]} not {out.shape[0]}") # TODO: apply the mask during the computation, not at the end - result = impl.vxm(op, left, right) + result = impl.vxm(out.dtype, op, left, right) if mask is not None: result = impl.select_by_mask(result, mask, desc) update(out, result, mask, accum, desc) @@ -386,7 +346,6 @@ def apply(out: SparseTensor, if optype is UnaryOp: if thunk is not None or left is not None or right is not None: raise TypeError("UnaryOp does not accept thunk, left, or right") - required_out_dtype = op.get_output_type(tensor.dtype) elif optype is BinaryOp: if thunk is not None: raise TypeError("BinaryOp accepts left or thing, not thunk") @@ -396,23 +355,16 @@ def apply(out: SparseTensor, raise TypeError("Cannot provide both left and right") if left is not None: left = ensure_scalar_of_type(left, tensor.dtype) - required_out_dtype = op.get_output_type(left.dtype, tensor.dtype) else: right = ensure_scalar_of_type(right, tensor.dtype) - required_out_dtype = op.get_output_type(tensor.dtype, right.dtype) elif optype is IndexUnaryOp: if left is not None or right is not None: raise TypeError("IndexUnaryOp accepts thunk, not left or right") thunk_dtype = INT64 if op.thunk_as_index else tensor.dtype thunk = ensure_scalar_of_type(thunk, thunk_dtype) - required_out_dtype = op.get_output_type(tensor.dtype, thunk.dtype) else: raise TypeError(f"op must be UnaryOp, BinaryOp, or IndexUnaryOp, not {type(op)}") - # Verify dtype - if out.dtype != required_out_dtype: - raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}") - # Apply transpose if desc.transpose0 and tensor.ndims == 2: tensor = TransposedMatrix.wrap(tensor) @@ -433,9 +385,9 @@ def apply(out: SparseTensor, and desc is NULL_DESC and not tensor._intermediate_result ): - impl.apply(op, tensor, left, right, None, inplace=True) + impl.apply(out.dtype, op, tensor, left, right, None, inplace=True) else: - result = impl.apply(op, tensor, left, right, thunk) + result = impl.apply(out.dtype, op, tensor, left, right, thunk) update(out, result, mask, accum, desc) @@ -452,8 +404,6 @@ def select(out: SparseTensor, raise TypeError(f"op must be SelectOp, not {type(op)}") # Verify dtypes - if out.dtype != tensor.dtype: - raise GrbDomainMismatch(f"output dtype must match input dtype: {out.dtype} != {tensor.dtype}") thunk_dtype = INT64 if op.thunk_as_index else tensor.dtype thunk = ensure_scalar_of_type(thunk, thunk_dtype) @@ -468,7 +418,7 @@ def select(out: SparseTensor, if mask is not None: tensor = impl.select_by_mask(tensor, mask, desc) - result = impl.select(op, tensor, thunk) + result = impl.select(out.dtype, op, tensor, thunk) update(out, result, mask, accum, desc) @@ -483,10 +433,6 @@ def reduce_to_vector(out: Vector, if type(op) is not Monoid: raise TypeError(f"op must be Monoid, not {type(op)}") - # Verify dtypes - if out.dtype != tensor.dtype: - raise GrbDomainMismatch(f"output dtype must match input dtype: {out.dtype} != {tensor.dtype}") - # Apply transpose if tensor.ndims != 2: raise GrbDimensionMismatch("reduce_to_vector requires matrix input") @@ -500,7 +446,7 @@ def reduce_to_vector(out: Vector, raise GrbDimensionMismatch(f"output size should be {tensor.shape[0]} not {out.shape[0]}") # TODO: apply the mask during the computation, not at the end - result = impl.reduce_to_vector(op, tensor) + result = impl.reduce_to_vector(out.dtype, op, tensor) if mask is not None: result = impl.select_by_mask(result, mask, desc) update(out, result, mask, accum, desc) @@ -516,15 +462,11 @@ def reduce_to_scalar(out: Scalar, if type(op) is not Monoid: raise TypeError(f"op must be Monoid, not {type(op)}") - # Verify dtypes - if out.dtype != tensor.dtype: - raise GrbDomainMismatch(f"output dtype must match input dtype: {out.dtype} != {tensor.dtype}") - # Compare shapes if out.ndims != 0: raise GrbDimensionMismatch("reduce_to_scalar requires scalar output") - result = impl.reduce_to_scalar(op, tensor) + result = impl.reduce_to_scalar(out.dtype, op, tensor) update(out, result, accum=accum, desc=desc) @@ -539,10 +481,6 @@ def extract(out: SparseTensor, """ Setting row_indices or col_indices to `None` is the equivalent of GrB_ALL """ - # Verify dtypes - if out.dtype != tensor.dtype: - raise GrbDomainMismatch(f"output must have same dtype as input: {out.dtype} != {tensor.dtype}") - # Apply transpose if desc.transpose0 and tensor.ndims == 2: tensor = TransposedMatrix.wrap(tensor) @@ -587,7 +525,7 @@ def extract(out: SparseTensor, if out.shape != expected_out_shape: raise GrbDimensionMismatch(f"output shape mismatch: {out.shape} != {expected_out_shape}") - result = impl.extract(tensor, row_indices, col_indices, row_size, col_size) + result = impl.extract(out.dtype, tensor, row_indices, col_indices, row_size, col_size) if mask is not None: result = impl.select_by_mask(result, mask, desc) update(out, result, mask, accum, desc) @@ -610,10 +548,6 @@ def assign(out: SparseTensor, raise TypeError(f"tensor must be a SparseObject or Python scalar, not {type(tensor)}") tensor = ensure_scalar_of_type(tensor, out.dtype) - # Verify dtypes - if out.dtype != tensor.dtype: - raise GrbDomainMismatch(f"output must have same dtype as input: {out.dtype} != {tensor.dtype}") - # Apply transpose if desc.transpose0 and tensor.ndims == 2: tensor = TransposedMatrix.wrap(tensor) @@ -653,7 +587,7 @@ def assign(out: SparseTensor, if mask is None: raise GrbError("This will create a dense matrix. Please provide a mask or indices.") # Use mask to build an iso-valued Matrix - result = impl.apply(BinaryOp.second, mask, right=tensor) + result = impl.apply(out.dtype, BinaryOp.second, mask, right=tensor) else: if out.ndims == 1: # Vector output result = impl.build_iso_vector_from_indices(out.dtype, *out.shape, row_indices, tensor) @@ -675,7 +609,7 @@ def assign(out: SparseTensor, if tensor.shape != expected_input_shape: raise GrbDimensionMismatch(f"input shape mismatch: {tensor.shape} != {expected_input_shape}") - result = impl.assign(tensor, row_indices, col_indices, *out.shape) + result = impl.assign(out.dtype, tensor, row_indices, col_indices, *out.shape) if mask is not None: result = impl.select_by_mask(result, mask, desc) update(out, result, mask, accum, desc, row_indices=row_indices, col_indices=col_indices) diff --git a/mlir_graphblas/operators.py b/mlir_graphblas/operators.py index f4f8f5e..4ae0e8d 100644 --- a/mlir_graphblas/operators.py +++ b/mlir_graphblas/operators.py @@ -12,7 +12,7 @@ from mlir.dialects import arith from .exceptions import GrbDomainMismatch -from .types import DType, BOOL, INT8, INT64, FP64 +from .types import DType, BOOL, INT8, INT64, FP64, cast, find_common_dtype from .utils import CmpFPredicate, CmpIPredicate __all__ = ["UnaryOp", "BinaryOp", "IndexUnaryOp", "SelectOp", "Monoid", "Semiring"] @@ -72,8 +72,7 @@ def __init__(self, func, *, input=None, output=None): self.input = input # Validate output if output is not None: - if type(output) is not int: - assert output in {bool, int, float} + assert output in {bool, int} self.output = output @classmethod @@ -86,31 +85,30 @@ def _register(cls, func=None, **kwargs): super()._register(op) return op - def validate_input(self, input_val): + def validate_input(self, input_type): if self.input is None: return - val_dtype = self._dtype_of(input_val) - if self.input is bool and val_dtype not in {BOOL, INT8}: - raise GrbDomainMismatch("input must be boolean type") - elif self.input is int and not val_dtype.is_int(): - raise GrbDomainMismatch("input must be int type") - elif self.input is float and not val_dtype.is_float(): - raise GrbDomainMismatch("input must be float type") - - def get_output_type(self, left_input_dtype, right_input_dtype=None): + + if self.input is bool: + if input_type not in {BOOL, INT8}: + raise GrbDomainMismatch("input must be boolean type") + elif self.input is int: + if not input_type.is_int(): + raise GrbDomainMismatch("input must be int type") + elif self.input is float: + if not input_type.is_float(): + raise GrbDomainMismatch("input must be float type") + + def validate_output(self, output_type): if self.output is None: - if right_input_dtype is None: - return left_input_dtype - if left_input_dtype != right_input_dtype: - raise TypeError(f"Unable to infer output type from {left_input_dtype} and {right_input_dtype}") - return left_input_dtype - elif self.output == 0: - return left_input_dtype - elif self.output == 1: - if right_input_dtype is None: - raise TypeError("No type provided for expected 2nd input argument") - return right_input_dtype - return self._type_convert[self.output] + return + + if self.output is bool: + if output_type != BOOL: + raise GrbDomainMismatch("output must be BOOL type") + elif self.output is int: + if output_type != INT64: + raise GrbDomainMismatch("output must be INT64 type") class UnaryOp(_FuncOp): @@ -125,9 +123,18 @@ class UnaryOp(_FuncOp): def name_of_op(x, dtype): return ... """ - def __call__(self, x): - self.validate_input(x) - return self.func(x, self._dtype_of(x)) + def __call__(self, out_type: DType, x): + x, xtype = self._validate(out_type, x) + return self.func(x, xtype) + + def _validate(self, out_type, x): + self.validate_output(out_type) + xtype = self._dtype_of(x) + if self.output is None: + x = cast(x, xtype, out_type) + xtype = out_type + self.validate_input(xtype) + return x, xtype class BinaryOp(_FuncOp): @@ -142,21 +149,25 @@ class BinaryOp(_FuncOp): def name_of_op(x, y, input_dtype): return ... """ - def __call__(self, x, y): - dtype = self._dtype_of(x) - dtype2 = self._dtype_of(y) - if self.output == 0: - self.validate_input(x) - return self.func(x, y, dtype) - if self.output == 1: - self.validate_input(y) - return self.func(x, y, dtype2) - # If we reached this point, inputs must have the same dtype - if dtype is not dtype2: - raise TypeError(f"Types must match, {dtype} != {dtype2}") - self.validate_input(x) + def __call__(self, out_type: DType, x, y): + x, y, dtype = self._validate(out_type, x, y) return self.func(x, y, dtype) + def _validate(self, out_type, x, y): + self.validate_output(out_type) + xtype = self._dtype_of(x) + ytype = self._dtype_of(y) + if self.output is None: + x = cast(x, xtype, out_type) + y = cast(y, ytype, out_type) + dtype = out_type + else: + dtype = find_common_dtype(xtype, ytype) + x = cast(x, xtype, dtype) + y = cast(y, ytype, dtype) + self.validate_input(dtype) + return x, y, dtype + class IndexUnaryOp(_FuncOp): """ @@ -184,20 +195,17 @@ def __init__(self, func, *, input=None, output=None, thunk_as_index=False): super().__init__(func, input=input, output=output) self.thunk_as_index = thunk_as_index - def __call__(self, val, row, col, thunk): - val_dtype = self._dtype_of(val) - self.validate_input(val) + def __call__(self, out_type: DType, val, row, col, thunk): if self.thunk_as_index: # Ensure thunk is an index thunk_type_str = self._mlirtype_of(thunk) if thunk_type_str != "index": raise GrbDomainMismatch("thunk must be index type") + val, dtype = UnaryOp._validate(self, out_type, val) else: - # Check that thunk dtype matches value dtype - thunk_dtype = self._dtype_of(thunk) - if val_dtype != thunk_dtype: - raise GrbDomainMismatch(f"Thunk dtype must match value dtype: {thunk_dtype} != {val_dtype}") - return self.func(val, row, col, thunk, val_dtype) + # Thunk dtype should make val dtype + val, thunk, dtype = BinaryOp._validate(self, out_type, val, thunk) + return self.func(val, row, col, thunk, dtype) class SelectOp(IndexUnaryOp): @@ -437,12 +445,12 @@ def oneb(x, y, dtype): BinaryOp.pair = BinaryOp.oneb -@BinaryOp._register(output=0) # dtype matches x +@BinaryOp._register def first(x, y, dtype): return x -@BinaryOp._register(output=1) # dtype matches y +@BinaryOp._register def second(x, y, dtype): return y @@ -489,59 +497,50 @@ def div(x, y, dtype): @SelectOp._register(output=bool, thunk_as_index=True) def tril(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) row_plus = arith.AddIOp(row, thunk) return arith.CmpIOp(CmpIPredicate.sle.build(), col, row_plus) @SelectOp._register(output=bool, thunk_as_index=True) def triu(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) row_plus = arith.AddIOp(row, thunk) return arith.CmpIOp(CmpIPredicate.sge.build(), col, row_plus) @SelectOp._register(output=bool, thunk_as_index=True) def diag(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) row_plus = arith.AddIOp(row, thunk) return arith.CmpIOp(CmpIPredicate.eq.build(), col, row_plus) @SelectOp._register(output=bool, thunk_as_index=True) def offdiag(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) row_plus = arith.AddIOp(row, thunk) return arith.CmpIOp(CmpIPredicate.ne.build(), col, row_plus) @SelectOp._register(output=bool, thunk_as_index=True) def colle(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) return arith.CmpIOp(CmpIPredicate.sle.build(), col, thunk) @SelectOp._register(output=bool, thunk_as_index=True) def colgt(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) return arith.CmpIOp(CmpIPredicate.sgt.build(), col, thunk) @SelectOp._register(output=bool, thunk_as_index=True) def rowle(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) return arith.CmpIOp(CmpIPredicate.sle.build(), row, thunk) @SelectOp._register(output=bool, thunk_as_index=True) def rowgt(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) return arith.CmpIOp(CmpIPredicate.sgt.build(), row, thunk) @SelectOp._register(output=bool) def valueeq(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) if val_dtype.is_float(): return arith.CmpFOp(CmpFPredicate.oeq.build(), val, thunk) else: @@ -550,7 +549,6 @@ def valueeq(val, row, col, thunk, val_dtype): @SelectOp._register(output=bool) def valuene(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) if val_dtype.is_float(): return arith.CmpFOp(CmpFPredicate.one.build(), val, thunk) else: @@ -559,7 +557,6 @@ def valuene(val, row, col, thunk, val_dtype): @SelectOp._register(output=bool) def valuelt(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) if val_dtype.is_float(): return arith.CmpFOp(CmpFPredicate.olt.build(), val, thunk) else: @@ -568,7 +565,6 @@ def valuelt(val, row, col, thunk, val_dtype): @SelectOp._register(output=bool) def valuele(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) if val_dtype.is_float(): return arith.CmpFOp(CmpFPredicate.ole.build(), val, thunk) else: @@ -577,7 +573,6 @@ def valuele(val, row, col, thunk, val_dtype): @SelectOp._register(output=bool) def valuegt(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) if val_dtype.is_float(): return arith.CmpFOp(CmpFPredicate.ogt.build(), val, thunk) else: @@ -586,7 +581,6 @@ def valuegt(val, row, col, thunk, val_dtype): @SelectOp._register(output=bool) def valuege(val, row, col, thunk, val_dtype): - i1 = ir.IntegerType.get_signless(1) if val_dtype.is_float(): return arith.CmpFOp(CmpFPredicate.oge.build(), val, thunk) else: diff --git a/mlir_graphblas/tensor.py b/mlir_graphblas/tensor.py index e8cfc16..f4b7ab2 100644 --- a/mlir_graphblas/tensor.py +++ b/mlir_graphblas/tensor.py @@ -74,7 +74,7 @@ def dup(self): f"""Returns a copy of the {self.baseclass}""" from . import implementations as impl - return impl.dup(self, intermediate=False) + return impl.dup(self.dtype, self, intermediate=False) @property def rtt(self): @@ -227,6 +227,7 @@ def clear(self): self._obj = None def dup(self): + # Don't pass value to `new` to avoid `set_element` logic s = Scalar.new(self.dtype) s._obj = self._obj return s diff --git a/mlir_graphblas/tests/test_operations.py b/mlir_graphblas/tests/test_operations.py index 78241c1..bc49dc8 100644 --- a/mlir_graphblas/tests/test_operations.py +++ b/mlir_graphblas/tests/test_operations.py @@ -59,6 +59,12 @@ def test_transpose_op(mm): assert x.is_rowwise() matrix_compare(z, xrows, xcols, xvals) + # Transpose into different dtype + z = Matrix.new(INT16, x.shape[1], x.shape[0]) + operations.transpose(z, x) + assert z.dtype + matrix_compare(z, xcols, xrows, xvals.astype(np.int16)) + def test_transpose_empty(mm): # Empty into empty @@ -77,19 +83,21 @@ def test_transpose_empty(mm): def test_ewise_add_vec(vs): x, y = vs - z = Vector.new(x.dtype, x.size()) - operations.ewise_add(z, BinaryOp.plus, x, y) - vector_compare(z, [0, 1, 2, 3], [1., 10., 22., 33.]) + for typ in (x.dtype, INT32): + z = Vector.new(typ, x.size()) + operations.ewise_add(z, BinaryOp.plus, x, y) + vector_compare(z, [0, 1, 2, 3], [1., 10., 22., 33.]) def test_ewise_add_mat(ms): x, y = ms - z = Matrix.new(x.dtype, *x.shape) - operations.ewise_add(z, BinaryOp.times, x, y) - matrix_compare(z, - [0, 0, 0, 1, 1, 1, 1], - [0, 1, 3, 0, 1, 3, 4], - [10, -20, -60, -3, -4, 40, -5]) + for typ in (x.dtype, INT32): + z = Matrix.new(typ, *x.shape) + operations.ewise_add(z, BinaryOp.times, x, y) + matrix_compare(z, + [0, 0, 0, 1, 1, 1, 1], + [0, 1, 3, 0, 1, 3, 4], + [10, -20, -60, -3, -4, 40, -5]) def test_ewise_add_empty(ms): @@ -103,25 +111,28 @@ def test_ewise_add_empty(ms): def test_ewise_add_scalar(ss): x, y = ss - z = Scalar.new(x.dtype) - operations.ewise_add(z, BinaryOp.times, x, y) - assert z._obj == x._obj - operations.ewise_add(z, BinaryOp.times, x, x) - assert z.extract_element() == x.extract_element() ** 2 + for typ in (x.dtype, INT32): + z = Scalar.new(typ) + operations.ewise_add(z, BinaryOp.times, x, y) + assert z._obj == x._obj + operations.ewise_add(z, BinaryOp.times, x, x) + assert z.extract_element() == x.extract_element() ** 2 def test_ewise_mult_vec(vs): x, y = vs - z = Vector.new(x.dtype, x.size()) - operations.ewise_mult(z, BinaryOp.plus, x, y) - vector_compare(z, [2, 3], [22., 33.]) + for typ in (x.dtype, INT32): + z = Vector.new(typ, x.size()) + operations.ewise_mult(z, BinaryOp.plus, x, y) + vector_compare(z, [2, 3], [22, 33]) def test_ewise_mult_mat(ms): x, y = ms - z = Matrix.new(x.dtype, *x.shape) - operations.ewise_mult(z, BinaryOp.first, x, y) - matrix_compare(z, [0, 0], [1, 3], [-1, -2]) + for typ in (x.dtype, INT32): + z = Matrix.new(typ, *x.shape) + operations.ewise_mult(z, BinaryOp.first, x, y) + matrix_compare(z, [0, 0], [1, 3], [-1, -2]) def test_ewise_mult_empty(ms): @@ -158,22 +169,23 @@ def test_mxm(mm): ycol = Matrix.new(y.dtype, *y.shape) ycol.build(*y.extract_tuples(), colwise=True) expected = [0, 1, 2, 2, 4], [0, 0, 0, 4, 3], [20.9, 16.5, 5.5, 70.4, 13.2] - # rowwise @ rowwise - z = Matrix.new(x.dtype, x.shape[0], y.shape[1]) - operations.mxm(z, Semiring.plus_times, x, y) - matrix_compare(z, *expected) - # rowwise @ colwise - z.clear() - operations.mxm(z, Semiring.plus_times, x, ycol) - matrix_compare(z, *expected) - # colwise @ colwise - z.clear() - operations.mxm(z, Semiring.plus_times, xcol, ycol) - matrix_compare(z, *expected) - # colwise @ rowwise - z.clear() - operations.mxm(z, Semiring.plus_times, xcol, y) - matrix_compare(z, *expected) + for typ in (x.dtype, FP64): + # rowwise @ rowwise + z = Matrix.new(typ, x.shape[0], y.shape[1]) + operations.mxm(z, Semiring.plus_times, x, y) + matrix_compare(z, *expected) + # rowwise @ colwise + z.clear() + operations.mxm(z, Semiring.plus_times, x, ycol) + matrix_compare(z, *expected) + # colwise @ colwise + z.clear() + operations.mxm(z, Semiring.plus_times, xcol, ycol) + matrix_compare(z, *expected) + # colwise @ rowwise + z.clear() + operations.mxm(z, Semiring.plus_times, xcol, y) + matrix_compare(z, *expected) def test_mxm_empty(mm): @@ -188,14 +200,15 @@ def test_mxm_empty(mm): def test_mxv(vs, mm): _, v = vs _, m = mm - z = Vector.new(m.dtype, m.shape[0]) - operations.mxv(z, Semiring.plus_times, m, v) - try: - vector_compare(z, [1, 2, 3, 5], [1., 6., 5., 7.]) - except AssertionError: - # Check for dense return, indicating lack of lex insert fix - vector_compare(z, [0, 1, 2, 3, 4, 5], [0., 1., 6., 5., 0., 7.]) - pytest.xfail("Waiting for lex insert fix") + for typ in (m.dtype, FP64): + z = Vector.new(typ, m.shape[0]) + operations.mxv(z, Semiring.plus_times, m, v) + try: + vector_compare(z, [1, 2, 3, 5], [1., 6., 5., 7.]) + except AssertionError: + # Check for dense return, indicating lack of lex insert fix + vector_compare(z, [0, 1, 2, 3, 4, 5], [0., 1., 6., 5., 0., 7.]) + pytest.xfail("Waiting for lex insert fix") def test_mxv_empty(vs, mm): @@ -211,9 +224,10 @@ def test_mxv_empty(vs, mm): def test_vxm(vs, mm): _, v = vs m, _ = mm - z = Vector.new(m.dtype, m.shape[1]) - operations.vxm(z, Semiring.plus_times, v, m) - vector_compare(z, [0, 1, 3, 5], [8.8, 11., 1.1, 2.2]) + for typ in (m.dtype, FP64): + z = Vector.new(typ, m.shape[1]) + operations.vxm(z, Semiring.plus_times, v, m) + vector_compare(z, [0, 1, 3, 5], [8.8, 11., 1.1, 2.2]) def test_vxm_empty(vs, mm): @@ -230,15 +244,16 @@ def test_apply_mat(ms): x, _ = ms xrows, xcols, xvals = x.extract_tuples() - # UnaryOp.abs - z = Matrix.new(x.dtype, *x.shape) - operations.apply(z, UnaryOp.abs, x) - matrix_compare(z, xrows, xcols, np.abs(xvals)) + for typ in (x.dtype, INT32): + # UnaryOp.abs + z = Matrix.new(typ, *x.shape) + operations.apply(z, UnaryOp.abs, x) + matrix_compare(z, xrows, xcols, np.abs(xvals)) - # BinaryOp.minus left=2 - z2 = Matrix.new(x.dtype, *x.shape) - operations.apply(z2, BinaryOp.minus, x, left=2) - matrix_compare(z2, xrows, xcols, 2 - xvals) + # BinaryOp.minus left=2 + z2 = Matrix.new(x.dtype, *x.shape) + operations.apply(z2, BinaryOp.minus, x, left=2) + matrix_compare(z2, xrows, xcols, 2 - xvals) # BinaryOp.gt right=-2 z3 = Matrix.new(BOOL, *x.shape) @@ -292,44 +307,47 @@ def test_apply_empty(mm): def test_apply_scalar(ss): x, y = ss - z = Scalar.new(x.dtype) - operations.apply(z, UnaryOp.ainv, x) - assert z.extract_element() == -x.extract_element() - operations.apply(z, UnaryOp.ainv, y) - assert z.nvals() == 0 - operations.apply(z, BinaryOp.minus, x, right=64) - assert z.extract_element() == x.extract_element() - 64 + for typ in (x.dtype, INT32): + z = Scalar.new(typ) + operations.apply(z, UnaryOp.ainv, x) + assert z.extract_element() == -x.extract_element() + operations.apply(z, UnaryOp.ainv, y) + assert z.nvals() == 0 + operations.apply(z, BinaryOp.minus, x, right=64) + assert z.extract_element() == x.extract_element() - 64 - with pytest.raises(exceptions.GrbError): - operations.apply(z, IndexUnaryOp.rowindex, x, thunk=0) + with pytest.raises(exceptions.GrbError): + operations.apply(z, IndexUnaryOp.rowindex, x, thunk=0) def test_select_vec(vs): x, _ = vs - # Select by index - z = Vector.new(x.dtype, x.size()) - operations.select(z, SelectOp.rowgt, x, 2) - vector_compare(z, [3], [30.]) + for typ in (x.dtype, INT32): + # Select by index + z = Vector.new(typ, x.size()) + operations.select(z, SelectOp.rowgt, x, 2) + vector_compare(z, [3], [30]) - # Select by value - z = Vector.new(x.dtype, x.size()) - operations.select(z, SelectOp.valuegt, x, 10.) - vector_compare(z, [2, 3], [20., 30.]) + # Select by value + z = Vector.new(typ, x.size()) + operations.select(z, SelectOp.valuegt, x, 10.) + vector_compare(z, [2, 3], [20, 30]) def test_select_mat(mm): _, y = mm - z = Matrix.new(y.dtype, *y.shape) - operations.select(z, SelectOp.triu, y, -1) - assert z.is_rowwise() - matrix_compare(z, [0, 1, 1, 2], [4, 0, 4, 3], [6., 1., 8., 2.]) - - # Transposed - z = Matrix.new(y.dtype, y.shape[1], y.shape[0]) - operations.select(z, SelectOp.triu, y, 0, desc=desc.T0) - assert z.is_colwise() - matrix_compare(z, [0, 0, 0], [1, 3, 5], [1., 5., 7.]) + for typ in (y.dtype, INT32): + z = Matrix.new(typ, *y.shape) + operations.select(z, SelectOp.triu, y, -1) + assert z.is_rowwise() + matrix_compare(z, [0, 1, 1, 2], [4, 0, 4, 3], [6, 1, 8, 2]) + + # Transposed + z = Matrix.new(typ, y.shape[1], y.shape[0]) + operations.select(z, SelectOp.triu, y, 0, desc=desc.T0) + assert z.is_colwise() + matrix_compare(z, [0, 0, 0], [1, 3, 5], [1, 5, 7]) def test_select_empty(vs): @@ -349,35 +367,38 @@ def test_select_empty(vs): def test_select_scalar(ss): x, y = ss - z = Scalar.new(x.dtype) - operations.select(z, SelectOp.valuegt, x, 1004) - assert z.nvals() == 0 - operations.select(z, SelectOp.valuegt, x, 4) - assert z.nvals() == 1 - assert z.extract_element() == x.extract_element() - # Scalars are treated as having row=0, col=0 for index purposes - operations.select(z, SelectOp.rowle, x, 0) - assert z.nvals() == 1 - assert z.extract_element() == x.extract_element() + for typ in (x.dtype, INT32): + z = Scalar.new(typ) + operations.select(z, SelectOp.valuegt, x, 1004) + assert z.nvals() == 0 + operations.select(z, SelectOp.valuegt, x, 4) + assert z.nvals() == 1 + assert z.extract_element() == x.extract_element() + # Scalars are treated as having row=0, col=0 for index purposes + operations.select(z, SelectOp.rowle, x, 0) + assert z.nvals() == 1 + assert z.extract_element() == x.extract_element() def test_reduce_rowwise(mm): x, _ = mm - z = Vector.new(x.dtype, x.shape[0]) - operations.reduce_to_vector(z, Monoid.plus, x) - try: - vector_compare(z, [0, 1, 2, 4], [3.3, 3.3, 9.9, 6.6]) - except AssertionError: - # Check for dense return, indicating lack of lex insert fix - vector_compare(z, [0, 1, 2, 3, 4], [3.3, 3.3, 9.9, 0.0, 6.6]) - pytest.xfail("Waiting for lex insert fix") + for typ in (x.dtype, FP64): + z = Vector.new(typ, x.shape[0]) + operations.reduce_to_vector(z, Monoid.plus, x) + try: + vector_compare(z, [0, 1, 2, 4], [3.3, 3.3, 9.9, 6.6]) + except AssertionError: + # Check for dense return, indicating lack of lex insert fix + vector_compare(z, [0, 1, 2, 3, 4], [3.3, 3.3, 9.9, 0.0, 6.6]) + pytest.xfail("Waiting for lex insert fix") def test_reduce_colwise(mm): x, _ = mm - z = Vector.new(x.dtype, x.shape[1]) - operations.reduce_to_vector(z, Monoid.times, x, desc=desc.T0) - vector_compare(z, [0, 1, 2, 3, 5], [4.4, 5.5, 6.6, 3.63, 2.2]) + for typ in (x.dtype, FP64): + z = Vector.new(typ, x.shape[1]) + operations.reduce_to_vector(z, Monoid.times, x, desc=desc.T0) + vector_compare(z, [0, 1, 2, 3, 5], [4.4, 5.5, 6.6, 3.63, 2.2]) def test_reduce_to_vector_empty(vs, mm): @@ -398,21 +419,23 @@ def test_reduce_to_vector_empty(vs, mm): def test_reduce_scalar_mat(mm): x, _ = mm _, _, xvals = x.extract_tuples() - s = Scalar.new(x.dtype) - operations.reduce_to_scalar(s, Monoid.times, x) - np_assert_allclose(s.extract_element(), functools.reduce(operator.mul, xvals)) + for typ in (x.dtype, FP64): + s = Scalar.new(typ) + operations.reduce_to_scalar(s, Monoid.times, x) + np_assert_allclose(s.extract_element(), functools.reduce(operator.mul, xvals)) - # Verify transpose has no effect on scalar reduction - operations.reduce_to_scalar(s, Monoid.plus, x, desc=desc.T0) - np_assert_allclose(s.extract_element(), functools.reduce(operator.add, xvals)) + # Verify transpose has no effect on scalar reduction + operations.reduce_to_scalar(s, Monoid.plus, x, desc=desc.T0) + np_assert_allclose(s.extract_element(), functools.reduce(operator.add, xvals)) def test_reduce_scalar_vec(vs): x, _ = vs _, xvals = x.extract_tuples() - s = Scalar.new(x.dtype) - operations.reduce_to_scalar(s, Monoid.times, x) - np_assert_allclose(s.extract_element(), functools.reduce(operator.mul, xvals)) + for typ in (x.dtype, INT32, FP64): + s = Scalar.new(typ) + operations.reduce_to_scalar(s, Monoid.times, x) + np_assert_allclose(s.extract_element(), functools.reduce(operator.mul, xvals)) def test_reduce_to_scalar_empty(): @@ -431,75 +454,77 @@ def test_reduce_to_scalar_empty(): def test_extract_vec(vs): x, _ = vs xidx, xvals = x.extract_tuples() - z = Vector.new(x.dtype, 3) - operations.extract(z, x, [0, 1, 3]) - vector_compare(z, [1, 2], [10., 30.]) + for typ in (x.dtype, INT32): + z = Vector.new(typ, 3) + operations.extract(z, x, [0, 1, 3]) + vector_compare(z, [1, 2], [10., 30.]) - # Extract all - z2 = Vector.new(x.dtype, *x.shape) - operations.extract(z2, x, None) - vector_compare(z2, xidx, xvals) + # Extract all + z2 = Vector.new(typ, *x.shape) + operations.extract(z2, x, None) + vector_compare(z2, xidx, xvals) - # Extract duplicates - z3 = Vector.new(x.dtype, 5) - operations.extract(z3, x, [0, 3, 3, 3, 2]) - vector_compare(z3, [1, 2, 3, 4], [30., 30., 30., 20.]) + # Extract duplicates + z3 = Vector.new(typ, 5) + operations.extract(z3, x, [0, 3, 3, 3, 2]) + vector_compare(z3, [1, 2, 3, 4], [30., 30., 30., 20.]) def test_extract_mat(mm): x, _ = mm xrows, xcols, xvals = x.extract_tuples() - - # Extract all rows, all cols - z = Matrix.new(x.dtype, *x.shape) - operations.extract(z, x, None, None) - matrix_compare(z, xrows, xcols, xvals) - - # Extract some rows, some cols - z2 = Matrix.new(x.dtype, 2, 5) - operations.extract(z2, x, [0, 4], [1, 2, 3, 5, 3]) - matrix_compare(z2, [0, 0, 0, 1], [2, 3, 4, 1], [1.1, 2.2, 1.1, 6.6]) - - # Extract some rows, all cols - z3 = Matrix.new(x.dtype, 4, x.shape[1]) - operations.extract(z3, x, [0, 4, 3, 0], None) - matrix_compare(z3, [0, 0, 1, 3, 3], [3, 5, 2, 3, 5], [1.1, 2.2, 6.6, 1.1, 2.2]) - - # Extract all rows, some cols - z4 = Matrix.new(x.dtype, x.shape[0], 5) - operations.extract(z4, x, None, [1, 5, 3, 2, 1]) - matrix_compare(z4, - [0, 0, 1, 2, 2, 4], - [1, 2, 2, 0, 4, 3], - [2.2, 1.1, 3.3, 5.5, 5.5, 6.6]) + for typ in (x.dtype, FP64): + # Extract all rows, all cols + z = Matrix.new(typ, *x.shape) + operations.extract(z, x, None, None) + matrix_compare(z, xrows, xcols, xvals) + + # Extract some rows, some cols + z2 = Matrix.new(typ, 2, 5) + operations.extract(z2, x, [0, 4], [1, 2, 3, 5, 3]) + matrix_compare(z2, [0, 0, 0, 1], [2, 3, 4, 1], [1.1, 2.2, 1.1, 6.6]) + + # Extract some rows, all cols + z3 = Matrix.new(typ, 4, x.shape[1]) + operations.extract(z3, x, [0, 4, 3, 0], None) + matrix_compare(z3, [0, 0, 1, 3, 3], [3, 5, 2, 3, 5], [1.1, 2.2, 6.6, 1.1, 2.2]) + + # Extract all rows, some cols + z4 = Matrix.new(typ, x.shape[0], 5) + operations.extract(z4, x, None, [1, 5, 3, 2, 1]) + matrix_compare(z4, + [0, 0, 1, 2, 2, 4], + [1, 2, 2, 0, 4, 3], + [2.2, 1.1, 3.3, 5.5, 5.5, 6.6]) def test_extract_vec_from_mat(mm): x, _ = mm - # Extract partial column - z = Vector.new(x.dtype, 3) - operations.extract(z, x, [0, 1, 4], 2) - vector_compare(z, [2], [6.6]) + for typ in (x.dtype, FP64): + # Extract partial column + z = Vector.new(typ, 3) + operations.extract(z, x, [0, 1, 4], 2) + vector_compare(z, [2], [6.6]) - # Extract full column - z1 = Vector.new(x.dtype, x.shape[0]) - operations.extract(z1, x, None, 3) - vector_compare(z1, [0, 1], [1.1, 3.3]) + # Extract full column + z1 = Vector.new(typ, x.shape[0]) + operations.extract(z1, x, None, 3) + vector_compare(z1, [0, 1], [1.1, 3.3]) - # Extract partial row - z2 = Vector.new(x.dtype, 8) - operations.extract(z2, x, 0, [0, 1, 3, 4, 5, 3, 5, 3]) - vector_compare(z2, [2, 4, 5, 6, 7], [1.1, 2.2, 1.1, 2.2, 1.1]) + # Extract partial row + z2 = Vector.new(typ, 8) + operations.extract(z2, x, 0, [0, 1, 3, 4, 5, 3, 5, 3]) + vector_compare(z2, [2, 4, 5, 6, 7], [1.1, 2.2, 1.1, 2.2, 1.1]) - # Extract full row - z3 = Vector.new(x.dtype, x.shape[1]) - operations.extract(z3, x, 0, None) - vector_compare(z3, [3, 5], [1.1, 2.2]) + # Extract full row + z3 = Vector.new(typ, x.shape[1]) + operations.extract(z3, x, 0, None) + vector_compare(z3, [3, 5], [1.1, 2.2]) - # Extract partial column via transposed input - z3 = Vector.new(x.dtype, 3) - operations.extract(z3, x, 2, [0, 1, 4], desc=desc.T0) - vector_compare(z3, [2], [6.6]) + # Extract partial column via transposed input + z3 = Vector.new(typ, 3) + operations.extract(z3, x, 2, [0, 1, 4], desc=desc.T0) + vector_compare(z3, [2], [6.6]) def test_extract_empty(mm): @@ -519,20 +544,21 @@ def test_extract_empty(mm): def test_assign_vec(vs): x, y = vs - # Assign all - operations.assign(y, x, accum=BinaryOp.plus) - vector_compare(y, [0, 1, 2, 3], [1., 10., 22., 33.]) + # # Assign all + # operations.assign(y, x, accum=BinaryOp.plus) + # vector_compare(y, [0, 1, 2, 3], [1., 10., 22., 33.]) - # Expand - z = Vector.new(x.dtype, 16) - operations.assign(z, x, [1, 3, 13, 10, 2]) - assert z.size() == 16 - vector_compare(z, [3, 10, 13], [10., 30., 20.]) + for typ in (x.dtype, INT32): + # Expand + z = Vector.new(typ, 16) + operations.assign(z, x, [1, 3, 13, 10, 2]) + assert z.size() == 16 + vector_compare(z, [3, 10, 13], [10., 30., 20.]) - # Empty input - a = Vector.new(x.dtype, 3) - operations.assign(z, a, [1, 3, 13]) - vector_compare(z, [10], [30.]) + # Empty input + a = Vector.new(typ, 3) + operations.assign(z, a, [1, 3, 13]) + vector_compare(z, [10], [30.]) def test_assign_mat(ms): @@ -546,29 +572,30 @@ def test_assign_mat(ms): [0, 1, 3, 0, 1, 3, 4], [10, 19, 28, -3, -4, 40, -5]) - # Assign new rows, new cols - z2 = Matrix.new(x.dtype, 4, 8) - operations.assign(z2, x, [3, 0], [0, 3, 4, 1, 7]) - matrix_compare(z2, - [0, 0, 0, 3, 3], - [0, 3, 7, 1, 3], - [-3, -4, -5, -2, -1]) - - # Assign identical rows, new cols - z3 = Matrix.new(x.dtype, x.shape[0], 8) - operations.assign(z3, x, None, [0, 3, 4, 1, 7]) - matrix_compare(z3, - [0, 0, 1, 1, 1], - [1, 3, 0, 3, 7], - [-2, -1, -3, -4, -5]) - - # Assign new rows, identical cols - z4 = Matrix.new(x.dtype, 4, x.shape[1]) - operations.assign(z4, x, [3, 0], None) - matrix_compare(z4, - [0, 0, 0, 3, 3], - [0, 1, 4, 1, 3], - [-3, -4, -5, -1, -2]) + for typ in (x.dtype, INT32): + # Assign new rows, new cols + z2 = Matrix.new(typ, 4, 8) + operations.assign(z2, x, [3, 0], [0, 3, 4, 1, 7]) + matrix_compare(z2, + [0, 0, 0, 3, 3], + [0, 3, 7, 1, 3], + [-3, -4, -5, -2, -1]) + + # Assign identical rows, new cols + z3 = Matrix.new(typ, x.shape[0], 8) + operations.assign(z3, x, None, [0, 3, 4, 1, 7]) + matrix_compare(z3, + [0, 0, 1, 1, 1], + [1, 3, 0, 3, 7], + [-2, -1, -3, -4, -5]) + + # Assign new rows, identical cols + z4 = Matrix.new(typ, 4, x.shape[1]) + operations.assign(z4, x, [3, 0], None) + matrix_compare(z4, + [0, 0, 0, 3, 3], + [0, 1, 4, 1, 3], + [-3, -4, -5, -1, -2]) def test_assign_vec_to_mat(ms): @@ -584,41 +611,42 @@ def test_assign_vec_to_mat(ms): [0, 1, 2, 3, 4, 0, 1, 4], [5, -1, 4, 1, 2, -3, -4, -5]) - # Assign row with new indices - z2 = x.dup() - r1 = Vector.new(x.dtype, 3) - r1.build([0, 2], [100, 150]) - operations.assign(z2, r1, 0, [4, 0, 2], accum=BinaryOp.plus) - matrix_compare(z2, - [0, 0, 0, 0, 1, 1, 1], - [1, 2, 3, 4, 0, 1, 4], - [-1, 150, -2, 100, -3, -4, -5]) - - # Assign col with identical indices - z3 = x.dup() - c0 = Vector.new(x.dtype, x.shape[0]) - c0.build([0, 1], [97, 99]) - operations.assign(z3, c0, None, 3, accum=BinaryOp.plus) - matrix_compare(z3, - [0, 0, 1, 1, 1, 1], - [1, 3, 0, 1, 3, 4], - [-1, 95, -3, -4, 99, -5]) - - # Assign col with new indices - z4 = x.dup() - c1 = Vector.new(x.dtype, 1) - c1.build([0], [101]) - operations.assign(z4, c1, [1], 3, accum=BinaryOp.plus) - matrix_compare(z4, - [0, 0, 1, 1, 1, 1], - [1, 3, 0, 1, 3, 4], - [-1, -2, -3, -4, 101, -5]) - - # Empty input col - z5 = x.dup() - a = Vector.new(x.dtype, 2) - operations.assign(z5, a, [0, 1], 1) - matrix_compare(z5, [0, 1, 1], [3, 0, 4], [-2, -3, -5]) + for typ in (x.dtype, INT32): + # Assign row with new indices + z2 = x.dup() + r1 = Vector.new(typ, 3) + r1.build([0, 2], [100, 150]) + operations.assign(z2, r1, 0, [4, 0, 2], accum=BinaryOp.plus) + matrix_compare(z2, + [0, 0, 0, 0, 1, 1, 1], + [1, 2, 3, 4, 0, 1, 4], + [-1, 150, -2, 100, -3, -4, -5]) + + # Assign col with identical indices + z3 = x.dup() + c0 = Vector.new(typ, x.shape[0]) + c0.build([0, 1], [97, 99]) + operations.assign(z3, c0, None, 3, accum=BinaryOp.plus) + matrix_compare(z3, + [0, 0, 1, 1, 1, 1], + [1, 3, 0, 1, 3, 4], + [-1, 95, -3, -4, 99, -5]) + + # Assign col with new indices + z4 = x.dup() + c1 = Vector.new(typ, 1) + c1.build([0], [101]) + operations.assign(z4, c1, [1], 3, accum=BinaryOp.plus) + matrix_compare(z4, + [0, 0, 1, 1, 1, 1], + [1, 3, 0, 1, 3, 4], + [-1, -2, -3, -4, 101, -5]) + + # Empty input col + z5 = x.dup() + a = Vector.new(typ, 2) + operations.assign(z5, a, [0, 1], 1) + matrix_compare(z5, [0, 1, 1], [3, 0, 4], [-2, -3, -5]) def test_assign_scalar_to_vec(vs): diff --git a/mlir_graphblas/tests/test_types.py b/mlir_graphblas/tests/test_types.py index 1facb33..0f66623 100644 --- a/mlir_graphblas/tests/test_types.py +++ b/mlir_graphblas/tests/test_types.py @@ -1,6 +1,9 @@ import pytest import numpy as np -from ..types import DType, BOOL, INT8, INT16, INT32, INT64, FP32, FP64, RankedTensorType +from ..types import ( + DType, BOOL, INT8, INT16, INT32, INT64, FP32, FP64, RankedTensorType, + find_common_dtype +) from mlir.dialects.sparse_tensor import DimLevelType from mlir import ir @@ -53,3 +56,15 @@ def test_rtt(): 'dimOrdering = affine_map<(d0, d1) -> (d1, d0)> ' '}>>' ) + + +def test_find_common_dtype(): + assert find_common_dtype(FP64, FP64) == FP64 + assert find_common_dtype(FP64, FP32) == FP64 + assert find_common_dtype(FP64, INT8) == FP64 + assert find_common_dtype(INT8, FP64) == FP64 + assert find_common_dtype(INT16, FP32) == FP32 + assert find_common_dtype(INT16, INT16) == INT16 + assert find_common_dtype(INT8, INT64) == INT64 + assert find_common_dtype(INT32, BOOL) == INT32 + assert find_common_dtype(BOOL, BOOL) == BOOL diff --git a/mlir_graphblas/tests/utils.py b/mlir_graphblas/tests/utils.py index 077a850..fe9bdaf 100644 --- a/mlir_graphblas/tests/utils.py +++ b/mlir_graphblas/tests/utils.py @@ -6,6 +6,7 @@ def vector_compare(vec, i, v): assert vec.ndims == 1 idx, vals = vec.extract_tuples() + assert vals.dtype == vec.dtype.np_type np_assert_equal(idx, i) np_assert_allclose(vals, v) @@ -13,6 +14,7 @@ def vector_compare(vec, i, v): def matrix_compare(mat, r, c, v): assert mat.ndims == 2 rows, cols, vals = mat.extract_tuples() + assert vals.dtype == mat.dtype.np_type if mat.is_rowwise(): sort_order = np.argsort(r) else: diff --git a/mlir_graphblas/types.py b/mlir_graphblas/types.py index 2f9c4eb..7080020 100644 --- a/mlir_graphblas/types.py +++ b/mlir_graphblas/types.py @@ -32,6 +32,9 @@ def is_float(self): def is_int(self): return self.mlir_name[0] == 'i' + def bitwidth(self): + return int(self.mlir_name[1:]) + def build_max(self): if self.is_float(): return arith.ConstantOp(self.build_mlir_type(), math.inf) @@ -109,6 +112,7 @@ def from_np(cls, nptype): FP32 = DType._register('FP32', 'f32', np.float32, ir.F32Type.get) FP64 = DType._register('FP64', 'f64', np.float64, ir.F64Type.get) # TODO: should we handle complex types? +# TODO: should we handle unsigned ints? If so, utility functions need to change. class RankedTensorType: @@ -138,3 +142,36 @@ def copy(self, dtype=None, sparsity=None, ordering=None): if ordering is None: ordering = self.ordering return RankedTensorType(dtype, sparsity, ordering) + + +def find_common_dtype(left: DType, right: DType): + """ + Follow C convention for unifying types + """ + if left == right: + return left + + bits = max(left.bitwidth(), right.bitwidth()) + # There are no unsigned types; if that changes, this will need to change as well + if left.is_float() or right.is_float(): + return DType.from_gb(f"FP{bits}") + return DType.from_gb(f"INT{bits}") + + +def cast(val, intype: DType, outtype: DType): + if intype == outtype: + return val + + out_mlir = outtype.build_mlir_type() + if outtype.is_float(): + if intype.is_int(): + return arith.SIToFPOp(out_mlir, val) + if intype.bitwidth() < outtype.bitwidth(): + return arith.ExtFOp(out_mlir, val) + return arith.TruncFOp(out_mlir, val) + # Output is int + if intype.is_float(): + return arith.FPToSIOp(out_mlir, val) + if intype.bitwidth() < outtype.bitwidth(): + return arith.ExtSIOp(out_mlir, val) + return arith.TruncIOp(out_mlir, val)