Look at local sets with casts in OptimizeCasts.

For instance, the below
(local.set $a (ref.cast .. (local.get $ref)))
(local.get $a)

Can be converted to
(local.set $a (local.tee $temp (ref.cast .. (local.get $ref))))
(local.get $temp)

Currently, this change only applies the optimization if the
immediate child of the local.set is a cast and cannot look through
other operations. This is added as a TODO.
diff --git a/src/passes/OptimizeCasts.cpp b/src/passes/OptimizeCasts.cpp
index 8541101..b7970cc 100644
--- a/src/passes/OptimizeCasts.cpp
+++ b/src/passes/OptimizeCasts.cpp
@@ -75,10 +75,8 @@
 // TODO: Move casts earlier in a basic block as well, at least in traps-never-
 //       happen mode where we can assume they never fail.
 // TODO: Look past individual basic blocks?
-// TODO: Look at LocalSet as well and not just Get. That would add some overlap
-//       with the other passes mentioned above, but once we do things like
-//       moving casts earlier as in the other TODO, we'd be doing uniquely
-//       useful things with LocalSet here.
+// TODO: When looking at Local Sets, check fallthroughs/descendants for casts
+//       instead of just the immediate child
 //
 
 #include "ir/linear-execution.h"
@@ -104,8 +102,13 @@
   // This is tracked in each basic block, and cleared between them.
   std::unordered_map<Index, Expression*> mostCastedGets;
 
-  // For each most-downcasted local.get, a vector of other local.gets that could
-  // be replaced with gets of the downcasted value.
+  // Map local indices to the current downcasting of local.set to those indices.
+  //
+  // Also tracked in each basic block and cleared between them.
+  std::unordered_map<Index, Expression*> curCastedSets;
+
+  // For each most-downcasted local.get or local.set, a vector of other
+  // local.gets that could be replaced with gets of the downcasted value.
   //
   // This is tracked until the end of the entire function, and contains the
   // information we need to optimize later. That is, entries here are things we
