blob: 247aecbbbb6da29caf2f5275e91a4b32b82a8546 [file] [log] [blame]
/* Copyright 2019 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include "ruy/common.h"
#include "ruy/kernel.h"
#include "ruy/opt_set.h"
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
namespace ruy {
#if RUY_PLATFORM_NEON_64 && RUY_OPT(ASM)
#define RUY_ASM_LABEL_STORE_UINT8 91
#define RUY_ASM_LABEL_STORE_INT8 92
#define RUY_ASM_LABEL_STORE_INT16 93
#define RUY_ASM_LABEL_STORE_INT32 94
#define RUY_ASM_LABEL_AFTER_STORE 99
#define RUY_OFFSET_BIAS 0
#define RUY_OFFSET_LHS_SUMS 8
#define RUY_OFFSET_RHS_SUMS 16
#define RUY_OFFSET_LHS_BASE_PTR 24
#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32
#define RUY_OFFSET_MULTIPLIER_EXPONENT 40
#define RUY_OFFSET_RHS_BASE_PTR 48
#define RUY_OFFSET_DST_BASE_PTR 56
#define RUY_OFFSET_LHS_ZERO_POINT 64
#define RUY_OFFSET_RHS_ZERO_POINT 68
#define RUY_OFFSET_DST_ZERO_POINT 72
#define RUY_OFFSET_PROD_ZP_DEPTH 76
#define RUY_OFFSET_START_ROW 80
#define RUY_OFFSET_START_COL 84
#define RUY_OFFSET_LAST_ROW 88
#define RUY_OFFSET_LAST_COL 92
#define RUY_OFFSET_DST_ROWS 96
#define RUY_OFFSET_DST_COLS 100
#define RUY_OFFSET_LHS_STRIDE 104
#define RUY_OFFSET_RHS_STRIDE 108
#define RUY_OFFSET_DST_STRIDE 112
#define RUY_OFFSET_DEPTH 116
#define RUY_OFFSET_CLAMP_MIN 120
#define RUY_OFFSET_CLAMP_MAX 124
#define RUY_OFFSET_FLAGS 128
template <typename Params>
void CheckOffsetsInKernelParams8bit(const Params&) {
static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT,
"");
static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT,
"");
static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT,
"");
static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH,
"");
static_assert(offsetof(Params, multiplier_fixedpoint) ==
RUY_OFFSET_MULTIPLIER_FIXEDPOINT,
"");
static_assert(
offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT,
"");
static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, "");
static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, "");
static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, "");
static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, "");
static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, "");
static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, "");
static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, "");
static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, "");
static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, "");
static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, "");
static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, "");
static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, "");
static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, "");
static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, "");
}
// Fast-int8-trick kernel, similar to this production gemmlowp kernel:
// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296
//
// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
// since these are 64-bit, out-of-order and without dotprod support.
void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label(
"Kernel (kNeon, optimized for out-of-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// v4 -- v7 from RHS:
//
// int8 RHS 16x4 block
// /-----------------------------------------|
// |v4.b[0] ... v7.b[0] |
// | ... ... |
// |v4.b[15] ... v7.b[15] |
// \-----------------------------------------/
// int8 LHS 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
// |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
// |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
// |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
//
// No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
// optimization for this kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 64 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Some multiplications and 16-bit accumulation were already done above,
// so we start right away in the middle.
"sadalp v16.4s, v8.8h\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"sadalp v24.4s, v8.8h\n"
"smull v8.8h, v0.8b, v4.8b\n"
"sadalp v25.4s, v9.8h\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
"smull v9.8h, v1.8b, v4.8b\n"
"sadalp v26.4s, v10.8h\n"
"smull v10.8h, v2.8b, v4.8b\n"
"sadalp v27.4s, v11.8h\n"
"smull v11.8h, v3.8b, v4.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v5.8b\n"
"sadalp v29.4s, v13.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v30.4s, v14.8h\n"
"smull v14.8h, v2.8b, v5.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Each iteration of this loop advances by 16 levels of depth.
"add w1, w1, #16\n"
// Loop termination condition
"cmp w1, w12\n"
"blt 2b\n"
"79:\n"
"sadalp v16.4s, v8.8h\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v24.4s, v8.8h\n"
"sadalp v25.4s, v9.8h\n"
"sadalp v26.4s, v10.8h\n"
"sadalp v27.4s, v11.8h\n"
"sadalp v28.4s, v12.8h\n"
"sadalp v29.4s, v13.8h\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 4x4 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x4 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Reduce 32bit accumulators horizontally.
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v18.4s, v19.4s\n"
"addp v20.4s, v20.4s, v21.4s\n"
"addp v22.4s, v22.4s, v23.4s\n"
"addp v24.4s, v24.4s, v25.4s\n"
"addp v26.4s, v26.4s, v27.4s\n"
"addp v28.4s, v28.4s, v29.4s\n"
"addp v30.4s, v30.4s, v31.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v16.4s, v16.4s, v18.4s\n"
"addp v17.4s, v20.4s, v22.4s\n"
"addp v18.4s, v24.4s, v26.4s\n"
"addp v19.4s, v28.4s, v30.4s\n"
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
RUY_MAKE_ZERO(v8)
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"add x5, x4, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
"add x5, x1, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v14.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v14.4s\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[1]\n"
"mls v18.4s, v10.4s, v14.s[2]\n"
"mls v19.4s, v10.4s, v14.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v11.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, %x[row], lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smax v12.4s, v14.4s, v8.4s\n"
"sshl v16.4s, v16.4s, v12.4s\n"
"sshl v17.4s, v17.4s, v12.4s\n"
"sshl v18.4s, v18.4s, v12.4s\n"
"sshl v19.4s, v19.4s, v12.4s\n"
"smin v12.4s, v14.4s, v8.4s\n"
// Apply the fixed-point part of the multiplier.
"sqrdmulh v16.4s, v16.4s, v15.4s\n"
"sqrdmulh v17.4s, v17.4s, v15.4s\n"
"sqrdmulh v18.4s, v18.4s, v15.4s\n"
"sqrdmulh v19.4s, v19.4s, v15.4s\n"
// We have some rounding division-by-power-of-two to do. This should
// always use "round to nearest". We allow for some
// freedom in how ties are broken, to strike a good compromise of
// performance on given hardware vs. perfect agreement of results
// across hardware.
//
// When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
// defined tie-breaks to help performance. On NEON, this means that we
// can just use the NEON rounding instructions, such as srshl. They
// happen to be breaking ties upward.
//
// When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
// break-ties-away-from zero, as described in Appendix B of
// https://arxiv.org/pdf/1712.05877.pdf
// When we wrote that, we thought that that would be better unbiased
// than the NEON upwards tie-breaks, and we had observed some
// improvement on some model. However, that is only more unbiased for
// data centered at zero, which was likely the case in that model,
// but is not always the case. If we wanted something more consistently
// unbiased then we should try breaking ties toward-nearest-even.
#if !RUY_OPT(NATIVE_ROUNDING)
// Fix up values to be right-shifted, so that the (round to nearest,
// break ties upward) behavior of srshl applied to these fixed-up
// values, produces the same result as the desired (round to nearest,
// break ties away from zero) behavior on the original values.
"and v8.16b, v16.16b, v12.16b\n"
"and v9.16b, v17.16b, v12.16b\n"
"and v14.16b, v18.16b, v12.16b\n"
"and v15.16b, v19.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sshr v9.4s, v9.4s, #31\n"
"sshr v14.4s, v14.4s, #31\n"
"sshr v15.4s, v15.4s, #31\n"
"sqadd v16.4s, v16.4s, v8.4s\n"
"sqadd v17.4s, v17.4s, v9.4s\n"
"sqadd v18.4s, v18.4s, v14.4s\n"
"sqadd v19.4s, v19.4s, v15.4s\n"
#endif
// At this point we have reduced the problem of correctly implementing
// rounding divide-by-power-of-two, to what the SRSHL instruction can
// do.
"srshl v16.4s, v16.4s, v12.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"srshl v18.4s, v18.4s, v12.4s\n"
"srshl v19.4s, v19.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"add v16.8h, v16.8h, v14.8h\n"
"add v17.8h, v17.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
"sqxtun2 v16.16b, v17.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"add v16.8h, v16.8h, v14.8h\n"
"add v17.8h, v17.8h, v14.8h\n"
// Cast-and-saturate from int16 to int8
"sqxtn v16.8b, v16.8h\n"
"sqxtn2 v16.16b, v17.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.4h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Load the clamp_min, clamp_max bounds
"ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[1], [x3], #2\n"
"st1 {v16.h}[2], [x3], #2\n"
"st1 {v16.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[5], [x3], #2\n"
"st1 {v16.h}[6], [x3], #2\n"
"st1 {v16.h}[7], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[1], [x3], #2\n"
"st1 {v17.h}[2], [x3], #2\n"
"st1 {v17.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[5], [x3], #2\n"
"st1 {v17.h}[6], [x3], #2\n"
"st1 {v17.h}[7], [x3], #2\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
"str q18, [%[dst_tmp_buf], #32]\n"
"str q19, [%[dst_tmp_buf], #48]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v17.s}[1], [x3], #4\n"
"st1 {v17.s}[2], [x3], #4\n"
"st1 {v17.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v18.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v18.s}[1], [x3], #4\n"
"st1 {v18.s}[2], [x3], #4\n"
"st1 {v18.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v19.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v19.s}[1], [x3], #4\n"
"st1 {v19.s}[2], [x3], #4\n"
"st1 {v19.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #4\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #4\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #16\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of
// RHS cols == 1.
// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75,
// since these are 64-bit, out-of-order and without dotprod support.
void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label(
"Kernel (kNeon, optimized for out-of-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v19 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// v4 from RHS:
//
// int8 RHS 16x1 block
// /-----------|
// |v4.b[0] |
// | ... |
// |v4.b[15] |
// \-----------/
// int8 LHS 4x16 block
// /---------------------\ /-----------|
// |v0.b[0] ... v0.b[15] | |v16.4s |
// |v1.b[0] ... v1.b[15] | |v17.4s |
// |v2.b[0] ... v2.b[15] | |v18.4s |
// |v3.b[0] ... v3.b[15] | |v19.4s |
// \---------------------/ \-----------/
// int32 accumulators 4x1 block
//
// No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING
// optimization for this kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 64 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"add %[rhs_ptr], %[rhs_ptr], #48\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Some multiplications and 16-bit accumulation were already done above,
// so we start right away in the middle.
"sadalp v16.4s, v8.8h\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"add %[rhs_ptr], %[rhs_ptr], #48\n"
"sadalp v17.4s, v9.8h\n"
"sadalp v18.4s, v10.8h\n"
"sadalp v19.4s, v11.8h\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Each iteration of this loop advances by 16 levels of depth.
"add w1, w1, #16\n"
// Loop termination condition
"cmp w1, w12\n"
"blt 2b\n"
"79:\n"
"sadalp v16.4s, v8.8h\n"
"sadalp v17.4s, v9.8h\n"
"sadalp v18.4s, v10.8h\n"
"sadalp v19.4s, v11.8h\n"
// End of accumulation. The registers v16 -- v19 contain the final
// int32 accumulator values of the current 4x1 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x1 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Reduce 32bit accumulators horizontally.
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v18.4s, v19.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v16.4s, v16.4s, v18.4s\n"
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
// (still multiply column stride by 4 due to packing)
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
RUY_MAKE_ZERO(v8)
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"add x5, x4, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
// Now we load: bias data, LHS sums data, RHS sums data.
// First, load the base pointers from the params.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
"add x5, x1, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
"add %[rhs_ptr], %[rhs_ptr], #48\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
// (all four 32-bit accumulators are in v16 at this point)
"add v16.4s, v16.4s, v14.4s\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, %x[row], lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smax v12.4s, v14.4s, v8.4s\n"
"sshl v16.4s, v16.4s, v12.4s\n"
"smin v12.4s, v14.4s, v8.4s\n"
// Apply the fixed-point part of the multiplier.
"sqrdmulh v16.4s, v16.4s, v15.4s\n"
// We have some rounding division-by-power-of-two to do. This should
// always use "round to nearest". We allow for some
// freedom in how ties are broken, to strike a good compromise of
// performance on given hardware vs. perfect agreement of results
// across hardware.
//
// When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
// defined tie-breaks to help performance. On NEON, this means that we
// can just use the NEON rounding instructions, such as srshl. They
// happen to be breaking ties upward.
//
// When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
// break-ties-away-from zero, as described in Appendix B of
// https://arxiv.org/pdf/1712.05877.pdf
// When we wrote that, we thought that that would be better unbiased
// than the NEON upwards tie-breaks, and we had observed some
// improvement on some model. However, that is only more unbiased for
// data centered at zero, which was likely the case in that model,
// but is not always the case. If we wanted something more consistently
// unbiased then we should try breaking ties toward-nearest-even.
#if !RUY_OPT(NATIVE_ROUNDING)
// Fix up values to be right-shifted, so that the (round to nearest,
// break ties upward) behavior of srshl applied to these fixed-up
// values, produces the same result as the desired (round to nearest,
// break ties away from zero) behavior on the original values.
"and v8.16b, v16.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sqadd v16.4s, v16.4s, v8.4s\n"
#endif
// At this point we have reduced the problem of correctly implementing
// rounding divide-by-power-of-two, to what the SRSHL instruction can
// do.
"srshl v16.4s, v16.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this instruction, all data is in lower half (64-bits) of v16
"sqxtn v16.4h, v16.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"add v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
// Now all data is in the first 32-bits of v16
"sqxtun v16.8b, v16.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x1 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x1, there are some 4x1 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x1 block fit
"csel w1, w1, w3, le\n"
// Test if w1==4, i.e. if all of the 4x1 block fits.
"cmp w1, w3\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x1 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x1 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
// After this, all values for output are in the lower half (64 bits) of v16.
"sqxtn v16.4h, v16.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"add v16.8h, v16.8h, v14.8h\n"
// Cast-and-saturate from int16 to int8
"sqxtn v16.8b, v16.8h\n"
// At this point, we only need 4 lowest 8-bit values in v16.
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x1 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Test if w1==4, i.e. if all of the 4x1 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.4h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
// After this instruction, all data is in lower half of v16.
"sqxtn v16.4h, v16.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
// Load the clamp_min, clamp_max bounds
"ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[0], [x3], #2\n"
"st1 {v16.h}[1], [x3], #2\n"
"st1 {v16.h}[2], [x3], #2\n"
"st1 {v16.h}[3], [x3], #2\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Test if w1==4 i.e. if all of the 4x1 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x1 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #4\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #4\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19");
}
// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs.
// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and
// the original Cortex-A55, since these are 64-bit and do not support dotprod.
//
// While this kernel does not have a direct equivalent in gemmlowp, it was
// developed based on insights that David Mansell at ARM shared with their
// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful
// comments. Specifically, see this comment about tuning for Cortex-A53:
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215
void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) {
profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v3 are used to load int8 data from LHS and
// v4 -- v7 from RHS:
//
// int8 RHS 16x4 block
// /-----------------------------------------|
// |v4.b[0] ... v7.b[0] |
// | ... ... |
// |v4.b[15] ... v7.b[15] |
// \-----------------------------------------/
// int8 LHS 4x16 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s |
// |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s |
// |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s |
// |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s |
// \---------------------/ \-----------------------------------------/
// int32 accumulators 4x4 block
asm volatile(
#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
RUY_MAKE_ZERO(v16)
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
RUY_MAKE_ZERO(v17)
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
RUY_MAKE_ZERO(v18)
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
RUY_MAKE_ZERO(v19)
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
RUY_MAKE_ZERO(v20)
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
RUY_MAKE_ZERO(v21)
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
RUY_MAKE_ZERO(v22)
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
RUY_MAKE_ZERO(v23)
// Load the first 64 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v24)
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v25)
"ld1 {v2.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v26)
"ld1 {v3.16b}, [%[lhs_ptr]], #16\n"
RUY_MAKE_ZERO(v27)
"ld1 {v4.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v28)
"ld1 {v5.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v29)
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v30)
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 16.
"mov w1, #16\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
"smull v11.8h, v3.8b, v4.8b\n"
"smull v12.8h, v0.8b, v5.8b\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Some multiplications and 16-bit accumulation were already done above,
// so we start right away in the middle.
"sadalp v16.4s, v8.8h\n"
"ldr d4, [%[rhs_ptr], #0]\n"
"smull v8.8h, v0.8b, v6.8b\n"
"ldr x7, [%[rhs_ptr], #8]\n"
"sadalp v17.4s, v9.8h\n"
"ldr d5, [%[rhs_ptr], #16]\n"
"smull v9.8h, v1.8b, v6.8b\n"
"ldr x8, [%[rhs_ptr], #24]\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"add %[lhs_ptr], %[lhs_ptr], #64\n"
"smull v11.8h, v3.8b, v6.8b\n"
"add %[rhs_ptr], %[rhs_ptr], #64\n"
"sadalp v20.4s, v12.8h\n"
// Each iteration of this loop advances by 16 levels of depth.
"add w1, w1, #16\n"
"smull v12.8h, v0.8b, v7.8b\n"
// Loop termination condition
"cmp w1, w12\n"
"sadalp v21.4s, v13.8h\n"
"ldr x3, [%[lhs_ptr], #-56]\n"
"smull v13.8h, v1.8b, v7.8b\n"
"ldr x4, [%[lhs_ptr], #-40]\n"
"sadalp v22.4s, v14.8h\n"
"ldr x5, [%[lhs_ptr], #-24]\n"
"smull v14.8h, v2.8b, v7.8b\n"
"ldr x6, [%[lhs_ptr], #-8]\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"ldr x9, [%[rhs_ptr], #-24]\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"ldr d6, [%[rhs_ptr], #-32]\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"ldr d0, [%[lhs_ptr], #-64]\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"ldr d1, [%[lhs_ptr], #-48]\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"ins v4.d[1], x7\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"ins v5.d[1], x8\n"
"ldr d2, [%[lhs_ptr], #-32]\n"
"ins v0.d[1], x3\n"
"sadalp v24.4s, v8.8h\n"
"ldr d3, [%[lhs_ptr], #-16]\n"
"ins v1.d[1], x4\n"
"smull v8.8h, v0.8b, v4.8b\n"
"ins v2.d[1], x5\n"
"sadalp v25.4s, v9.8h\n"
"ins v3.d[1], x6\n"
"smull v9.8h, v1.8b, v4.8b\n"
"ldr d7, [%[rhs_ptr], #-16]\n"
"sadalp v26.4s, v10.8h\n"
"ldr x10, [%[rhs_ptr], #-8]\n"
"smull v10.8h, v2.8b, v4.8b\n"
"sadalp v27.4s, v11.8h\n"
"smull v11.8h, v3.8b, v4.8b\n"
"sadalp v28.4s, v12.8h\n"
"smull v12.8h, v0.8b, v5.8b\n"
"sadalp v29.4s, v13.8h\n"
"smull v13.8h, v1.8b, v5.8b\n"
"sadalp v30.4s, v14.8h\n"
"smull v14.8h, v2.8b, v5.8b\n"
"sadalp v31.4s, v15.8h\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"ins v6.d[1], x9\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"ins v7.d[1], x10\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
"blt 2b\n"
"79:\n"
"sadalp v16.4s, v8.8h\n"
"smull v8.8h, v0.8b, v6.8b\n"
"sadalp v17.4s, v9.8h\n"
"smull v9.8h, v1.8b, v6.8b\n"
"sadalp v18.4s, v10.8h\n"
"smull v10.8h, v2.8b, v6.8b\n"
"sadalp v19.4s, v11.8h\n"
"smull v11.8h, v3.8b, v6.8b\n"
"sadalp v20.4s, v12.8h\n"
"smull v12.8h, v0.8b, v7.8b\n"
"sadalp v21.4s, v13.8h\n"
"smull v13.8h, v1.8b, v7.8b\n"
"sadalp v22.4s, v14.8h\n"
"smull v14.8h, v2.8b, v7.8b\n"
"sadalp v23.4s, v15.8h\n"
"smull v15.8h, v3.8b, v7.8b\n"
// Multiply-accumulate second-half, again into the same
// 16bit local accumulator registers. This is where we
// take advantage of having int8 instead of uint8 and therefore
// being able to accumulate two products into int16.
"smlal2 v8.8h, v0.16b, v6.16b\n"
"smlal2 v9.8h, v1.16b, v6.16b\n"
"smlal2 v10.8h, v2.16b, v6.16b\n"
"smlal2 v11.8h, v3.16b, v6.16b\n"
"smlal2 v12.8h, v0.16b, v7.16b\n"
"smlal2 v13.8h, v1.16b, v7.16b\n"
"smlal2 v14.8h, v2.16b, v7.16b\n"
"smlal2 v15.8h, v3.16b, v7.16b\n"
"sadalp v24.4s, v8.8h\n"
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"sadalp v25.4s, v9.8h\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"sadalp v26.4s, v10.8h\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"sadalp v27.4s, v11.8h\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"sadalp v28.4s, v12.8h\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"sadalp v29.4s, v13.8h\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"sadalp v30.4s, v14.8h\n"
"sadalp v31.4s, v15.8h\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 4x4 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 4x4 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Reduce 32bit accumulators horizontally.
"addp v16.4s, v16.4s, v17.4s\n"
"addp v18.4s, v18.4s, v19.4s\n"
"addp v20.4s, v20.4s, v21.4s\n"
"addp v22.4s, v22.4s, v23.4s\n"
"addp v24.4s, v24.4s, v25.4s\n"
"addp v26.4s, v26.4s, v27.4s\n"
"addp v28.4s, v28.4s, v29.4s\n"
"addp v30.4s, v30.4s, v31.4s\n"
// Reduce 32bit accumulators horizontally, second pass
// (each pass adds pairwise. we need to add 4-wise).
"addp v16.4s, v16.4s, v18.4s\n"
"addp v17.4s, v20.4s, v22.4s\n"
"addp v18.4s, v24.4s, v26.4s\n"
"addp v19.4s, v28.4s, v30.4s\n"
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
RUY_MAKE_ZERO(v8)
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"add x5, x4, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
"add x5, x1, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 4 bias values.
"ld1 {v14.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"ldr d0, [%[lhs_ptr], #0]\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
"add v16.4s, v16.4s, v14.4s\n"
"ldr d1, [%[lhs_ptr], #16]\n"
"add v17.4s, v17.4s, v14.4s\n"
"ldr d2, [%[lhs_ptr], #32]\n"
"add v18.4s, v18.4s, v14.4s\n"
"ldr d3, [%[lhs_ptr], #48]\n"
"add v19.4s, v19.4s, v14.4s\n"
"ldr d4, [%[rhs_ptr], #0]\n"
"ldr d5, [%[rhs_ptr], #16]\n"
"ldr d6, [%[rhs_ptr], #32]\n"
"ldr d7, [%[rhs_ptr], #48]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[1]\n"
"mls v18.4s, v10.4s, v14.s[2]\n"
"mls v19.4s, v10.4s, v14.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v11.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v11.4s\n"
// If the destination is int32, it means the user asks for the raw
// accumulators, no need for us to downquantize the value.
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, %x[row], lsl #2\n"
"csel x1, x1, x5, eq\n"
"ld1 {v14.4s}, [x1]\n"
"smax v12.4s, v14.4s, v8.4s\n"
"ldr x1, [%[lhs_ptr], #8]\n"
"sshl v16.4s, v16.4s, v12.4s\n"
"ldr x2, [%[lhs_ptr], #24]\n"
"sshl v17.4s, v17.4s, v12.4s\n"
"ldr x3, [%[lhs_ptr], #40]\n"
"sshl v18.4s, v18.4s, v12.4s\n"
"ldr x4, [%[lhs_ptr], #56]\n"
"sshl v19.4s, v19.4s, v12.4s\n"
"smin v12.4s, v14.4s, v8.4s\n"
// Apply the fixed-point part of the multiplier.
"ins v0.d[1], x1\n"
"ldr x1, [%[rhs_ptr], #8]\n"
"sqrdmulh v16.4s, v16.4s, v15.4s\n"
"ins v1.d[1], x2\n"
"ldr x2, [%[rhs_ptr], #24]\n"
"sqrdmulh v17.4s, v17.4s, v15.4s\n"
"ins v2.d[1], x3\n"
"ldr x3, [%[rhs_ptr], #40]\n"
"sqrdmulh v18.4s, v18.4s, v15.4s\n"
"ins v3.d[1], x4\n"
"ldr x4, [%[rhs_ptr], #56]\n"
"sqrdmulh v19.4s, v19.4s, v15.4s\n"
// We have some rounding division-by-power-of-two to do. This should
// always use "round to nearest". We allow for some
// freedom in how ties are broken, to strike a good compromise of
// performance on given hardware vs. perfect agreement of results
// across hardware.
//
// When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
// defined tie-breaks to help performance. On NEON, this means that we
// can just use the NEON rounding instructions, such as srshl. They
// happen to be breaking ties upward.
//
// When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
// break-ties-away-from zero, as described in Appendix B of
// https://arxiv.org/pdf/1712.05877.pdf
// When we wrote that, we thought that that would be better unbiased
// than the NEON upwards tie-breaks, and we had observed some
// improvement on some model. However, that is only more unbiased for
// data centered at zero, which was likely the case in that model,
// but is not always the case. If we wanted something more consistently
// unbiased then we should try breaking ties toward-nearest-even.
#if !RUY_OPT(NATIVE_ROUNDING)
// Fix up values to be right-shifted, so that the (round to nearest,
// break ties upward) behavior of srshl applied to these fixed-up
// values, produces the same result as the desired (round to nearest,
// break ties away from zero) behavior on the original values.
"and v8.16b, v16.16b, v12.16b\n"
"and v9.16b, v17.16b, v12.16b\n"
"and v14.16b, v18.16b, v12.16b\n"
"and v15.16b, v19.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sshr v9.4s, v9.4s, #31\n"
"sshr v14.4s, v14.4s, #31\n"
"sshr v15.4s, v15.4s, #31\n"
"sqadd v16.4s, v16.4s, v8.4s\n"
"sqadd v17.4s, v17.4s, v9.4s\n"
"sqadd v18.4s, v18.4s, v14.4s\n"
"sqadd v19.4s, v19.4s, v15.4s\n"
#endif
// At this point we have reduced the problem of correctly implementing
// rounding divide-by-power-of-two, to what the SRSHL instruction can
// do.
"srshl v16.4s, v16.4s, v12.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"srshl v18.4s, v18.4s, v12.4s\n"
"srshl v19.4s, v19.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
"ins v4.d[1], x1\n"
"sqxtn v16.4h, v16.4s\n"
"ins v5.d[1], x2\n"
"sqxtn2 v16.8h, v17.4s\n"
"ins v6.d[1], x3\n"
"sqxtn v17.4h, v18.4s\n"
"ins v7.d[1], x4\n"
RUY_MAKE_ZERO(v18)
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"add %[lhs_ptr], %[lhs_ptr], #64\n"
"dup v14.8h, v13.h[4]\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
"add v16.8h, v16.8h, v14.8h\n"
RUY_MAKE_ZERO(v21)
"add v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v22)
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
RUY_MAKE_ZERO(v23)
"sqxtun2 v16.16b, v17.8h\n"
RUY_MAKE_ZERO(v24)
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
RUY_MAKE_ZERO(v25)
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
RUY_MAKE_ZERO(v26)
"dup v14.16b, w2\n" // clamp_min
RUY_MAKE_ZERO(v27)
"dup v15.16b, w3\n" // clamp_max
RUY_MAKE_ZERO(v28)
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
RUY_MAKE_ZERO(v29)
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
"ins v4.d[1], x1\n"
"sqxtn v16.4h, v16.4s\n"
"ins v5.d[1], x2\n"
"sqxtn2 v16.8h, v17.4s\n"
"ins v6.d[1], x3\n"
"sqxtn v17.4h, v18.4s\n"
"ins v7.d[1], x4\n"
RUY_MAKE_ZERO(v18)
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v19)
// Add the destination zero point
"add %[lhs_ptr], %[lhs_ptr], #64\n"
"dup v14.8h, v13.h[4]\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
"add v16.8h, v16.8h, v14.8h\n"
RUY_MAKE_ZERO(v21)
"add v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v22)
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
RUY_MAKE_ZERO(v23)
"sqxtn2 v16.16b, v17.8h\n"
RUY_MAKE_ZERO(v24)
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
RUY_MAKE_ZERO(v25)
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
RUY_MAKE_ZERO(v26)
"dup v14.16b, w2\n" // clamp_min
RUY_MAKE_ZERO(v27)
"dup v15.16b, w3\n" // clamp_max
RUY_MAKE_ZERO(v28)
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
RUY_MAKE_ZERO(v29)
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"st1 {v16.16b}, [%[dst_tmp_buf]]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #4\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[0], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[1], [x3], #1\n"
"st1 {v16.b}[2], [x3], #1\n"
"st1 {v16.b}[3], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[4], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[5], [x3], #1\n"
"st1 {v16.b}[6], [x3], #1\n"
"st1 {v16.b}[7], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[8], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[9], [x3], #1\n"
"st1 {v16.b}[10], [x3], #1\n"
"st1 {v16.b}[11], [x3], #1\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.b}[12], [x3], #1\n"
"add x4, x4, x11\n"
"st1 {v16.b}[13], [x3], #1\n"
"st1 {v16.b}[14], [x3], #1\n"
"st1 {v16.b}[15], [x3], #1\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #4\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.4h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"ins v4.d[1], x1\n"
"sqxtn v16.4h, v16.4s\n"
"ins v5.d[1], x2\n"
"sqxtn2 v16.8h, v17.4s\n"
"ins v6.d[1], x3\n"
"sqxtn v17.4h, v18.4s\n"
"ins v7.d[1], x4\n"
RUY_MAKE_ZERO(v18)
"sqxtn2 v17.8h, v19.4s\n"
// At this point, v18 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v19)
"add %[lhs_ptr], %[lhs_ptr], #64\n"
RUY_MAKE_ZERO(v20)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
// Load the clamp_min, clamp_max bounds
"ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
RUY_MAKE_ZERO(v25)
"ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
RUY_MAKE_ZERO(v26)
"dup v14.8h, w2\n" // clamp_min
RUY_MAKE_ZERO(v27)
"dup v15.8h, w3\n" // clamp_max
RUY_MAKE_ZERO(v28)
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
RUY_MAKE_ZERO(v29)
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldrh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[1], [x3], #2\n"
"st1 {v16.h}[2], [x3], #2\n"
"st1 {v16.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v16.h}[5], [x3], #2\n"
"st1 {v16.h}[6], [x3], #2\n"
"st1 {v16.h}[7], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[0], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[1], [x3], #2\n"
"st1 {v17.h}[2], [x3], #2\n"
"st1 {v17.h}[3], [x3], #2\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.h}[4], [x3], #2\n"
"add x4, x4, x11\n"
"st1 {v17.h}[5], [x3], #2\n"
"st1 {v17.h}[6], [x3], #2\n"
"st1 {v17.h}[7], [x3], #2\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
"ldr x1, [%[lhs_ptr], #8]\n"
"ldr x2, [%[lhs_ptr], #24]\n"
"ldr x3, [%[lhs_ptr], #40]\n"
"ldr x4, [%[lhs_ptr], #56]\n"
"ins v0.d[1], x1\n"
"ldr x1, [%[rhs_ptr], #8]\n"
"ins v1.d[1], x2\n"
"ldr x2, [%[rhs_ptr], #24]\n"
"ins v2.d[1], x3\n"
"ldr x3, [%[rhs_ptr], #40]\n"
"ins v3.d[1], x4\n"
"ldr x4, [%[rhs_ptr], #56]\n"
"ins v4.d[1], x1\n"
"ins v5.d[1], x2\n"
"ins v6.d[1], x3\n"
"ins v7.d[1], x4\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// At this point, v20 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v20)
"add %[lhs_ptr], %[lhs_ptr], #64\n"
RUY_MAKE_ZERO(v21)
"add %[rhs_ptr], %[rhs_ptr], #64\n"
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
// Compute how much of the 4x4 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 4x4, there are some 4x4 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
RUY_MAKE_ZERO(v31)
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #4\n"
"cmp w1, #4\n"
// Compute w1 = how many rows of the 4x4 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #4\n"
// Compute w2 = how many cols of the 4x4 block fit
"csel w2, w2, w3, le\n"
// Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
"mov x4, %[dst_ptr]\n"
// Yes, all of the 4x4 block fits, go to fast path.
"beq 30f\n"
// Not all of the 4x4 block fits.
// Store to dst_tmp_buf
"str q16, [%[dst_tmp_buf], #0]\n"
"str q17, [%[dst_tmp_buf], #16]\n"
"str q18, [%[dst_tmp_buf], #32]\n"
"str q19, [%[dst_tmp_buf], #48]\n"
// Slow loop copying from dst_tmp_buf to dst.
"mov x3, %[dst_tmp_buf]\n"
"mov w6, #0\n"
"50:\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov w5, #0\n"
"51:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"b 31f\n"
"30:\n"
// Yes, all of the 4x4 block fits.
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v16.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v16.s}[1], [x3], #4\n"
"st1 {v16.s}[2], [x3], #4\n"
"st1 {v16.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v17.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v17.s}[1], [x3], #4\n"
"st1 {v17.s}[2], [x3], #4\n"
"st1 {v17.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v18.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v18.s}[1], [x3], #4\n"
"st1 {v18.s}[2], [x3], #4\n"
"st1 {v18.s}[3], [x3], #4\n"
RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n")
"mov x3, x4\n"
"st1 {v19.s}[0], [x3], #4\n"
"add x4, x4, x11\n"
"st1 {v19.s}[1], [x3], #4\n"
"st1 {v19.s}[2], [x3], #4\n"
"st1 {v19.s}[3], [x3], #4\n"
"31:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
"smull v8.8h, v0.8b, v4.8b\n"
"smull v9.8h, v1.8b, v4.8b\n"
"smull v10.8h, v2.8b, v4.8b\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"smull v11.8h, v3.8b, v4.8b\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"smull v12.8h, v0.8b, v5.8b\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"smull v13.8h, v1.8b, v5.8b\n"
"smull v14.8h, v2.8b, v5.8b\n"
"smull v15.8h, v3.8b, v5.8b\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"smlal2 v8.8h, v0.16b, v4.16b\n"
"smlal2 v9.8h, v1.16b, v4.16b\n"
"smlal2 v10.8h, v2.16b, v4.16b\n"
"smlal2 v11.8h, v3.16b, v4.16b\n"
"smlal2 v12.8h, v0.16b, v5.16b\n"
"smlal2 v13.8h, v1.16b, v5.16b\n"
"smlal2 v14.8h, v2.16b, v5.16b\n"
"smlal2 v15.8h, v3.16b, v5.16b\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #4\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #4\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #16\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(&params),[dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}
// Kernel taking advantage of the optional dotprod instruction.
// This is very similar to (and directly inspired by) this gemmlowp kernel
// which was contributed by David Mansell at ARM:
// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct
// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391
//
// Besides the ruy-ification, the main difference here is that we use a 8x8
// instead of 12x8 width, so as to stick to power-of-two widths. This slightly
// narrower kernel layout is still wide enough to achieve high performance
// although we haven't actually performed a real comparison to know exactly
// how this compares to ARM's aforementioned kernel.
//
// Relevant target CPUs for this kernel include ARM Cortex-A76,
// since these are 64-bit, out-of-order and with dotprod support.
void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) {
profiler::ScopeLabel label(
"Kernel (kNeonDotprod, optimized for out-of-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v15 are used to load int8 data from LHS and
// RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
// v3 are used to load a 4x8 block of RHS, like this:
//
// int8 RHS 4x8 block
// /-----------------------------------------|
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
// | ... ... |
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
// \-----------------------------------------/
// int8 LHS 8x4 block
// /---------------------\ /-----------------------------------------|
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
// | ... ... | | ... ... |
// |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
// |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
// | ... ... | | ... ... |
// |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// int32 accumulators 8x8 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for loading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Optional, maximally-streaming, partial-unrolling (4x unrolled)
// optimization of the kernel inner loop (over depth). For more
// comments, see the non-unrolled loop below after the #endif.
#if RUY_OPT(MAX_STREAMING)
"cmp w12, #32\n"
"blt 78f\n"
"ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
"mov w1, #16\n"
"and w3, w12, #-16\n"
"81:\n"
"add w1, w1, #16\n"
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ldr q0, [%[lhs_ptr], #0]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ldr q2, [%[rhs_ptr], #0]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ldr q1, [%[lhs_ptr], #16]\n"
".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
"ldr q3, [%[rhs_ptr], #16]\n"