Merge pull request #2416 from eugeneo:unroller-const-input
PiperOrigin-RevId: 707971695
diff --git a/hwy/contrib/unroller/unroller-inl.h b/hwy/contrib/unroller/unroller-inl.h
index 4ed8c25..e5d9661 100644
--- a/hwy/contrib/unroller/unroller-inl.h
+++ b/hwy/contrib/unroller/unroller-inl.h
@@ -63,11 +63,11 @@
Y_VEC YInitImpl() { return hn::Zero(d_out); }
- X_VEC Load(const ptrdiff_t idx, IN_T* from) {
+ X_VEC Load(const ptrdiff_t idx, const IN_T* from) {
return me()->LoadImpl(idx, from);
}
- X_VEC LoadImpl(const ptrdiff_t idx, IN_T* from) {
+ X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) {
return hn::LoadU(d_in, from + idx);
}
@@ -77,11 +77,13 @@
// | o | o | o | x | x | x | x | x |
// example places = -3
// | x | x | x | x | x | o | o | o |
- X_VEC MaskLoad(const ptrdiff_t idx, IN_T* from, const ptrdiff_t places) {
+ X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from,
+ const ptrdiff_t places) {
return me()->MaskLoadImpl(idx, from, places);
}
- X_VEC MaskLoadImpl(const ptrdiff_t idx, IN_T* from, const ptrdiff_t places) {
+ X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from,
+ const ptrdiff_t places) {
auto mask = hn::FirstN(d_in, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
d_in,
@@ -181,19 +183,19 @@
Y_VEC YInitImpl() { return hn::Zero(d_out); }
- X0_VEC Load0(const ptrdiff_t idx, IN0_T* from) {
+ X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) {
return me()->Load0Impl(idx, from);
}
- X0_VEC Load0Impl(const ptrdiff_t idx, IN0_T* from) {
+ X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) {
return hn::LoadU(d_in0, from + idx);
}
- X1_VEC Load1(const ptrdiff_t idx, IN1_T* from) {
+ X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) {
return me()->Load1Impl(idx, from);
}
- X1_VEC Load1Impl(const ptrdiff_t idx, IN1_T* from) {
+ X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) {
return hn::LoadU(d_in1, from + idx);
}
@@ -203,11 +205,12 @@
// | o | o | o | x | x | x | x | x |
// example places = -3
// | x | x | x | x | x | o | o | o |
- X0_VEC MaskLoad0(const ptrdiff_t idx, IN0_T* from, const ptrdiff_t places) {
+ X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from,
+ const ptrdiff_t places) {
return me()->MaskLoad0Impl(idx, from, places);
}
- X0_VEC MaskLoad0Impl(const ptrdiff_t idx, IN0_T* from,
+ X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_in0, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
@@ -218,12 +221,12 @@
return hn::MaskedLoad(mask, d_in0, from + idx);
}
- hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, IN1_T* from,
+ hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, const IN1_T* from,
const ptrdiff_t places) {
return me()->MaskLoad1Impl(idx, from, places);
}
- hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, IN1_T* from,
+ hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d_in1, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
@@ -284,7 +287,7 @@
};
template <class FUNC, typename IN_T, typename OUT_T>
-inline void Unroller(FUNC& f, IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y,
+inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y,
const ptrdiff_t n) {
auto xx = f.X0Init();
auto yy = f.YInit();
diff --git a/hwy/contrib/unroller/unroller_test.cc b/hwy/contrib/unroller/unroller_test.cc
index 7a13825..50f2671 100644
--- a/hwy/contrib/unroller/unroller_test.cc
+++ b/hwy/contrib/unroller/unroller_test.cc
@@ -148,7 +148,7 @@
hn::Vec<DI> YInitImpl() { return hn::Set(di, TI{-1}); }
- hn::Vec<D> MaskLoadImpl(const ptrdiff_t idx, T* from,
+ hn::Vec<D> MaskLoadImpl(const ptrdiff_t idx, const T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
@@ -236,7 +236,7 @@
hn::Vec<TT> YInitImpl() { return hn::Set(d, HighestValue<T>()); }
- hn::Vec<TT> MaskLoadImpl(const ptrdiff_t idx, T* from,
+ hn::Vec<TT> MaskLoadImpl(const ptrdiff_t idx, const T* from,
const ptrdiff_t places) {
auto mask = hn::FirstN(d, static_cast<size_t>(places));
auto maskneg = hn::Not(hn::FirstN(
@@ -452,7 +452,9 @@
FindUnit<T> cvtfn(ConvertScalarTo<T>(num - 1));
MakeSigned<T> idx = 0;
- Unroller(cvtfn, a, &idx, static_cast<ptrdiff_t>(num));
+ // Explicitly test input can be const
+ const T* const_a = a;
+ Unroller(cvtfn, const_a, &idx, static_cast<ptrdiff_t>(num));
HWY_ASSERT(static_cast<MakeUnsigned<T>>(idx) < num);
HWY_ASSERT(a[idx] == ConvertScalarTo<T>(num - 1));