Efficient support for any channel_dimension for kNeon quantized kernels on ARM64.
PiperOrigin-RevId: 320184717
diff --git a/ruy/create_trmul_params.cc b/ruy/create_trmul_params.cc
index 6c20c7d..824de6d 100644
--- a/ruy/create_trmul_params.cc
+++ b/ruy/create_trmul_params.cc
@@ -42,11 +42,14 @@
return true;
}
- if (RUY_PLATFORM_NEON_64) {
- if (src[Side::kLhs].data_type == Type::Create<float>()) {
- return false;
- }
+#if RUY_PLATFORM_NEON_64
+ if (src[Side::kLhs].data_type == Type::Create<float>()) {
+ return false;
}
+ if (path == Path::kNeon) {
+ return false;
+ }
+#endif
// Ruy's optimized kernels currently only support the channel_dimension==kRow
// case.
if (channel_dimension != ChannelDimension::kRow) {
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index 211ff91..b6f6d31 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -99,7 +99,6 @@
void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label(
"Kernel (kNeon, optimized for out-of-order cores)");
-
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
@@ -409,24 +408,33 @@
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
- "add x5, x4, %x[row], lsl #2\n"
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
- "csel x4, x4, x5, eq\n"
-
- "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
- "add x5, x1, %x[row], lsl #2\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
+ // Load the multiplier_fixedpoint values.
+ "add x5, x4, x3, lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
@@ -446,10 +454,27 @@
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v14.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v14.4s\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v20.4s, v14.s[0]\n"
+ "dup v21.4s, v14.s[1]\n"
+ "dup v22.4s, v14.s[2]\n"
+ "dup v23.4s, v14.s[3]\n"
+ "add v16.4s, v16.4s, v20.4s\n"
+ "add v17.4s, v17.4s, v21.4s\n"
+ "add v18.4s, v18.4s, v22.4s\n"
+ "add v19.4s, v19.4s, v23.4s\n"
+ "7:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
@@ -499,13 +524,23 @@
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
- "add x5, x1, %x[row], lsl #2\n"
+ "add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smax v12.4s, v14.4s, v8.4s\n"
+ "smin v11.4s, v14.4s, v8.4s\n"
+
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 8f\n"
+ // Case where channels are rows
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v12.4s\n"
@@ -513,8 +548,6 @@
"sshl v18.4s, v18.4s, v12.4s\n"
"sshl v19.4s, v19.4s, v12.4s\n"
- "smin v12.4s, v14.4s, v8.4s\n"
-
// Apply the fixed-point part of the multiplier.
"sqrdmulh v16.4s, v16.4s, v15.4s\n"
"sqrdmulh v17.4s, v17.4s, v15.4s\n"
@@ -522,10 +555,41 @@
"sqrdmulh v19.4s, v19.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
- "srshl v16.4s, v16.4s, v12.4s\n"
- "srshl v17.4s, v17.4s, v12.4s\n"
- "srshl v18.4s, v18.4s, v12.4s\n"
- "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v11.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v11.4s\n"
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns
+
+ // Apply the positive exponent part of the multiplier.
+ "dup v20.4s, v12.s[0]\n"
+ "dup v21.4s, v12.s[1]\n"
+ "dup v22.4s, v12.s[2]\n"
+ "dup v23.4s, v12.s[3]\n"
+ "sshl v16.4s, v16.4s, v20.4s\n"
+ "sshl v17.4s, v17.4s, v21.4s\n"
+ "sshl v18.4s, v18.4s, v22.4s\n"
+ "sshl v19.4s, v19.4s, v23.4s\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqrdmulh v16.4s, v16.4s, v15.s[0]\n"
+ "sqrdmulh v17.4s, v17.4s, v15.s[1]\n"
+ "sqrdmulh v18.4s, v18.4s, v15.s[2]\n"
+ "sqrdmulh v19.4s, v19.4s, v15.s[3]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "dup v20.4s, v11.s[0]\n"
+ "dup v21.4s, v11.s[1]\n"
+ "dup v22.4s, v11.s[2]\n"
+ "dup v23.4s, v11.s[3]\n"
+ "srshl v16.4s, v16.4s, v20.4s\n"
+ "srshl v17.4s, v17.4s, v21.4s\n"
+ "srshl v18.4s, v18.4s, v22.4s\n"
+ "srshl v19.4s, v19.4s, v23.4s\n"
+ "9:\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
@@ -2091,21 +2155,30 @@
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
- "add x5, x4, %x[row], lsl #2\n"
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
- "csel x4, x4, x5, eq\n"
-
- "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
- "add x5, x1, %x[row], lsl #2\n"
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
+
+ // Offset the bias pointer as needed given the current row, col.
+ "add x5, x1, x3, lsl #2\n"
+
+ // If there is no bias, use no offset, just address the passed zero
+ // data.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
+ // Load the multiplier_fixedpoint values.
+ "add x5, x4, x3, lsl #2\n"
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+ "csel x4, x4, x5, eq\n"
+ "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
@@ -2118,6 +2191,11 @@
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 6f\n"
+ // Case where channels are rows
+
"add v16.4s, v16.4s, v14.4s\n"
"ldr d1, [%[lhs_ptr], #16]\n"
"add v17.4s, v17.4s, v14.4s\n"
@@ -2130,6 +2208,27 @@
"ldr d6, [%[rhs_ptr], #32]\n"
"ldr d7, [%[rhs_ptr], #48]\n"
+ "b 7f\n"
+
+ "6:\n"
+ // Case where channels are columns
+ "dup v20.4s, v14.s[0]\n"
+ "ldr d1, [%[lhs_ptr], #16]\n"
+ "dup v21.4s, v14.s[1]\n"
+ "ldr d2, [%[lhs_ptr], #32]\n"
+ "dup v22.4s, v14.s[2]\n"
+ "ldr d3, [%[lhs_ptr], #48]\n"
+ "dup v23.4s, v14.s[3]\n"
+ "ldr d4, [%[rhs_ptr], #0]\n"
+ "add v16.4s, v16.4s, v20.4s\n"
+ "ldr d5, [%[rhs_ptr], #16]\n"
+ "add v17.4s, v17.4s, v21.4s\n"
+ "ldr d6, [%[rhs_ptr], #32]\n"
+ "add v18.4s, v18.4s, v22.4s\n"
+ "ldr d7, [%[rhs_ptr], #48]\n"
+ "add v19.4s, v19.4s, v23.4s\n"
+ "7:\n"
+
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
@@ -2176,16 +2275,26 @@
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
+ // Determine the channel index.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "csel w3, %w[row], %w[col], eq\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
- "add x5, x1, %x[row], lsl #2\n"
+ "add x5, x1, x3, lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smax v12.4s, v14.4s, v8.4s\n"
"ldr x1, [%[lhs_ptr], #8]\n"
+ "smin v11.4s, v14.4s, v8.4s\n"
+
+ // Jump based on channel dimension.
+ "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+ "bne 8f\n"
+ // Case where channels are rows
+
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v12.4s\n"
@@ -2196,7 +2305,6 @@
"ldr x4, [%[lhs_ptr], #56]\n"
"sshl v19.4s, v19.4s, v12.4s\n"
- "smin v12.4s, v14.4s, v8.4s\n"
// Apply the fixed-point part of the multiplier.
"ins v0.d[1], x1\n"
@@ -2213,10 +2321,54 @@
"sqrdmulh v19.4s, v19.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
- "srshl v16.4s, v16.4s, v12.4s\n"
- "srshl v17.4s, v17.4s, v12.4s\n"
- "srshl v18.4s, v18.4s, v12.4s\n"
- "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v11.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v11.4s\n"
+
+ "b 9f\n"
+
+ "8:\n"
+ // Case where channels are columns
+
+ // Apply the positive exponent part of the multiplier.
+ "dup v20.4s, v12.s[0]\n"
+ "ldr x2, [%[lhs_ptr], #24]\n"
+ "ldr x3, [%[lhs_ptr], #40]\n"
+ "dup v21.4s, v12.s[1]\n"
+ "ldr x4, [%[lhs_ptr], #56]\n"
+ "dup v22.4s, v12.s[2]\n"
+ "ins v0.d[1], x1\n"
+ "dup v23.4s, v12.s[3]\n"
+ "ldr x1, [%[rhs_ptr], #8]\n"
+ "sshl v16.4s, v16.4s, v20.4s\n"
+ "ins v1.d[1], x2\n"
+ "sshl v17.4s, v17.4s, v21.4s\n"
+ "ldr x2, [%[rhs_ptr], #24]\n"
+ "sshl v18.4s, v18.4s, v22.4s\n"
+ "ins v2.d[1], x3\n"
+ "sshl v19.4s, v19.4s, v23.4s\n"
+ "ldr x3, [%[rhs_ptr], #40]\n"
+
+ // Apply the fixed-point part of the multiplier.
+ "sqrdmulh v16.4s, v16.4s, v15.s[0]\n"
+ "ins v3.d[1], x4\n"
+ "sqrdmulh v17.4s, v17.4s, v15.s[1]\n"
+ "ldr x4, [%[rhs_ptr], #56]\n"
+ "sqrdmulh v18.4s, v18.4s, v15.s[2]\n"
+ "sqrdmulh v19.4s, v19.4s, v15.s[3]\n"
+
+ // Apply the negative exponent part of the multiplier.
+ "dup v20.4s, v11.s[0]\n"
+ "dup v21.4s, v11.s[1]\n"
+ "dup v22.4s, v11.s[2]\n"
+ "dup v23.4s, v11.s[3]\n"
+ "srshl v16.4s, v16.4s, v20.4s\n"
+ "srshl v17.4s, v17.4s, v21.4s\n"
+ "srshl v18.4s, v18.4s, v22.4s\n"
+ "srshl v19.4s, v19.4s, v23.4s\n"
+
+ "9:\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"