Skip to content

[aarch64] add sbgemm inner product op #1768

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 7, 2024

Conversation

snadampal
Copy link
Contributor

Description

Added sbgemm inner product op to enable PyTorch torch.compile() and bf16 fastmath kernels to work together on aarch64.

Fixes # (github issue)

Checklist

General

  • [x ] Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
    same pass rate as in main branch.
  • Have you formatted the code using clang-format?

Performance improvements

  1. torch.compile() with fp32 inner product kernel showed up to 5.8x perf improvement on AWS c7g instance
  2. torch.compile() with bf16 inner product kernel showed another 1.25x perf improvement on AWS c7g instance
  • Have you submitted performance data that demonstrates performance improvements?

New features

  • Have you published an RFC for the new feature?
  • Was the RFC approved?
  • Have you added relevant tests?

Bug fixes

  • Have you included information on how to reproduce the issue (either in a github issue or in this PR)?
  • Have you added relevant regression tests?

RFC PR

  • Does RFC document follow the template?
  • Have you added a link to the rendered document?

@snadampal
Copy link
Contributor Author

Hi @jondea , @nSircombe , @cfRod , appreciate if any of you can review this PR? thank you!

@jondea
Copy link
Contributor

jondea commented Dec 14, 2023

HI @snadampal , thanks for the patch! I'm just looking into this now. Just so I can test it, what benchdnn run lines are you using to test, and in particular which ones do you want to enable with this patch?

@snadampal
Copy link
Contributor Author

Hi @jondea , I have tested oneDNN inner product tests in fastmath and normal mode to make sure no regressions. The main addition in this PR is to allow bfloat16 weights format for primitive creation and reuse the matmul sgemm operator. this is possible with fixed format kernels where, in the fast math mode, it will accept bf16 weights to gemm operators.
I've also tested bert model inference using torch.compile() and oneDNN fastmath model DNNL_DEFAULT_FPMATH_MODE=BF16.
Now I'm thinking if I need to add any new test scenarios for oneDNN.

@snadampal
Copy link
Contributor Author

snadampal commented Jan 12, 2024

Hi @jondea , while I check if new unit tests to be added for the sbgemm ops, can you please review and comment? thank you!

cc: @cfRod , @milpuz01

@vpirogov vpirogov added this to the v3.4 milestone Jan 12, 2024
@jondea
Copy link
Contributor

jondea commented Jan 15, 2024

Sorry for the slow reply. It looks great in principle, I'm glad it's such a simple change to enable this!

My only concern is removing the check that the memory format is any. I thought that oneDNN expects you to change the memory format only if it is set to any? For reference, acl_utils::reorder_to_weight_format mutates the memory descriptor.

On the PyTorch side, are you planning to pass any or a specific format?

@snadampal
Copy link
Contributor Author

snadampal commented Jan 15, 2024

Hi @jondea , thanks for the review. I know ACL expects format_any so that it can reorder and pack them to the custom format that ACL gemm understands. However, with the torch.compile() and weights pre-packing done before the primitive creation, the weights are already reordered by ACL (even before the primitive is created), so, they are no longer in format::any, instead they are in the format ACL decided. But your concern is valid, because the same primitive is used for compile as well as non-compile path.

from pytorch side, I will pass format::any for non-compile path, and the reordered format (whatever ACL returns as dst format for the reordered/packed weights) if the graph was compiled and use pre-packed weights.

so, how about we check for eitherformat_any or ACL custom format (blocked:AB8a4b::f0 or blocked:Ab8a::f0) during the primitive creation to make it more controlled?

I'm suggesting the above blocked formats because this is what I see during the ACL reorders
for fastmath bf16 weights:
onednn_verbose,primitive,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:ab::f0 dst_bf16::blocked:AB8a4b::f0,attr-scratchpad:user attr-fpmath:bf16 ,,768x768,0.138184

