Skip to content

Data types no longer need to match #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
399 changes: 213 additions & 186 deletions mlir_graphblas/implementations.py

Large diffs are not rendered by default.

108 changes: 21 additions & 87 deletions mlir_graphblas/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def update(output: SparseObject,
if accum is None or output._obj is None:
output.set_element(tensor.extract_element())
else:
output._obj = impl.ewise_add(accum, output, tensor)._obj
output._obj = impl.ewise_add(output.dtype, accum, output, tensor)._obj
return

if not isinstance(output, SparseTensor):
Expand All @@ -80,7 +80,7 @@ def update(output: SparseObject,
else: # mask=Y, accum=N, replace=N
# Apply inverted mask, then eWiseAdd
output._replace(impl.select_by_mask(output, mask, desc_inverted))
result = impl.ewise_add(BinaryOp.oneb, output, tensor)
result = impl.ewise_add(output.dtype, BinaryOp.oneb, output, tensor)
else:
if mask is None: # mask=N, accum=N, replace=?, w/ indices
# Drop indices in output, then eWiseAdd
Expand All @@ -106,22 +106,22 @@ def update(output: SparseObject,
# Select the row/col indices in the mask, apply it inverted to the output, then eWiseAdd
new_mask = impl.select_by_indices(mask, row_indices, col_indices)
output._replace(impl.select_by_mask(output, new_mask, desc_inverted))
result = impl.ewise_add(BinaryOp.oneb, output, tensor)
result = impl.ewise_add(output.dtype, BinaryOp.oneb, output, tensor)
elif mask is None or not desc.replace:
# eWiseAdd using accum
result = impl.ewise_add(accum, output, tensor)
result = impl.ewise_add(output.dtype, accum, output, tensor)
else:
# Mask the output, then perform eWiseAdd using accum
output._replace(impl.select_by_mask(output, mask, desc))
result = impl.ewise_add(accum, output, tensor)
result = impl.ewise_add(output.dtype, accum, output, tensor)

if result is output:
# This can happen if empty tensors are used as input
return output

# If not an intermediate result, make a copy
if not result._intermediate_result:
result = impl.dup(result)
# If not an intermediate result or wrong dtype, make a copy
if not result._intermediate_result or result.dtype != output.dtype:
result = impl.dup(output.dtype, result)

output._replace(result)

Expand All @@ -132,10 +132,6 @@ def transpose(out: Matrix,
mask: Optional[SparseTensor] = None,
accum: Optional[BinaryOp] = None,
desc: Descriptor = NULL_DESC):
# Verify dtypes
if out.dtype != tensor.dtype:
raise GrbDomainMismatch(f"output type must be {tensor.dtype}, not {out.dtype}")

# Apply descriptor transpose
if tensor.ndims != 2:
raise TypeError(f"transpose requires Matrix, not {type(tensor)}")
Expand Down Expand Up @@ -171,14 +167,6 @@ def ewise_add(out: SparseTensor,
if type(op) is not BinaryOp:
raise TypeError(f"op must be BinaryOp, Monoid, or Semiring")

# Verify dtypes
if op.output is not None and type(op.output) is not int:
raise GrbDomainMismatch("op must return same type as inputs with ewise_add")
if left.dtype != right.dtype:
raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}")
if out.dtype != left.dtype:
raise GrbDomainMismatch(f"output type must be {left.dtype}, not {out.dtype}")

# Apply transposes
if desc.transpose0 and left.ndims == 2:
left = TransposedMatrix.wrap(left)
Expand All @@ -195,7 +183,7 @@ def ewise_add(out: SparseTensor,
left = impl.select_by_mask(left, mask, desc)
right = impl.select_by_mask(right, mask, desc)

result = impl.ewise_add(op, left, right)
result = impl.ewise_add(out.dtype, op, left, right)
update(out, result, mask, accum, desc)


Expand All @@ -214,13 +202,6 @@ def ewise_mult(out: SparseTensor,
else:
raise TypeError(f"op must be BinaryOp, Monoid, or Semiring")

# Verify dtypes
if left.dtype != right.dtype:
raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}")
required_out_dtype = op.get_output_type(left.dtype, right.dtype)
if out.dtype != required_out_dtype:
raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}")

# Apply transposes
if desc.transpose0 and left.ndims == 2:
left = TransposedMatrix.wrap(left)
Expand All @@ -237,7 +218,7 @@ def ewise_mult(out: SparseTensor,
# Only need to apply mask to one of the inputs
left = impl.select_by_mask(left, mask, desc)

result = impl.ewise_mult(op, left, right)
result = impl.ewise_mult(out.dtype, op, left, right)
update(out, result, mask, accum, desc)


Expand All @@ -253,13 +234,6 @@ def mxm(out: Matrix,
if type(op) is not Semiring:
raise TypeError(f"op must be Semiring, not {type(op)}")

# Verify dtypes
if left.dtype != right.dtype:
raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}")
required_out_dtype = op.binop.get_output_type(left.dtype, right.dtype)
if out.dtype != required_out_dtype:
raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}")

# Apply transposes
if left.ndims != right.ndims != 2:
raise GrbDimensionMismatch("mxm requires rank 2 tensors")
Expand All @@ -285,7 +259,7 @@ def mxm(out: Matrix,
right = impl.flip_layout(right)

# TODO: apply the mask during the computation, not at the end
result = impl.mxm(op, left, right)
result = impl.mxm(out.dtype, op, left, right)
if mask is not None:
result = impl.select_by_mask(result, mask, desc)
update(out, result, mask, accum, desc)
Expand All @@ -303,13 +277,6 @@ def mxv(out: Vector,
if type(op) is not Semiring:
raise TypeError(f"op must be Semiring, not {type(op)}")

# Verify dtypes
if left.dtype != right.dtype:
raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}")
required_out_dtype = op.binop.get_output_type(left.dtype, right.dtype)
if out.dtype != required_out_dtype:
raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}")

# Apply transpose
if left.ndims != 2:
raise GrbDimensionMismatch("mxv requires matrix as first input")
Expand All @@ -325,7 +292,7 @@ def mxv(out: Vector,
raise GrbDimensionMismatch(f"output size should be {left.shape[0]} not {out.shape[0]}")

# TODO: apply the mask during the computation, not at the end
result = impl.mxv(op, left, right)
result = impl.mxv(out.dtype, op, left, right)
if mask is not None:
result = impl.select_by_mask(result, mask, desc)
update(out, result, mask, accum, desc)
Expand All @@ -343,13 +310,6 @@ def vxm(out: Vector,
if type(op) is not Semiring:
raise TypeError(f"op must be Semiring, not {type(op)}")

# Verify dtypes
if left.dtype != right.dtype:
raise GrbDomainMismatch(f"inputs must have same dtype: {left.dtype} != {right.dtype}")
required_out_dtype = op.binop.get_output_type(left.dtype, right.dtype)
if out.dtype != required_out_dtype:
raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}")

# Apply transpose
if right.ndims != 2:
raise GrbDimensionMismatch("vxm requires matrix as second input")
Expand All @@ -365,7 +325,7 @@ def vxm(out: Vector,
raise GrbDimensionMismatch(f"output size should be {right.shape[1]} not {out.shape[0]}")

# TODO: apply the mask during the computation, not at the end
result = impl.vxm(op, left, right)
result = impl.vxm(out.dtype, op, left, right)
if mask is not None:
result = impl.select_by_mask(result, mask, desc)
update(out, result, mask, accum, desc)
Expand All @@ -386,7 +346,6 @@ def apply(out: SparseTensor,
if optype is UnaryOp:
if thunk is not None or left is not None or right is not None:
raise TypeError("UnaryOp does not accept thunk, left, or right")
required_out_dtype = op.get_output_type(tensor.dtype)
elif optype is BinaryOp:
if thunk is not None:
raise TypeError("BinaryOp accepts left or thing, not thunk")
Expand All @@ -396,23 +355,16 @@ def apply(out: SparseTensor,
raise TypeError("Cannot provide both left and right")
if left is not None:
left = ensure_scalar_of_type(left, tensor.dtype)
required_out_dtype = op.get_output_type(left.dtype, tensor.dtype)
else:
right = ensure_scalar_of_type(right, tensor.dtype)
required_out_dtype = op.get_output_type(tensor.dtype, right.dtype)
elif optype is IndexUnaryOp:
if left is not None or right is not None:
raise TypeError("IndexUnaryOp accepts thunk, not left or right")
thunk_dtype = INT64 if op.thunk_as_index else tensor.dtype
thunk = ensure_scalar_of_type(thunk, thunk_dtype)
required_out_dtype = op.get_output_type(tensor.dtype, thunk.dtype)
else:
raise TypeError(f"op must be UnaryOp, BinaryOp, or IndexUnaryOp, not {type(op)}")

# Verify dtype
if out.dtype != required_out_dtype:
raise GrbDomainMismatch(f"output type must be {required_out_dtype}, not {out.dtype}")

# Apply transpose
if desc.transpose0 and tensor.ndims == 2:
tensor = TransposedMatrix.wrap(tensor)
Expand All @@ -433,9 +385,9 @@ def apply(out: SparseTensor,
and desc is NULL_DESC
and not tensor._intermediate_result
):
impl.apply(op, tensor, left, right, None, inplace=True)
impl.apply(out.dtype, op, tensor, left, right, None, inplace=True)
else:
result = impl.apply(op, tensor, left, right, thunk)
result = impl.apply(out.dtype, op, tensor, left, right, thunk)
update(out, result, mask, accum, desc)


Expand All @@ -452,8 +404,6 @@ def select(out: SparseTensor,
raise TypeError(f"op must be SelectOp, not {type(op)}")

# Verify dtypes
if out.dtype != tensor.dtype:
raise GrbDomainMismatch(f"output dtype must match input dtype: {out.dtype} != {tensor.dtype}")
thunk_dtype = INT64 if op.thunk_as_index else tensor.dtype
thunk = ensure_scalar_of_type(thunk, thunk_dtype)

Expand All @@ -468,7 +418,7 @@ def select(out: SparseTensor,
if mask is not None:
tensor = impl.select_by_mask(tensor, mask, desc)

result = impl.select(op, tensor, thunk)
result = impl.select(out.dtype, op, tensor, thunk)
update(out, result, mask, accum, desc)


Expand All @@ -483,10 +433,6 @@ def reduce_to_vector(out: Vector,
if type(op) is not Monoid:
raise TypeError(f"op must be Monoid, not {type(op)}")

# Verify dtypes
if out.dtype != tensor.dtype:
raise GrbDomainMismatch(f"output dtype must match input dtype: {out.dtype} != {tensor.dtype}")

# Apply transpose
if tensor.ndims != 2:
raise GrbDimensionMismatch("reduce_to_vector requires matrix input")
Expand All @@ -500,7 +446,7 @@ def reduce_to_vector(out: Vector,
raise GrbDimensionMismatch(f"output size should be {tensor.shape[0]} not {out.shape[0]}")

# TODO: apply the mask during the computation, not at the end
result = impl.reduce_to_vector(op, tensor)
result = impl.reduce_to_vector(out.dtype, op, tensor)
if mask is not None:
result = impl.select_by_mask(result, mask, desc)
update(out, result, mask, accum, desc)
Expand All @@ -516,15 +462,11 @@ def reduce_to_scalar(out: Scalar,
if type(op) is not Monoid:
raise TypeError(f"op must be Monoid, not {type(op)}")

# Verify dtypes
if out.dtype != tensor.dtype:
raise GrbDomainMismatch(f"output dtype must match input dtype: {out.dtype} != {tensor.dtype}")

# Compare shapes
if out.ndims != 0:
raise GrbDimensionMismatch("reduce_to_scalar requires scalar output")

result = impl.reduce_to_scalar(op, tensor)
result = impl.reduce_to_scalar(out.dtype, op, tensor)
update(out, result, accum=accum, desc=desc)


Expand All @@ -539,10 +481,6 @@ def extract(out: SparseTensor,
"""
Setting row_indices or col_indices to `None` is the equivalent of GrB_ALL
"""
# Verify dtypes
if out.dtype != tensor.dtype:
raise GrbDomainMismatch(f"output must have same dtype as input: {out.dtype} != {tensor.dtype}")

# Apply transpose
if desc.transpose0 and tensor.ndims == 2:
tensor = TransposedMatrix.wrap(tensor)
Expand Down Expand Up @@ -587,7 +525,7 @@ def extract(out: SparseTensor,
if out.shape != expected_out_shape:
raise GrbDimensionMismatch(f"output shape mismatch: {out.shape} != {expected_out_shape}")

result = impl.extract(tensor, row_indices, col_indices, row_size, col_size)
result = impl.extract(out.dtype, tensor, row_indices, col_indices, row_size, col_size)
if mask is not None:
result = impl.select_by_mask(result, mask, desc)
update(out, result, mask, accum, desc)
Expand All @@ -610,10 +548,6 @@ def assign(out: SparseTensor,
raise TypeError(f"tensor must be a SparseObject or Python scalar, not {type(tensor)}")
tensor = ensure_scalar_of_type(tensor, out.dtype)

# Verify dtypes
if out.dtype != tensor.dtype:
raise GrbDomainMismatch(f"output must have same dtype as input: {out.dtype} != {tensor.dtype}")

# Apply transpose
if desc.transpose0 and tensor.ndims == 2:
tensor = TransposedMatrix.wrap(tensor)
Expand Down Expand Up @@ -653,7 +587,7 @@ def assign(out: SparseTensor,
if mask is None:
raise GrbError("This will create a dense matrix. Please provide a mask or indices.")
# Use mask to build an iso-valued Matrix
result = impl.apply(BinaryOp.second, mask, right=tensor)
result = impl.apply(out.dtype, BinaryOp.second, mask, right=tensor)
else:
if out.ndims == 1: # Vector output
result = impl.build_iso_vector_from_indices(out.dtype, *out.shape, row_indices, tensor)
Expand All @@ -675,7 +609,7 @@ def assign(out: SparseTensor,
if tensor.shape != expected_input_shape:
raise GrbDimensionMismatch(f"input shape mismatch: {tensor.shape} != {expected_input_shape}")

result = impl.assign(tensor, row_indices, col_indices, *out.shape)
result = impl.assign(out.dtype, tensor, row_indices, col_indices, *out.shape)
if mask is not None:
result = impl.select_by_mask(result, mask, desc)
update(out, result, mask, accum, desc, row_indices=row_indices, col_indices=col_indices)
Loading