[utils] Add IsInBounds(index, size, max) helper

This CL adds a helper function that simplifies a bounds check pattern
that appears repeatedly in the code.

R=clemensh@chromium.org

Change-Id: I8c617515b34eb2d262d58a239a29c1515de2d92d
Reviewed-on: https://chromium-review.googlesource.com/c/1417611
Commit-Queue: Ben Titzer <titzer@chromium.org>
Reviewed-by: Clemens Hammacher <clemensh@chromium.org>
Cr-Commit-Position: refs/heads/master@{#58892}
diff --git a/src/compiler/wasm-compiler.cc b/src/compiler/wasm-compiler.cc
index f35064b..c687346 100644
--- a/src/compiler/wasm-compiler.cc
+++ b/src/compiler/wasm-compiler.cc
@@ -3313,9 +3313,7 @@
     return index;
   }
 
-  const bool statically_oob = access_size > env_->max_memory_size ||
-                              offset > env_->max_memory_size - access_size;
-  if (statically_oob) {
+  if (!IsInBounds(offset, access_size, env_->max_memory_size)) {
     // The access will be out of bounds, even for the largest memory.
     TrapIfEq32(wasm::kTrapMemOutOfBounds, Int32Constant(0), 0, position);
     return mcgraph()->IntPtrConstant(0);
diff --git a/src/utils.h b/src/utils.h
index 21b6d76..74bbec5 100644
--- a/src/utils.h
+++ b/src/utils.h
@@ -73,6 +73,12 @@
                                  static_cast<unsigned_T>(lower_limit));
 }
 
+// Checks if [index, index+length) is in range [0, max). Note that this check
+// works even if {index+length} would wrap around.
+inline constexpr bool IsInBounds(size_t index, size_t length, size_t max) {
+  return length <= max && index <= (max - length);
+}
+
 // X must be a power of 2.  Returns the number of trailing zeros.
 template <typename T,
           typename = typename std::enable_if<std::is_integral<T>::value>::type>
diff --git a/src/wasm/baseline/liftoff-compiler.cc b/src/wasm/baseline/liftoff-compiler.cc
index 87924e2..7bf358c 100644
--- a/src/wasm/baseline/liftoff-compiler.cc
+++ b/src/wasm/baseline/liftoff-compiler.cc
@@ -14,6 +14,7 @@
 #include "src/macro-assembler-inl.h"
 #include "src/objects/smi.h"
 #include "src/tracing/trace-event.h"
+#include "src/utils.h"
 #include "src/wasm/baseline/liftoff-assembler.h"
 #include "src/wasm/function-body-decoder-impl.h"
 #include "src/wasm/function-compiler.h"
