Efficient support for any channel_dimension for kNeon quantized kernels on ARM64.

PiperOrigin-RevId: 320184717
diff --git a/ruy/create_trmul_params.cc b/ruy/create_trmul_params.cc
index 6c20c7d..824de6d 100644
--- a/ruy/create_trmul_params.cc
+++ b/ruy/create_trmul_params.cc
@@ -42,11 +42,14 @@
     return true;
   }
 
-  if (RUY_PLATFORM_NEON_64) {
-    if (src[Side::kLhs].data_type == Type::Create<float>()) {
-      return false;
-    }
+#if RUY_PLATFORM_NEON_64
+  if (src[Side::kLhs].data_type == Type::Create<float>()) {
+    return false;
   }
+  if (path == Path::kNeon) {
+    return false;
+  }
+#endif
   // Ruy's optimized kernels currently only support the channel_dimension==kRow
   // case.
   if (channel_dimension != ChannelDimension::kRow) {
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index 211ff91..b6f6d31 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -99,7 +99,6 @@
 void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
   profiler::ScopeLabel label(
       "Kernel (kNeon, optimized for out-of-order cores)");
-
   CheckOffsetsInKernelParams8bit(params);
 
   const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
@@ -409,24 +408,33 @@
         "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
         "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
         "dup v9.4s, w3\n"   // create prod_zp_depth_vec
-        "add x5, x4, %x[row], lsl #2\n"
-        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "csel x4, x4, x5, eq\n"
-
-        "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
 
         // Now we load: bias data, LHS sums data, RHS sums data.
 
         // First, load the base pointers from the params.
         "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
 
-        "add x5, x1, %x[row], lsl #2\n"
+        // Determine the channel index.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
+
+        // Offset the bias pointer as needed given the current row, col.
+        "add x5, x1, x3, lsl #2\n"
+
+        // If there is no bias, use no offset, just address the passed zero
+        // data.
         "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
         "csel x1, x1, x5, eq\n"
 
         // Load 4 bias values.
         "ld1 {v14.4s}, [x1]\n"
 
+        // Load the multiplier_fixedpoint values.
+        "add x5, x4, x3, lsl #2\n"
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+        "csel x4, x4, x5, eq\n"
+        "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
         // Now that we know what LHS and RHS data the next iteration of the
         // main loop will need to load, we start loading the first 32 bytes of
         // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
@@ -446,10 +454,27 @@
 
         // Perform the bias-addition (per the above, we have just folded into
         // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+        // Jump based on channel dimension.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 6f\n"
+        // Case where channels are rows
         "add v16.4s, v16.4s, v14.4s\n"
         "add v17.4s, v17.4s, v14.4s\n"
         "add v18.4s, v18.4s, v14.4s\n"
         "add v19.4s, v19.4s, v14.4s\n"
+        "b 7f\n"
+
+        "6:\n"
+        // Case where channels are columns
+        "dup v20.4s, v14.s[0]\n"
+        "dup v21.4s, v14.s[1]\n"
+        "dup v22.4s, v14.s[2]\n"
+        "dup v23.4s, v14.s[3]\n"
+        "add v16.4s, v16.4s, v20.4s\n"
+        "add v17.4s, v17.4s, v21.4s\n"
+        "add v18.4s, v18.4s, v22.4s\n"
+        "add v19.4s, v19.4s, v23.4s\n"
+        "7:\n"
 
         "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
         "beq 401f\n"
@@ -499,13 +524,23 @@
 
         //Load the exponent part of the multiplier.
         "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+        // Determine the channel index.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
+
         "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "add x5, x1, %x[row], lsl #2\n"
+        "add x5, x1, x3, lsl #2\n"
         "csel x1, x1, x5, eq\n"
 
         "ld1 {v14.4s}, [x1]\n"
 
         "smax v12.4s, v14.4s, v8.4s\n"
+        "smin v11.4s, v14.4s, v8.4s\n"
+
+        // Jump based on channel dimension.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 8f\n"
+        // Case where channels are rows
 
         // Apply the positive exponent part of the multiplier.
         "sshl v16.4s, v16.4s, v12.4s\n"
@@ -513,8 +548,6 @@
         "sshl v18.4s, v18.4s, v12.4s\n"
         "sshl v19.4s, v19.4s, v12.4s\n"
 
-        "smin v12.4s, v14.4s, v8.4s\n"
-
         // Apply the fixed-point part of the multiplier.
         "sqrdmulh v16.4s, v16.4s, v15.4s\n"
         "sqrdmulh v17.4s, v17.4s, v15.4s\n"
@@ -522,10 +555,41 @@
         "sqrdmulh v19.4s, v19.4s, v15.4s\n"
 
         // Apply the negative exponent part of the multiplier.
-        "srshl v16.4s, v16.4s, v12.4s\n"
-        "srshl v17.4s, v17.4s, v12.4s\n"
-        "srshl v18.4s, v18.4s, v12.4s\n"
-        "srshl v19.4s, v19.4s, v12.4s\n"
+        "srshl v16.4s, v16.4s, v11.4s\n"
+        "srshl v17.4s, v17.4s, v11.4s\n"
+        "srshl v18.4s, v18.4s, v11.4s\n"
+        "srshl v19.4s, v19.4s, v11.4s\n"
+        "b 9f\n"
+
+        "8:\n"
+        // Case where channels are columns
+
+        // Apply the positive exponent part of the multiplier.
+        "dup v20.4s, v12.s[0]\n"
+        "dup v21.4s, v12.s[1]\n"
+        "dup v22.4s, v12.s[2]\n"
+        "dup v23.4s, v12.s[3]\n"
+        "sshl v16.4s, v16.4s, v20.4s\n"
+        "sshl v17.4s, v17.4s, v21.4s\n"
+        "sshl v18.4s, v18.4s, v22.4s\n"
+        "sshl v19.4s, v19.4s, v23.4s\n"
+
+        // Apply the fixed-point part of the multiplier.
+        "sqrdmulh v16.4s, v16.4s, v15.s[0]\n"
+        "sqrdmulh v17.4s, v17.4s, v15.s[1]\n"
+        "sqrdmulh v18.4s, v18.4s, v15.s[2]\n"
+        "sqrdmulh v19.4s, v19.4s, v15.s[3]\n"
+
+        // Apply the negative exponent part of the multiplier.
+        "dup v20.4s, v11.s[0]\n"
+        "dup v21.4s, v11.s[1]\n"
+        "dup v22.4s, v11.s[2]\n"
+        "dup v23.4s, v11.s[3]\n"
+        "srshl v16.4s, v16.4s, v20.4s\n"
+        "srshl v17.4s, v17.4s, v21.4s\n"
+        "srshl v18.4s, v18.4s, v22.4s\n"
+        "srshl v19.4s, v19.4s, v23.4s\n"
+        "9:\n"
 
         "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
@@ -2091,21 +2155,30 @@
         "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
         "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
         "dup v9.4s, w3\n"   // create prod_zp_depth_vec
-        "add x5, x4, %x[row], lsl #2\n"
-        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "csel x4, x4, x5, eq\n"
-
-        "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
 
         "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
-        "add x5, x1, %x[row], lsl #2\n"
 
+        // Determine the channel index.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
+
+        // Offset the bias pointer as needed given the current row, col.
+        "add x5, x1, x3, lsl #2\n"
+
+        // If there is no bias, use no offset, just address the passed zero
+        // data.
         "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
         "csel x1, x1, x5, eq\n"
 
         // Load 4 bias values.
         "ld1 {v14.4s}, [x1]\n"
 
+        // Load the multiplier_fixedpoint values.
+        "add x5, x4, x3, lsl #2\n"
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
+        "csel x4, x4, x5, eq\n"
+        "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
+
         // Now that we know what LHS and RHS data the next iteration of the
         // main loop will need to load, we start loading the first 32 bytes of
         // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
@@ -2118,6 +2191,11 @@
 
         // Perform the bias-addition (per the above, we have just folded into
         // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+        // Jump based on channel dimension.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 6f\n"
+        // Case where channels are rows
+
         "add v16.4s, v16.4s, v14.4s\n"
         "ldr d1, [%[lhs_ptr], #16]\n"
         "add v17.4s, v17.4s, v14.4s\n"
@@ -2130,6 +2208,27 @@
         "ldr d6, [%[rhs_ptr], #32]\n"
         "ldr d7, [%[rhs_ptr], #48]\n"
 
+        "b 7f\n"
+
+        "6:\n"
+        // Case where channels are columns
+        "dup v20.4s, v14.s[0]\n"
+        "ldr d1, [%[lhs_ptr], #16]\n"
+        "dup v21.4s, v14.s[1]\n"
+        "ldr d2, [%[lhs_ptr], #32]\n"
+        "dup v22.4s, v14.s[2]\n"
+        "ldr d3, [%[lhs_ptr], #48]\n"
+        "dup v23.4s, v14.s[3]\n"
+        "ldr d4, [%[rhs_ptr], #0]\n"
+        "add v16.4s, v16.4s, v20.4s\n"
+        "ldr d5, [%[rhs_ptr], #16]\n"
+        "add v17.4s, v17.4s, v21.4s\n"
+        "ldr d6, [%[rhs_ptr], #32]\n"
+        "add v18.4s, v18.4s, v22.4s\n"
+        "ldr d7, [%[rhs_ptr], #48]\n"
+        "add v19.4s, v19.4s, v23.4s\n"
+        "7:\n"
+
         "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
         "beq 401f\n"
         "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
@@ -2176,16 +2275,26 @@
         // multiplied by a multiplier that has a fixed-point component and an
         // exponent component.
 
+        // Determine the channel index.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "csel w3, %w[row], %w[col], eq\n"
 
         "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
         "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "add x5, x1, %x[row], lsl #2\n"
+        "add x5, x1, x3, lsl #2\n"
         "csel x1, x1, x5, eq\n"
 
         "ld1 {v14.4s}, [x1]\n"
 
         "smax v12.4s, v14.4s, v8.4s\n"
         "ldr x1, [%[lhs_ptr], #8]\n"
+        "smin v11.4s, v14.4s, v8.4s\n"
+
+        // Jump based on channel dimension.
+        "tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 8f\n"
+        // Case where channels are rows
+
 
         // Apply the positive exponent part of the multiplier.
         "sshl v16.4s, v16.4s, v12.4s\n"
@@ -2196,7 +2305,6 @@
         "ldr x4, [%[lhs_ptr], #56]\n"
         "sshl v19.4s, v19.4s, v12.4s\n"
 
-        "smin v12.4s, v14.4s, v8.4s\n"
 
         // Apply the fixed-point part of the multiplier.
         "ins v0.d[1], x1\n"
@@ -2213,10 +2321,54 @@
         "sqrdmulh v19.4s, v19.4s, v15.4s\n"
 
         // Apply the negative exponent part of the multiplier.
-        "srshl v16.4s, v16.4s, v12.4s\n"
-        "srshl v17.4s, v17.4s, v12.4s\n"
-        "srshl v18.4s, v18.4s, v12.4s\n"
-        "srshl v19.4s, v19.4s, v12.4s\n"
+        "srshl v16.4s, v16.4s, v11.4s\n"
+        "srshl v17.4s, v17.4s, v11.4s\n"
+        "srshl v18.4s, v18.4s, v11.4s\n"
+        "srshl v19.4s, v19.4s, v11.4s\n"
+
+        "b 9f\n"
+
+        "8:\n"
+        // Case where channels are columns
+
+        // Apply the positive exponent part of the multiplier.
+        "dup v20.4s, v12.s[0]\n"
+        "ldr x2, [%[lhs_ptr], #24]\n"
+        "ldr x3, [%[lhs_ptr], #40]\n"
+        "dup v21.4s, v12.s[1]\n"
+        "ldr x4, [%[lhs_ptr], #56]\n"
+        "dup v22.4s, v12.s[2]\n"
+        "ins v0.d[1], x1\n"
+        "dup v23.4s, v12.s[3]\n"
+        "ldr x1, [%[rhs_ptr], #8]\n"
+        "sshl v16.4s, v16.4s, v20.4s\n"
+        "ins v1.d[1], x2\n"
+        "sshl v17.4s, v17.4s, v21.4s\n"
+        "ldr x2, [%[rhs_ptr], #24]\n"
+        "sshl v18.4s, v18.4s, v22.4s\n"
+        "ins v2.d[1], x3\n"
+        "sshl v19.4s, v19.4s, v23.4s\n"
+        "ldr x3, [%[rhs_ptr], #40]\n"
+
+        // Apply the fixed-point part of the multiplier.
+        "sqrdmulh v16.4s, v16.4s, v15.s[0]\n"
+        "ins v3.d[1], x4\n"
+        "sqrdmulh v17.4s, v17.4s, v15.s[1]\n"
+        "ldr x4, [%[rhs_ptr], #56]\n"
+        "sqrdmulh v18.4s, v18.4s, v15.s[2]\n"
+        "sqrdmulh v19.4s, v19.4s, v15.s[3]\n"
+
+        // Apply the negative exponent part of the multiplier.
+        "dup v20.4s, v11.s[0]\n"
+        "dup v21.4s, v11.s[1]\n"
+        "dup v22.4s, v11.s[2]\n"
+        "dup v23.4s, v11.s[3]\n"
+        "srshl v16.4s, v16.4s, v20.4s\n"
+        "srshl v17.4s, v17.4s, v21.4s\n"
+        "srshl v18.4s, v18.4s, v22.4s\n"
+        "srshl v19.4s, v19.4s, v23.4s\n"
+
+        "9:\n"
 
         "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
         "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"