@@ -114,22 +117,50 @@
 
   static void doNoteNonLinear(BestCastFinder* self, Expression** currp) {
     self->mostCastedGets.clear();
+    self->curCastedSets.clear();
   }
 
   void visitLocalSet(LocalSet* curr) {
     // Clear any information about this local; it has a new value here.
     mostCastedGets.erase(curr->index);
+
+    // This only checks the immediate child for casts. This should be extended
+    // to look deeper for casts
+    if (curr->value->dynCast<RefAs>() || curr->value->dynCast<RefCast>()) {
+      curCastedSets[curr->index] = curr->value;
+    } else {
+      // If the local.set doesn't use a cast, get rid of any old cast information
+      curCastedSets.erase(curr->index);
+    }
   }
 
   void visitLocalGet(LocalGet* curr) {
-    auto iter = mostCastedGets.find(curr->index);
-    if (iter != mostCastedGets.end()) {
-      auto* bestCast = iter->second;
+    auto getIter = mostCastedGets.find(curr->index);
+    auto setIter = curCastedSets.find(curr->index);
+
+    if (getIter != mostCastedGets.end()) {
+      auto* bestCast = getIter->second;
+      if (setIter != curCastedSets.end()) {
+        // Always use a cast in local.set if it is equal or better than
+        // a local.get since we know it is always before any gets that
+        // retrieve the set value from the index
+        if (bestCast->type == setIter->second->type ||
+            Type::isSubType(setIter->second->type, bestCast->type)) {
+          bestCast = setIter->second;
+        }
+      }
+
       if (curr->type != bestCast->type &&
           Type::isSubType(bestCast->type, curr->type)) {
         // The best cast has a more refined type, note that we want to use it.
         lessCastedGets[bestCast].push_back(curr);
       }
+    } else if (setIter != curCastedSets.end()) {
+      auto* setCast = setIter->second;
+      if (curr->type != setCast->type &&
+          Type::isSubType(setCast->type, curr->type)) {
+        lessCastedGets[setCast].push_back(curr);
+      }
     }
   }
 
diff --git a/test/lit/passes/optimize-casts.wast b/test/lit/passes/optimize-casts.wast
index 27e38f7..26163e8 100644
--- a/test/lit/passes/optimize-casts.wast
+++ b/test/lit/passes/optimize-casts.wast
@@ -261,6 +261,16 @@
   ;; CHECK-NEXT:  (drop
   ;; CHECK-NEXT:   (local.get $1)
   ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (local.set $x
+  ;; CHECK-NEXT:   (block (result (ref $A))
+  ;; CHECK-NEXT:    (ref.cast $A
+  ;; CHECK-NEXT:     (call $get)
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $x)
+  ;; CHECK-NEXT:  )
   ;; CHECK-NEXT: )
   (func $fallthrough (param $x (ref struct))
     (drop
@@ -274,9 +284,21 @@
     (drop
       (local.get $x)
     )
+    (local.set $x
+      ;; Cannot look through for sets at the moment
+      (block (result (ref $A))
+        (ref.cast $A
+          (call $get)
+        )
+      )
+    )
+    (drop
+      (local.get $x)
+    )
   )
 
   ;; CHECK:      (func $past-basic-block (type $ref|struct|_=>_none) (param $x (ref struct))
+  ;; CHECK-NEXT:  (local $1 (ref $A))
   ;; CHECK-NEXT:  (drop
   ;; CHECK-NEXT:   (ref.cast $A
   ;; CHECK-NEXT:    (local.get $x)
@@ -289,6 +311,23 @@
   ;; CHECK-NEXT:  (drop
   ;; CHECK-NEXT:   (local.get $x)
   ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (local.set $x
+  ;; CHECK-NEXT:   (local.tee $1
+  ;; CHECK-NEXT:    (ref.cast $A
+  ;; CHECK-NEXT:     (local.get $x)
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $1)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (if
+  ;; CHECK-NEXT:   (i32.const 0)
+  ;; CHECK-NEXT:   (return)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $x)
+  ;; CHECK-NEXT:  )
   ;; CHECK-NEXT: )
   (func $past-basic-block (param $x (ref struct))
     (drop
@@ -305,6 +344,22 @@
     (drop
       (local.get $x)
     )
+    (local.set $x
+      (ref.cast $A
+        (local.get $x)
+      )
+    )
+    (drop
+      (local.get $x)
+    )
+    ;; Same behaviour for sets.
+    (if
+      (i32.const 0)
+      (return)
+    )
+    (drop
+      (local.get $x)
+    )
   )
 
   ;; CHECK:      (func $multiple (type $ref|struct|_ref|struct|_=>_none) (param $x (ref struct)) (param $y (ref struct))
@@ -387,6 +442,165 @@
     )
   )
 
+  ;; CHECK:      (func $check-set-basic (type $ref|$A|_ref?|$A|_=>_none) (param $x (ref $A)) (param $y (ref null $A))
+  ;; CHECK-NEXT:  (local $a (ref struct))
+  ;; CHECK-NEXT:  (local $b structref)
+  ;; CHECK-NEXT:  (local $4 (ref $A))
+  ;; CHECK-NEXT:  (local $5 (ref $B))
+  ;; CHECK-NEXT:  (local.set $a
+  ;; CHECK-NEXT:   (ref.as_non_null
+  ;; CHECK-NEXT:    (local.get $x)
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $x)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (local.set $b
+  ;; CHECK-NEXT:   (local.tee $4
+  ;; CHECK-NEXT:    (ref.as_non_null
+  ;; CHECK-NEXT:     (local.get $y)
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $4)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.tee $a
+  ;; CHECK-NEXT:    (local.tee $5
+  ;; CHECK-NEXT:     (ref.cast $B
+  ;; CHECK-NEXT:      (local.get $x)
+  ;; CHECK-NEXT:     )
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $5)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT: )
+  (func $check-set-basic (param $x (ref $A)) (param $y (ref null $A))
+    (local $a (ref struct))
+    (local $b (ref null struct))
+    ;; Param is already non-nullable, so set won't do anything
+    (local.set $a
+      (ref.as_non_null
+        (local.get $x)
+      )
+    )
+    (drop
+      (local.get $x)
+    )
+    (local.set $b
+      (ref.as_non_null
+        (local.get $y)
+      )
+    )
+    (drop
+      (local.get $b)
+    )
+    (drop
+      (local.tee $a
+        (ref.cast $B
+          (local.get $x)
+        )
+      )
+    )
+    (drop
+      (local.get $a)
+    )
+  )
+
+  ;; CHECK:      (func $check-set-uses-most-casted (type $none_=>_none)
+  ;; CHECK-NEXT:  (local $a (ref struct))
+  ;; CHECK-NEXT:  (local $1 (ref $B))
+  ;; CHECK-NEXT:  (local $2 (ref $A))
+  ;; CHECK-NEXT:  (local $3 (ref $B))
+  ;; CHECK-NEXT:  (local.set $a
+  ;; CHECK-NEXT:   (local.tee $1
+  ;; CHECK-NEXT:    (ref.cast $B
+  ;; CHECK-NEXT:     (call $get)
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $1)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (ref.cast $B
+  ;; CHECK-NEXT:    (local.get $1)
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (ref.cast $B
+  ;; CHECK-NEXT:    (local.get $1)
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (local.set $a
+  ;; CHECK-NEXT:   (local.tee $2
+  ;; CHECK-NEXT:    (ref.cast $A
+  ;; CHECK-NEXT:     (call $get)
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.get $2)
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (local.tee $3
+  ;; CHECK-NEXT:    (ref.cast $B
+  ;; CHECK-NEXT:     (local.get $2)
+  ;; CHECK-NEXT:    )
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT:  (drop
+  ;; CHECK-NEXT:   (ref.cast $B
+  ;; CHECK-NEXT:    (local.get $3)
+  ;; CHECK-NEXT:   )
+  ;; CHECK-NEXT:  )
+  ;; CHECK-NEXT: )
+  (func $check-set-uses-most-casted
+    (local $a (ref struct))
+    (local.set $a
+      (ref.cast $B
+        (call $get)
+      )
+    )
+    (drop
+      (local.get $a)
+    )
+    (drop
+      ;; This will use the value from the cast in the above local.set
+      ;; since both casts are equally specific
+      (ref.cast $B
+        (local.get $a)
+      )
+    )
+    (drop
+      (ref.cast $A
+        (local.get $a)
+      )
+    )
+    (local.set $a
+      (ref.cast $A
+        (call $get)
+      )
+    )
+    (drop
+      (local.get $a)
+    )
+    (drop
+      ;; This cast is more specific than the one in the set, so it will be used henceforth
+      (ref.cast $B
+        (local.get $a)
+      )
+    )
+    (drop
+      (ref.cast $A
+        (local.get $a)
+      )
+    )
+  )
+
   ;; CHECK:      (func $get (type $none_=>_ref|struct|) (result (ref struct))
   ;; CHECK-NEXT:  (unreachable)
   ;; CHECK-NEXT: )