@@ -1393,8 +1394,8 @@
   // (a jump to the trap was generated then); return false otherwise.
   bool BoundsCheckMem(FullDecoder* decoder, uint32_t access_size,
                       uint32_t offset, Register index, LiftoffRegList pinned) {
-    const bool statically_oob = access_size > env_->max_memory_size ||
-                                offset > env_->max_memory_size - access_size;
+    const bool statically_oob =
+        !IsInBounds(offset, access_size, env_->max_memory_size);
 
     if (!statically_oob &&
         (FLAG_wasm_no_bounds_checks || env_->use_trap_handler)) {
diff --git a/src/wasm/module-instantiate.cc b/src/wasm/module-instantiate.cc
index 7b69f31..04c0f3c 100644
--- a/src/wasm/module-instantiate.cc
+++ b/src/wasm/module-instantiate.cc
@@ -5,6 +5,7 @@
 #include "src/wasm/module-instantiate.h"
 #include "src/asmjs/asm-js.h"
 #include "src/property-descriptor.h"
+#include "src/utils.h"
 #include "src/wasm/js-to-wasm-wrapper-cache-inl.h"
 #include "src/wasm/module-compiler.h"
 #include "src/wasm/wasm-import-wrapper-cache-inl.h"
@@ -24,10 +25,6 @@
 byte* raw_buffer_ptr(MaybeHandle<JSArrayBuffer> buffer, int offset) {
   return static_cast<byte*>(buffer.ToHandleChecked()->backing_store()) + offset;
 }
-bool in_bounds(uint32_t offset, size_t size, size_t upper) {
-  return offset + size <= upper && offset + size >= offset;
-}
-
 }  // namespace
 
 // A helper class to simplify instantiating a module from a module object.
@@ -432,7 +429,7 @@
     DCHECK(elem_segment.table_index < table_instances_.size());
     uint32_t base = EvalUint32InitExpr(elem_segment.offset);
     size_t table_size = table_instances_[elem_segment.table_index].table_size;
-    if (!in_bounds(base, elem_segment.entries.size(), table_size)) {
+    if (!IsInBounds(base, elem_segment.entries.size(), table_size)) {
       thrower_->LinkError("table initializer is out of bounds");
       return {};
     }
@@ -444,7 +441,7 @@
   for (const WasmDataSegment& seg : module_->data_segments) {
     if (!seg.active) continue;
     uint32_t base = EvalUint32InitExpr(seg.dest_addr);
-    if (!in_bounds(base, seg.source.length(), instance->memory_size())) {
+    if (!IsInBounds(base, seg.source.length(), instance->memory_size())) {
       thrower_->LinkError("data segment is out of bounds");
       return {};
     }
@@ -623,7 +620,7 @@
     // Passive segments are not copied during instantiation.
     if (!segment.active) continue;
     uint32_t dest_offset = EvalUint32InitExpr(segment.dest_addr);
-    DCHECK(in_bounds(dest_offset, source_size, instance->memory_size()));
+    DCHECK(IsInBounds(dest_offset, source_size, instance->memory_size()));
     byte* dest = instance->memory_start() + dest_offset;
     const byte* src = wire_bytes.start() + segment.source.offset();
     memcpy(dest, src, source_size);
@@ -1464,7 +1461,7 @@
     uint32_t num_entries = static_cast<uint32_t>(elem_segment.entries.size());
     uint32_t index = elem_segment.table_index;
     TableInstance& table_instance = table_instances_[index];
-    DCHECK(in_bounds(base, num_entries, table_instance.table_size));
+    DCHECK(IsInBounds(base, num_entries, table_instance.table_size));
     for (uint32_t i = 0; i < num_entries; ++i) {
       uint32_t func_index = elem_segment.entries[i];
       const WasmFunction* function = &module_->functions[func_index];
diff --git a/src/wasm/wasm-interpreter.cc b/src/wasm/wasm-interpreter.cc
index cf58482..8e75ad2 100644
--- a/src/wasm/wasm-interpreter.cc
+++ b/src/wasm/wasm-interpreter.cc
@@ -1402,14 +1402,18 @@
 
   template <typename mtype>
   inline Address BoundsCheckMem(uint32_t offset, uint32_t index) {
-    size_t mem_size = instance_object_->memory_size();
-    if (sizeof(mtype) > mem_size) return kNullAddress;
-    if (offset > (mem_size - sizeof(mtype))) return kNullAddress;
-    if (index > (mem_size - sizeof(mtype) - offset)) return kNullAddress;
+    uint32_t effective_index = offset + index;
+    if (effective_index < index) {
+      return kNullAddress;  // wraparound => oob
+    }
+    if (!IsInBounds(effective_index, sizeof(mtype),
+                    instance_object_->memory_size())) {
+      return kNullAddress;  // oob
+    }
     // Compute the effective address of the access, making sure to condition
     // the index even in the in-bounds case.
     return reinterpret_cast<Address>(instance_object_->memory_start()) +
-           offset + (index & instance_object_->memory_mask());
+           (effective_index & instance_object_->memory_mask());
   }
 
   template <typename ctype, typename mtype>
diff --git a/src/wasm/wasm-objects.cc b/src/wasm/wasm-objects.cc
index 00342d8..77d370f 100644
--- a/src/wasm/wasm-objects.cc
+++ b/src/wasm/wasm-objects.cc
@@ -1415,6 +1415,7 @@
 namespace {
 void CopyTableEntriesImpl(Handle<WasmInstanceObject> instance, uint32_t dst,
                           uint32_t src, uint32_t count) {
+  DCHECK(IsInBounds(dst, count, instance->indirect_function_table_size()));
   if (src < dst) {
     for (uint32_t i = count; i > 0; i--) {
       auto to_entry = IndirectFunctionTableEntry(instance, dst + i - 1);
@@ -1439,8 +1440,8 @@
   CHECK_EQ(0, table_index);  // TODO(titzer): multiple tables in TableCopy
   if (count == 0) return true;  // no-op
   auto max = instance->indirect_function_table_size();
-  if (dst > max || count > (max - dst)) return false;  // out-of-bounds
-  if (src > max || count > (max - src)) return false;  // out-of-bounds
+  if (!IsInBounds(dst, count, max)) return false;
+  if (!IsInBounds(src, count, max)) return false;
   if (dst == src) return true;                         // no-op
 
   if (!instance->has_table_object()) {
diff --git a/test/unittests/utils-unittest.cc b/test/unittests/utils-unittest.cc
index 0a37e84..c8032d1 100644
--- a/test/unittests/utils-unittest.cc
+++ b/test/unittests/utils-unittest.cc
@@ -132,5 +132,73 @@
   EXPECT_FALSE(PassesFilter(CStrVector(""), CStrVector("a")));
 }
 
+TEST(UtilsTest, IsInBounds) {
+// for column consistency and terseness
+#define INB(x, y, z) EXPECT_TRUE(IsInBounds(x, y, z))
+#define OOB(x, y, z) EXPECT_FALSE(IsInBounds(x, y, z))
+  INB(0, 0, 1);
+  INB(0, 1, 1);
+  INB(1, 0, 1);
+
+  OOB(0, 2, 1);
+  OOB(2, 0, 1);
+
+  INB(0, 0, 2);
+  INB(0, 1, 2);
+  INB(0, 2, 2);
+
+  INB(0, 0, 2);
+  INB(1, 0, 2);
+  INB(2, 0, 2);
+
+  OOB(0, 3, 2);
+  OOB(3, 0, 2);
+
+  INB(0, 1, 2);
+  INB(1, 1, 2);
+
+  OOB(1, 2, 2);
+  OOB(2, 1, 2);
+
+  const size_t max = std::numeric_limits<size_t>::max();
+  const size_t half = max / 2;
+
+  // limit cases.
+  INB(0, 0, max);
+  INB(0, 1, max);
+  INB(1, 0, max);
+  INB(max, 0, max);
+  INB(0, max, max);
+  INB(max - 1, 0, max);
+  INB(0, max - 1, max);
+  INB(max - 1, 1, max);
+  INB(1, max - 1, max);
+
+  INB(half, half, max);
+  INB(half + 1, half, max);
+  INB(half, half + 1, max);
+
+  OOB(max, 0, 0);
+  OOB(0, max, 0);
+  OOB(max, 0, 1);
+  OOB(0, max, 1);
+  OOB(max, 0, 2);
+  OOB(0, max, 2);
+
+  OOB(max, 0, max - 1);
+  OOB(0, max, max - 1);
+
+  // wraparound cases.
+  OOB(max, 1, max);
+  OOB(1, max, max);
+  OOB(max - 1, 2, max);
+  OOB(2, max - 1, max);
+  OOB(half + 1, half + 1, max);
+  OOB(half + 1, half + 1, max);
+
+#undef INB
+#undef OOB
+}
+
 }  // namespace internal
 }  // namespace v8