for non fastmath fp32 weights:
onednn_verbose,primitive,exec,cpu,reorder,jit:blk,undef,src_f32::blocked:ab::f0 dst_f32::blocked:Ab8a::f0,attr-scratchpad:user ,,768x768,0.158936

@jondea
Copy link
Contributor

jondea commented Jan 15, 2024

If you wanted a quick work around, you could pass in format any from PyTorch when you re-init the primitive even though you know it's going to be that format.

Nevertheless, this change does make sense in the long term. I think instead of hard coding it in oneDNN, it makes sense to check that the format passed into init matches the one returned from ACL. That would be more flexible to future changes in ACL.

@snadampal snadampal force-pushed the aarch64_sbgemm_inner_product branch from 600bb98 to 9745998 Compare January 18, 2024 16:56
@snadampal
Copy link
Contributor Author

Hi @jondea , given this is the right direction, I extended the supported formats to any and blocked. of course we can debate whether the blocked always guarantee that it's the one supported by ACL or not. for that I think we can rely one of the ACL config checks to return not supported error.

I didn't want to workaround ideep as it will be too hacky solution; since there won't be any on-the-fly reorders in the compiled case, we need to hack the weight md format at multiple places, during create as well as exec, which i don't think is a good option.

@snadampal
Copy link
Contributor Author

I'm trying to add few tests for ip in sbgemm mode. Could someone please point me to documentation on how to add benchdnn unit tests for new configurations? Thank you!

@jondea
Copy link
Contributor

jondea commented Jan 19, 2024

I think we need a bit more logic inside init_conf_ip Currently it is set up for just format_tag::any, and so with your changes it will accept any or any blocked format. I think in the case of blocked you also need to check that the format produced by acl_utils::reorder_to_weight_format matches the one passed in.

To be concrete, I think as it stands, this will fail benchdnn if you run with a blocked format that is not supported by ACL. Whereas it should return unimplemented and pass to ref.

@snadampal snadampal force-pushed the aarch64_sbgemm_inner_product branch 2 times, most recently from 052924f to d0389c8 Compare January 27, 2024 06:48
@snadampal
Copy link
Contributor Author

Hi @jondea , addressed your feedback about the error checks for blocked format and also added unit tests. Please note that for supporting the unit test with --wtag=Ab8a blocked layout I had to set --allow-enum-tags-only=0 and I hence I'm passing allow_all_flags=true to set_default_params(). Please let me know if you see any issues with it.
this is how I tested the blocked format
./benchdnn --ip --mode=P --engine=cpu --allow-enum-tags-only=0 --batch=inputs/ip/test_ip_acl

Copy link
Contributor

@jondea jondea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! I just have two minor comments.

