@@ -87,6 +87,22 @@ module {
8787 return %0 : tensor <8 x8 xf32 >
8888 }
8989
90+ func.func @add_coo_coo_out_coo (%arga: tensor <8 x8 xf32 , #SortedCOO >,
91+ %argb: tensor <8 x8 xf32 , #SortedCOO >)
92+ -> tensor <8 x8 xf32 , #SortedCOO > {
93+ %init = tensor.empty () : tensor <8 x8 xf32 , #SortedCOO >
94+ %0 = linalg.generic #trait
95+ ins (%arga , %argb: tensor <8 x8 xf32 , #SortedCOO >,
96+ tensor <8 x8 xf32 , #SortedCOO >)
97+ outs (%init: tensor <8 x8 xf32 , #SortedCOO >) {
98+ ^bb (%a: f32 , %b: f32 , %x: f32 ):
99+ %0 = arith.addf %a , %b : f32
100+ linalg.yield %0 : f32
101+ } -> tensor <8 x8 xf32 , #SortedCOO >
102+ return %0 : tensor <8 x8 xf32 , #SortedCOO >
103+ }
104+
105+
90106 func.func @add_coo_dense (%arga: tensor <8 x8 xf32 >,
91107 %argb: tensor <8 x8 xf32 , #SortedCOO >)
92108 -> tensor <8 x8 xf32 > {
@@ -149,17 +165,21 @@ module {
149165 %C3 = call @add_coo_coo (%COO_A , %COO_B ) : (tensor <8 x8 xf32 , #SortedCOO >,
150166 tensor <8 x8 xf32 , #SortedCOO >)
151167 -> tensor <8 x8 xf32 >
168+ %COO_RET = call @add_coo_coo_out_coo (%COO_A , %COO_B ) : (tensor <8 x8 xf32 , #SortedCOO >,
169+ tensor <8 x8 xf32 , #SortedCOO >)
170+ -> tensor <8 x8 xf32 , #SortedCOO >
171+ %C4 = sparse_tensor.convert %COO_RET : tensor <8 x8 xf32 , #SortedCOO > to tensor <8 x8 xf32 >
152172 //
153173 // Verify computed matrix C.
154174 //
155- // CHECK-COUNT-3 : ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
156- // CHECK-NEXT-COUNT-3 : ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
157- // CHECK-NEXT-COUNT-3 : ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
158- // CHECK-NEXT-COUNT-3 : ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
159- // CHECK-NEXT-COUNT-3 : ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
160- // CHECK-NEXT-COUNT-3 : ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
161- // CHECK-NEXT-COUNT-3 : ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
162- // CHECK-NEXT-COUNT-3 : ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
175+ // CHECK-COUNT-4 : ( 8.8, 4.8, 6.8, 4.8, 8.8, 6.1, 14.8, 16.8 )
176+ // CHECK-NEXT-COUNT-4 : ( 4.4, 4.4, 4.4, 8.4, 8.4, 12.4, 16.4, 16.4 )
177+ // CHECK-NEXT-COUNT-4 : ( 8.8, 4.8, 6.8, 8.8, 8.8, 12.8, 14.8, 15.8 )
178+ // CHECK-NEXT-COUNT-4 : ( 4.3, 5.3, 6.3, 8.3, 8.3, 12.3, 14.3, 16.3 )
179+ // CHECK-NEXT-COUNT-4 : ( 4.5, 4.5, 6.5, 8.5, 8.5, 12.5, 14.5, 16.5 )
180+ // CHECK-NEXT-COUNT-4 : ( 9.9, 4.9, 6.9, 8.9, 8.9, 12.9, 15.9, 16.9 )
181+ // CHECK-NEXT-COUNT-4 : ( 12.1, 6.1, 5.1, 9.1, 9.1, 13.1, 15.1, 17.1 )
182+ // CHECK-NEXT-COUNT-4 : ( 15.4, 5.4, 7.4, 5.4, 11.4, 10.4, 11.4, 9.4 )
163183 //
164184 %f0 = arith.constant 0.0 : f32
165185 scf.for %i = %c0 to %c8 step %c1 {
@@ -169,9 +189,12 @@ module {
169189 : tensor <8 x8 xf32 >, vector <8 xf32 >
170190 %v3 = vector.transfer_read %C3 [%i , %c0 ], %f0
171191 : tensor <8 x8 xf32 >, vector <8 xf32 >
192+ %v4 = vector.transfer_read %C4 [%i , %c0 ], %f0
193+ : tensor <8 x8 xf32 >, vector <8 xf32 >
172194 vector.print %v1 : vector <8 xf32 >
173195 vector.print %v2 : vector <8 xf32 >
174196 vector.print %v3 : vector <8 xf32 >
197+ vector.print %v4 : vector <8 xf32 >
175198 }
176199
177200 // Release resources.
@@ -181,6 +204,7 @@ module {
181204 bufferization.dealloc_tensor %CSR_A : tensor <8 x8 xf32 , #CSR >
182205 bufferization.dealloc_tensor %COO_A : tensor <8 x8 xf32 , #SortedCOO >
183206 bufferization.dealloc_tensor %COO_B : tensor <8 x8 xf32 , #SortedCOO >
207+ bufferization.dealloc_tensor %COO_RET : tensor <8 x8 xf32 , #SortedCOO >
184208
185209
186210 return
0 commit comments