Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 36d73d7

Browse files
committedDec 9, 2024·
Add Arm®v9-A architecture SME SGEMM kernels
Add implementation of SGEMM based on the Arm®v9-A architecture Scalable Matrix Extension (SME) [1], using the Arm C Language Extensions (ACLE) [2]. Add SME2 compute & packing kernels for SGEMM and enable them under the ARMV9SME target. The compute kernel performs outer products on panels of A and B, accumulating into 2x2 inner blocks of C via the SME two-dimensional architectural register, ZA. The non-transpose packing kernel performs a copy into a contiguous buffer using SVE loads & stores in Streaming SVE mode. Streaming SVE is an execution mode introduced by SME that supports execution of SVE code with the SME defined vector length, known as the Streaming SVE vector length (SVL). The transpose packing kernel performs on-the-fly transposition by utilizing horizontal & vertical tile slice access to the SME ZA register. Includes an update to the driver to account for expanded inner block. Note: this places the ARMV9SME target in WIP state. It is functional for SGEMM, and all GEMM tests are passing. Other BLAS3 routines have not been updated to match the larger kernel size, so SYMM/TRMM tests are currently expected to fail in this WIP state. [1] https://developer.arm.com/documentation/109246/0100/SME-Overview/SME-and-SME2 [2] https://arm-software.github.io/acle/main/acle.html
1 parent df2b2cf commit 36d73d7

File tree

6 files changed

+304
-0
lines changed

6 files changed

+304
-0
lines changed
 

‎CONTRIBUTORS.md

+3
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,6 @@ In chronological order:
229229

230230
* Christopher Daley <https://github.com/cdaley>
231231
* [2024-01-24] Optimize GEMV forwarding on ARM64 systems
232+
233+
* Aymen Qader <aymen.qader@arm.com>
234+
* [2024-12-09] Add Arm®v9-A architecture SME2 SGEMM kernels

‎driver/level3/level3.c

+3
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
345345
#if defined(SKYLAKEX) || defined(COOPERLAKE) || defined(SAPPHIRERAPIDS)
346346
/* the current AVX512 s/d/c/z GEMM kernel requires n>=6*GEMM_UNROLL_N to achieve best performance */
347347
if (min_jj >= 6*GEMM_UNROLL_N) min_jj = 6*GEMM_UNROLL_N;
348+
#elif defined(ARMV9SME) && !defined(DOUBLE) && !defined(COMPLEX)
349+
/* the current SME SGEMM kernel requires n>=8*GEMM_UNROLL_N to achieve best performance */
350+
if (min_jj >= 8*GEMM_UNROLL_N) min_jj = 8*GEMM_UNROLL_N;
348351
#else
349352
if (min_jj >= 3*GEMM_UNROLL_N) min_jj = 3*GEMM_UNROLL_N;
350353
else

‎kernel/arm64/KERNEL.ARMV9SME

+7
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
11
include $(KERNELDIR)/KERNEL.ARMV8SVE
2+
3+
SGEMMKERNEL = sgemm_kernel_sme.c
4+
5+
SGEMMINCOPY = sgemm_ncopy_sme.c
6+
SGEMMITCOPY = sgemm_tcopy_sme.c
7+
SGEMMONCOPY = sgemm_ncopy_sme.c
8+
SGEMMOTCOPY = sgemm_tcopy_sme.c