if ((weights_format_kind_received == format_kind::blocked)
&& !(dnnl_memory_desc_equal(
&weights_md_received, &weights_md_))) {
return status::unimplemented;

This comment was marked as resolved.

@snadampal snadampal force-pushed the aarch64_sbgemm_inner_product branch from 0eee71e to d66926a Compare January 29, 2024 20:50
Copy link
Contributor

@jondea jondea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thank you for the contribution!

@snadampal
Copy link
Contributor Author

Looks like the CI failures are not related to this PR, I have seen CI failing for other PRs too, with the same errors.
@vpirogov , Is this a known issue?

The following tests FAILED:
	180 - test_benchdnn_modeC_self_smoke_cpu (Failed)
Errors while running CTest
        Start 180: test_benchdnn_modeC_self_smoke_cpu

180: Test command: /home/vsts/work/1/s/build/tests/benchdnn/benchdnn "--mode=C" "-v1" "--engine=cpu" "--self" "--batch=test_self_smoke"
180: Working Directory: /home/vsts/work/1/s/build/tests/benchdnn
180: Test timeout computed to be: 10000000
180: check_simple_enums() ...
180: check_attr2str() ...
180: check_attr() ...
180: [int self::check_attr():109] '(entry.policy) == (policy_t::PER_OC)' FAILED ==>  0 != 1
180: check_post_ops2str() ...
180: check_str2post_ops() ...
180: check_tags() ...
180: check_trim_tags() ...
180: check_skip_impl() ...
180: check_status_change() ...
180: [   0][0] exp_f32:           0 exp:           0 got:          -1 diff:       1 rdiff:       1
180: [   1][1] exp_f32:           1 exp:           1 got:           0 diff:       1 rdiff:       1
180: [   2][2] exp_f32:           2 exp:           2 got:           1 diff:       1 rdiff:     0.5
180: [   3][3] exp_f32:           3 exp:           3 got:           2 diff:       1 rdiff:0.333333
180: [   4][4] exp_f32:           4 exp:           4 got:           3 diff:       1 rdiff:    0.25
180: [   5][5] exp_f32:           5 exp:           5 got:           4 diff:       1 rdiff:     0.2
180: [   6][6] exp_f32:           6 exp:           6 got:           5 diff:       1 rdiff:0.166667
180: [   7][7] exp_f32:           7 exp:           7 got:           6 diff:       1 rdiff:0.142857
180: [   8][8] exp_f32:           8 exp:           8 got:           7 diff:       1 rdiff:   0.125
180: [   9][9] exp_f32:           9 exp:           9 got:           8 diff:       1 rdiff:0.111111
180: [COMPARE_STATS]: trh=0 max_diff:       1 max_rdiff:       1
180: [   0][0] exp_f32:           0 exp:           0 got:          -1 diff:       1 rdiff:       1
180: [   1][1] exp_f32:           1 exp:           1 got:           0 diff:       1 rdiff:       1

180: check_norm() ...
180: check_diff_norm() ...
180: check_compare_norm() ...
180: [L0] = 1
180: [L1] exp:      55 got:      65 diff:      10 rel_diff:0.181818
180: [L2] exp: 19.6214 got: 22.4722 diff: 3.16228 rel_diff:0.161165
180: [L8] exp:      10 got:      11 diff:       1 rel_diff:     0.1
180: check_compare() ...
180: [  99][99] exp_f32:           1 exp:           1 got:         100 diff:      99 rdiff:      99
180: [COMPARE_STATS]: trh=98 max_diff:      99 max_rdiff:      99
180: tests:18 passed:17 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
180: total: 0.03s;
180/184 Test #180: test_benchdnn_modeC_self_smoke_cpu ......................***Failed    0.06 sec

@igorsafo
Copy link
Contributor

Looks like the CI failures are not related to this PR, I have seen CI failing for other PRs too, with the same errors. @vpirogov , Is this a known issue?

The following tests FAILED:
	180 - test_benchdnn_modeC_self_smoke_cpu (Failed)
Errors while running CTest
        Start 180: test_benchdnn_modeC_self_smoke_cpu

180: Test command: /home/vsts/work/1/s/build/tests/benchdnn/benchdnn "--mode=C" "-v1" "--engine=cpu" "--self" "--batch=test_self_smoke"
180: Working Directory: /home/vsts/work/1/s/build/tests/benchdnn
180: Test timeout computed to be: 10000000
180: check_simple_enums() ...
180: check_attr2str() ...
180: check_attr() ...
180: [int self::check_attr():109] '(entry.policy) == (policy_t::PER_OC)' FAILED ==>  0 != 1
180: check_post_ops2str() ...
180: check_str2post_ops() ...
180: check_tags() ...
180: check_trim_tags() ...
180: check_skip_impl() ...
180: check_status_change() ...
180: [   0][0] exp_f32:           0 exp:           0 got:          -1 diff:       1 rdiff:       1
180: [   1][1] exp_f32:           1 exp:           1 got:           0 diff:       1 rdiff:       1
180: [   2][2] exp_f32:           2 exp:           2 got:           1 diff:       1 rdiff:     0.5
180: [   3][3] exp_f32:           3 exp:           3 got:           2 diff:       1 rdiff:0.333333
180: [   4][4] exp_f32:           4 exp:           4 got:           3 diff:       1 rdiff:    0.25
180: [   5][5] exp_f32:           5 exp:           5 got:           4 diff:       1 rdiff:     0.2
180: [   6][6] exp_f32:           6 exp:           6 got:           5 diff:       1 rdiff:0.166667
180: [   7][7] exp_f32:           7 exp:           7 got:           6 diff:       1 rdiff:0.142857
180: [   8][8] exp_f32:           8 exp:           8 got:           7 diff:       1 rdiff:   0.125
180: [   9][9] exp_f32:           9 exp:           9 got:           8 diff:       1 rdiff:0.111111
180: [COMPARE_STATS]: trh=0 max_diff:       1 max_rdiff:       1
180: [   0][0] exp_f32:           0 exp:           0 got:          -1 diff:       1 rdiff:       1
180: [   1][1] exp_f32:           1 exp:           1 got:           0 diff:       1 rdiff:       1

180: check_norm() ...
180: check_diff_norm() ...
180: check_compare_norm() ...
180: [L0] = 1
180: [L1] exp:      55 got:      65 diff:      10 rel_diff:0.181818
180: [L2] exp: 19.6214 got: 22.4722 diff: 3.16228 rel_diff:0.161165
180: [L8] exp:      10 got:      11 diff:       1 rel_diff:     0.1
180: check_compare() ...
180: [  99][99] exp_f32:           1 exp:           1 got:         100 diff:      99 rdiff:      99
180: [COMPARE_STATS]: trh=98 max_diff:      99 max_rdiff:      99
180: tests:18 passed:17 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
180: total: 0.03s;
180/184 Test #180: test_benchdnn_modeC_self_smoke_cpu ......................***Failed    0.06 sec

Hi @snadampal Thanks for notifying, we are aware of this issue and a fix is submitted into the internal repository.

@snadampal
Copy link
Contributor Author

@igorsafo , thanks for confirming. I guess CI will be triggered for this PR once the internal issue is fixed, right?
This PR has been approved and I think it's ready for merge post CI runs. Please let me know if there are any additional steps. thank you!

@igorsafo
Copy link
Contributor

@igorsafo , thanks for confirming. I guess CI will be triggered for this PR once the internal issue is fixed, right? This PR has been approved and I think it's ready for merge post CI runs. Please let me know if there are any additional steps. thank you!

Once the fix is landed we will need to rebase this branch on top of the latest main to include the fix, otherwise CI will stay red. I will notify once the fix is available.

@igorsafo
Copy link
Contributor

@snadampal The fixes landed into the main. Could you please rebase the branch?

with weights pre-packing enabled in torch.compile(),
the weights come already reorderd and in oneDNN format,
so, allowing format_kind::blocked as one of the supported
formats for acl inner product primitive.
@snadampal snadampal force-pushed the aarch64_sbgemm_inner_product branch from d66926a to ffdbe04 Compare January 31, 2024 06:15
@snadampal
Copy link
Contributor Author

Hi @igorsafo , done, rebased the PR.

@snadampal
Copy link
Contributor Author

Hi @igorsafo , please let me know if anything else to be covered for this PR to get merged. thank you.

@snadampal
Copy link
Contributor Author

Hi @igorsafo , we planned this PR for oneDNN 3.4. It was tagged for v3.4 but not yet merged, it's been approved and CI passing, could you please prioritize it?

@igorsafo igorsafo merged commit 31d213d into uxlfoundation:main Feb 7, 2024
@snadampal
Copy link
Contributor Author

thank you!

@igorsafo
Copy link
Contributor

igorsafo commented Feb 7, 2024

Hi @snadampal , sorry for the wait, the changes were promoted into main and backported into rls-v3.4.
Thank you for the contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants