blob: ba4f341820e836ebc0cf30b6ab7559f46be13763 [file] [log] [blame]
// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
#include "../internal/kernel_default.h"
#include "../public/map.h"
#include "../public/output_stages.h"
#include "multi_thread_gemm.h"
namespace gemmlowp {
template <typename T>
struct TransposeImpl {
typedef T DstType;
static T Run(const T& t) { return t; }
};
template <typename T>
using TransposeType = typename TransposeImpl<T>::DstType;
template <typename T>
TransposeType<T> Transpose(const T& t) {
return TransposeImpl<T>::Run(t);
}
template <MapOrder Order>
struct TransposeMapOrder {
static constexpr MapOrder Value =
Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
};
template <VectorShape Shape>
struct TransposeVectorShape {
static constexpr VectorShape Value =
Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
};
template <typename Scalar, VectorShape Shape>
struct TransposeImpl<VectorMap<Scalar, Shape>> {
typedef VectorMap<Scalar, Shape> SrcType;
static constexpr VectorShape TransposedShape =
TransposeVectorShape<Shape>::Value;
typedef VectorMap<Scalar, TransposedShape> DstType;
static DstType Run(const SrcType& src) {
return DstType(src.data(), src.size());
}
};
template <typename Scalar, MapOrder Order>
struct TransposeImpl<MatrixMap<Scalar, Order>> {
typedef MatrixMap<Scalar, Order> SrcType;
static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
typedef MatrixMap<Scalar, TransposedOrder> DstType;
static DstType Run(const SrcType& src) {
return DstType(src.data(), src.cols(), src.rows(), src.stride());
}
};
template <VectorShape Shape>
struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
static DstType Run(const SrcType& src) {
DstType dst;
dst.result_shift = src.result_shift;
dst.result_offset = Transpose(src.result_offset);
dst.result_mult_int = Transpose(src.result_mult_int);
return dst;
}
};
template <VectorShape Shape>
struct TransposeImpl<OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>> {
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> SrcType;
static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
typedef OutputStageScaleInt32ByFixedPointAndExponentPC<TransposedShape>
DstType;
static DstType Run(const SrcType& src) {
DstType dst;
dst.result_fixedpoint_multiplier =
Transpose(src.result_fixedpoint_multiplier);
dst.result_exponent = Transpose(src.result_exponent);
dst.result_offset_after_shift = src.result_offset_after_shift;
return dst;
}
};
template <typename VectorMapType>
struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
typedef OutputStageBiasAddition<VectorMapType> SrcType;
typedef TransposeType<VectorMapType> TransposedVectorMapType;
typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
static DstType Run(const SrcType& src) {
DstType dst;
dst.bias_vector = Transpose(src.bias_vector);
return dst;
}
};
// TODO(benoitjacob) - does anyone understand C++ variadic templates?
// How to use them to implement TransposeTuple? Note: there are lots
// of answers on StackOverflow but they seem to all involve either
// C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
template <typename T0>
std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
return std::make_tuple(Transpose(std::get<0>(t)));
}
template <typename T0, typename T1>
std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
const std::tuple<T0, T1>& t) {
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
}
template <typename T0, typename T1, typename T2>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
TransposeTuple(const std::tuple<T0, T1, T2>& t) {
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
Transpose(std::get<2>(t)));
}
template <typename T0, typename T1, typename T2, typename T3>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
TransposeType<T3>>
TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
}
template <typename T0, typename T1, typename T2, typename T3, typename T4>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
TransposeType<T3>, TransposeType<T4>>
TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
Transpose(std::get<4>(t)));
}
template <typename T0, typename T1, typename T2, typename T3, typename T4,
typename T5>
std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
}
template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
typename GemmContextType>
void DispatchGemmShape(GemmContextType* context,
const MatrixMap<const InputScalar, LhsOrder>& lhs,
const MatrixMap<const InputScalar, RhsOrder>& rhs,
MatrixMap<OutputScalar, ResultOrder>* result,
const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
const OutputPipelineType& output_pipeline) {
assert(lhs.cols() == rhs.rows());
int rows = result->rows();
int cols = result->cols();
int depth = lhs.cols();
if (rows == 0 || cols == 0 || depth == 0) {
// Vacuous GEMM, return early to avoid having to deal with
// zero sizes below.
return;
}
if (rows < cols) {
auto transposed_result_map = Transpose(*result);
return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
Transpose(rhs_offset), Transpose(lhs_offset),
TransposeTuple(output_pipeline));
}
typedef DefaultKernel<BitDepthParams> Kernel;
MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
BitDepthParams>(context, Kernel(), lhs, rhs, result,
lhs_offset, rhs_offset, output_pipeline);
}
} // end namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_