‎kernel/arm64/sgemm_kernel_sme.c

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/***************************************************************************
2+
Copyright (c) 2024, The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
22+
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
25+
THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
#include <arm_sme.h>
29+
30+
#include "common.h"
31+
32+
// Outer product kernel.
33+
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
34+
// This kernel is unpredicated, and assumes a full 2SVL x 2SVL block.
35+
__attribute__((always_inline)) inline void
36+
kernel_2x2(const float *A, const float *B, float *C, float alpha,
37+
size_t shared_dim, size_t a_step, size_t b_step, size_t c_step)
38+
__arm_out("za") __arm_streaming {
39+
const size_t svl = svcntw();
40+
41+
// Predicate set-up
42+
svbool_t ptrue = svptrue_b32();
43+
44+
// Load from C into ZA
45+
for (size_t i = 0; i < (svl >> 1); i++) {
46+
svld1_ver_za32(0, i, ptrue, &C[0 * svl + i * c_step]);
47+
svld1_ver_za32(1, i, ptrue, &C[1 * svl + i * c_step]);
48+
svld1_ver_za32(2, i, ptrue, &C[0 * svl + (i + svl) * c_step]);
49+
svld1_ver_za32(3, i, ptrue, &C[1 * svl + (i + svl) * c_step]);
50+
}
51+
52+
svfloat32_t alpha_vec = svdup_f32(alpha);
53+
54+
// Iterate through shared dimension (K)
55+
for (size_t k = 0; k < shared_dim; k++) {
56+
// Load column of A
57+
svfloat32x2_t cols_a = svld1_x2(svptrue_c32(), &A[k * a_step]);
58+
59+
// Load row of B
60+
svfloat32x2_t rows_b = svld1_x2(svptrue_c32(), &B[k * b_step]);
61+
62+
// Multiply B through by alpha
63+
svfloat32_t row_b_0 = svmul_x(ptrue, alpha_vec, svget2(rows_b, 0));
64+
svfloat32_t row_b_1 = svmul_x(ptrue, alpha_vec, svget2(rows_b, 1));
65+
66+
// Perform outer products
67+
svmopa_za32_m(0, ptrue, ptrue, svget2(cols_a, 0), row_b_0);
68+
svmopa_za32_m(1, ptrue, ptrue, svget2(cols_a, 1), row_b_0);
69+
svmopa_za32_m(2, ptrue, ptrue, svget2(cols_a, 0), row_b_1);
70+
svmopa_za32_m(3, ptrue, ptrue, svget2(cols_a, 1), row_b_1);
71+
}
72+
73+
// Store out to C from ZA
74+
for (size_t i = 0; i < (svl >> 1); i++) {
75+
// Store out one row of C per tile
76+
svst1_ver_za32(0, i, ptrue, &C[0 * svl + i * c_step]);
77+
svst1_ver_za32(1, i, ptrue, &C[1 * svl + i * c_step]);
78+
svst1_ver_za32(2, i, ptrue, &C[0 * svl + (i + svl) * c_step]);
79+
svst1_ver_za32(3, i, ptrue, &C[1 * svl + (i + svl) * c_step]);
80+
}
81+
}
82+
83+
// Outer product kernel.
84+
// Computes an SVL x SVL block of C, utilizing a single FP32 tile of ZA (ZA0).
85+
// This kernel is predicated, and can handle under-filled blocks.
86+
__attribute__((always_inline)) inline void
87+
kernel_1x1(const float *A, const float *B, float *C, float alpha,
88+
size_t shared_dim, size_t a_len, size_t a_step, size_t b_len,
89+
size_t b_step, size_t c_step, size_t c_rows, size_t c_cols)
90+
__arm_out("za") __arm_streaming {
91+
92+
// Predicate set-up
93+
svbool_t pg = svptrue_b32();
94+
svbool_t pg_a = svwhilelt_b32((size_t)0, a_len);
95+
svbool_t pg_b = svwhilelt_b32((size_t)0, b_len);
96+
svbool_t pg_c = svwhilelt_b32((size_t)0, c_rows);
97+
98+
// Load from C into ZA
99+
for (size_t i = 0; i < c_cols; i++) {
100+
svld1_ver_za32(0, i, pg_c, &C[i * c_step]);
101+
}
102+
103+
svfloat32_t alpha_vec = svdup_f32_z(pg_b, alpha);
104+
105+
// Iterate through shared dimension (K)
106+
for (size_t k = 0; k < shared_dim; k++) {
107+
// Load column of A
108+
svfloat32_t col_a = svld1(pg_a, &A[k * a_step]);
109+
// Load row of B
110+
svfloat32_t row_b = svld1(pg_b, &B[k * b_step]);
111+
// Multiply B through by alpha
112+
row_b = svmul_x(pg_b, alpha_vec, row_b);
113+
// Perform outer product
114+
svmopa_za32_m(0, pg, pg, col_a, row_b);
115+
}
116+
117+
// Store out to C from ZA
118+
for (size_t i = 0; i < c_cols; i++) {
119+
svst1_ver_za32(0, i, pg_c, &C[i * c_step]);
120+
}
121+
}
122+
123+
__arm_new("za") __arm_locally_streaming
124+
int CNAME(BLASLONG bm, BLASLONG bn, BLASLONG bk, FLOAT alpha0, FLOAT *ba,
125+
FLOAT *bb, FLOAT *C, BLASLONG ldc) {
126+
127+
const BLASLONG num_rows = bm;
128+
const BLASLONG num_cols = bn;
129+
130+
const FLOAT *a_ptr = ba;
131+
const FLOAT *b_ptr = bb;
132+
FLOAT *c_ptr = C;
133+
134+
const BLASLONG svl = svcntw();
135+
136+
const BLASLONG a_step = bm;
137+
const BLASLONG b_step = bn;
138+
const BLASLONG c_step = ldc;
139+
140+
// Block over rows of C (panels of A)
141+
BLASLONG row_idx = 0;
142+
143+
// 2x2 loop
144+
BLASLONG row_batch = 2 * svl;
145+
146+
// Block over row dimension of C
147+
for (; row_idx + row_batch <= num_rows; row_idx += row_batch) {
148+
BLASLONG col_idx = 0;
149+
BLASLONG col_batch = 2 * svl;
150+
151+
// Block over column dimension of C
152+
for (; col_idx + col_batch <= num_cols; col_idx += col_batch) {
153+
kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx],
154+
&c_ptr[row_idx + col_idx * c_step], alpha0, bk, a_step, b_step,
155+
c_step);
156+
}
157+
158+
// Handle under-filled blocks w/ 2x(1x1) kernels
159+
col_batch = 1 * svl;
160+
for (; col_idx < num_cols; col_idx += col_batch) {
161+
col_batch = MIN(col_batch, num_cols - col_idx);
162+
163+
kernel_1x1(&a_ptr[row_idx], &b_ptr[col_idx],
164+
&c_ptr[row_idx + col_idx * c_step], alpha0, bk, svl, a_step,
165+
col_batch, b_step, c_step, svl, col_batch);
166+
167+
kernel_1x1(&a_ptr[row_idx + svl], &b_ptr[col_idx],
168+
&c_ptr[(row_idx + svl) + col_idx * c_step], alpha0, bk, svl,
169+
a_step, col_batch, b_step, c_step, svl, col_batch);
170+
}
171+
}
172+
173+
// Handle under-filled blocks w/ 1x1 kernels
174+
row_batch = 1 * svl;
175+
for (; row_idx < num_rows; row_idx += row_batch) {
176+
row_batch = MIN(row_batch, num_rows - row_idx);
177+
// Block over column dimension of C
178+
BLASLONG col_batch = svl;
179+
for (BLASLONG col_idx = 0; col_idx < num_cols; col_idx += col_batch) {
180+
col_batch = MIN(col_batch, num_cols - col_idx);
181+
kernel_1x1(&a_ptr[row_idx], &b_ptr[col_idx],
182+
&c_ptr[row_idx + col_idx * c_step], alpha0, bk, row_batch,
183+
a_step, col_batch, b_step, c_step, row_batch, col_batch);
184+
}
185+
}
186+
return 0;
187+
}

‎kernel/arm64/sgemm_ncopy_sme.c

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/***************************************************************************
2+
Copyright (c) 2024, The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
22+
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
25+
THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
#include <arm_sme.h>
29+
30+
#include "common.h"
31+
32+
// Transpose 1SVL x N panel of A into B
33+
__attribute__((always_inline)) inline static void
34+
transpose_panel(const FLOAT *a, FLOAT *b, BLASLONG rows, BLASLONG cols,
35+
BLASLONG a_step, BLASLONG b_step)
36+
__arm_out("za") __arm_streaming {
37+
BLASLONG col_batch = svcntsw();
38+
const svbool_t pg_a = svwhilelt_b32_u64(0, rows);
39+
40+
for (BLASLONG k = 0; k < cols; k += col_batch) {
41+
col_batch = MIN(col_batch, cols - k);
42+
for (BLASLONG col = 0; col < col_batch; col++) {
43+
svld1_ver_za32(0, col, pg_a, &a[(col + k) * a_step]);
44+
}
45+
46+
const svbool_t pg_b = svwhilelt_b32(k, cols);
47+
for (BLASLONG row = 0; row < rows; row++) {
48+
svst1_hor_za32(0, row, pg_b, &b[row * b_step + k]);
49+
}
50+
}
51+
}
52+
53+
__arm_new("za") __arm_locally_streaming
54+
int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, FLOAT *b) {
55+
const BLASLONG num_rows = m;
56+
BLASLONG row_batch = svcntsw();
57+
for (BLASLONG row_idx = 0; row_idx < num_rows; row_idx += row_batch) {
58+
// Transpose 1xSVL panel
59+
row_batch = MIN(row_batch, num_rows - row_idx);
60+
transpose_panel(&a[row_idx], &b[row_idx * n], row_batch, n, lda, n);
61+
}
62+
return 0;
63+
}

‎kernel/arm64/sgemm_tcopy_sme.c

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/***************************************************************************
2+
Copyright (c) 2024, The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
22+
GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
23+
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
24+
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
25+
THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
#include <arm_sve.h>
29+
30+
#include "common.h"
31+
32+
__arm_locally_streaming int CNAME(BLASLONG m, BLASLONG n, FLOAT *restrict a,
33+
BLASLONG lda, FLOAT *restrict b) {
34+
for (BLASLONG i = 0; i < m; i++) {
35+
for (BLASLONG j = 0; j < n; j += svcntw()) {
36+
svbool_t pg = svwhilelt_b32(j, n);
37+
svst1(pg, &b[i * n + j], svld1(pg, &a[i * lda + j]));
38+
}
39+
}
40+
return 0;
41+
}

0 commit comments

Comments
 (0)
Please sign in to comment.