Skip to content

Commit 9b0feb0

Browse files
authored
Fix for Type Check in Quantized CPU op_dequantize (#13174)
### Summary When quantizing a model (without delegating to a specific backend), an exported model relies on the operator library in `kernels/quantized/cpu/`. Specifically, the essential operation of `op_dequantize` is performing: `out = (in - offset) * scale` where the offset is an integer type. While initially, this offset is assumed to be an `uint64_t` (see [here](https://github.com/pytorch/executorch/blob/a44e4aca7cddf91e8ed7282a70d6c40493a50883/kernels/quantized/cpu/op_dequantize.cpp#L426)), when it is used to perform the operation above, it is cast down to a `uint32_t` (see [here](https://github.com/pytorch/executorch/blob/a44e4aca7cddf91e8ed7282a70d6c40493a50883/kernels/quantized/cpu/op_dequantize.cpp#L463)). It seems an implicit assumption is that the quantization offset is a `uint32_t` value, and the `uint64_t` declaration is simply safeguarding for future proofing. In any event, the type check for the offset should allow the offset to be either `uint32_t` or uint64_t`. This PR allows for that change. ### Test plan Tested with mobilenet V2 on Arm backend. Quantized model runner initially crashed do to this check only allowing the offset to be `uint64_t`. When examining the values, none were larger than `UINT32_MAX`, so it should be safe to permit the offset to have `uint32_t` values. When this change was made, the mobilenet V2 runner was able to complete.
1 parent 0f70a5d commit 9b0feb0

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,8 @@ Tensor& dequantize_per_channel_out(
384384
if (opt_zero_points.has_value()) {
385385
auto zero_point = opt_zero_points.value();
386386
ET_CHECK_MSG(
387-
zero_point.scalar_type() == ScalarType::Long,
387+
zero_point.scalar_type() == ScalarType::Int ||
388+
zero_point.scalar_type() == ScalarType::Long,
388389
"zero_point.scalar_type() %" PRId8 " is not integer type",
389390
static_cast<int8_t>(zero_point.scalar_type()));
390391

0 commit comments

Comments
 (0)