/*
 * Copyright 2016 WebAssembly Community Group participants
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//
// Optimize combinations of instructions
//

#include <algorithm>

#include <wasm.h>
#include <pass.h>
#include <wasm-s-parser.h>
#include <support/threads.h>
#include <ast_utils.h>

namespace wasm {

Name I32_EXPR  = "i32.expr",
     I64_EXPR  = "i64.expr",
     F32_EXPR  = "f32.expr",
     F64_EXPR  = "f64.expr",
     ANY_EXPR  = "any.expr";

// A pattern
struct Pattern {
  Expression* input;
  Expression* output;

  Pattern(Expression* input, Expression* output) : input(input), output(output) {}
};

// Database of patterns
struct PatternDatabase {
  Module wasm;

  char* input;

  std::map<Expression::Id, std::vector<Pattern>> patternMap; // root expression id => list of all patterns for it TODO optimize more

  PatternDatabase() {
    // generate module
    input = strdup(
      #include "OptimizeInstructions.wast.processed"
    );
    try {
      SExpressionParser parser(input);
      Element& root = *parser.root;
      SExpressionWasmBuilder builder(wasm, *root[0]);
      // parse module form
      auto* func = wasm.getFunction("patterns");
      auto* body = func->body->cast<Block>();
      for (auto* item : body->list) {
        auto* pair = item->cast<Block>();
        patternMap[pair->list[0]->_id].emplace_back(pair->list[0], pair->list[1]);
      }
    } catch (ParseException& p) {
      p.dump(std::cerr);
      Fatal() << "error in parsing wasm binary";
    }
  }

  ~PatternDatabase() {
    free(input);
  };
};

static PatternDatabase* database = nullptr;

struct DatabaseEnsurer {
  DatabaseEnsurer() {
    assert(!database);
    database = new PatternDatabase;
  }
};

// Check for matches and apply them
struct Match {
  Module& wasm;
  Pattern& pattern;

  Match(Module& wasm, Pattern& pattern) : wasm(wasm), pattern(pattern) {}

  std::vector<Expression*> wildcards; // id in i32.any(id) etc. => the expression it represents in this match

  // Comparing/checking

  // Check if we can match to this pattern, updating ourselves with the info if so
  bool check(Expression* seen) {
    // compare seen to the pattern input, doing a special operation for our "wildcards"
    assert(wildcards.size() == 0);
    return ExpressionAnalyzer::flexibleEqual(pattern.input, seen, *this);
  }

  bool compare(Expression* subInput, Expression* subSeen) {
    CallImport* call = subInput->dynCast<CallImport>();
    if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return false;
    Index index = call->operands[0]->cast<Const>()->value.geti32();
    // handle our special functions
    auto checkMatch = [&](WasmType type) {
      if (type != none && subSeen->type != type) return false;
      while (index >= wildcards.size()) {
        wildcards.push_back(nullptr);
      }
      if (!wildcards[index]) {
        // new wildcard
        wildcards[index] = subSeen; // NB: no need to copy
        return true;
      } else {
        // We are seeing this index for a second or later time, check it matches
        return ExpressionAnalyzer::equal(subSeen, wildcards[index]);
      };
    };
    if (call->target == I32_EXPR) {
      if (checkMatch(i32)) return true;
    } else if (call->target == I64_EXPR) {
      if (checkMatch(i64)) return true;
    } else if (call->target == F32_EXPR) {
      if (checkMatch(f32)) return true;
    } else if (call->target == F64_EXPR) {
      if (checkMatch(f64)) return true;
    } else if (call->target == ANY_EXPR) {
      if (checkMatch(none)) return true;
    }
    return false;
  }

  // Applying/copying

  // Apply the match, generate an output expression from the matched input, performing substitutions as necessary
  Expression* apply() {
    return ExpressionManipulator::flexibleCopy(pattern.output, wasm, *this);
  }

  // When copying a wildcard, perform the substitution.
  // TODO: we can reuse nodes, not copying a wildcard when it appears just once, and we can reuse other individual nodes when they are discarded anyhow.
  Expression* copy(Expression* curr) {
    CallImport* call = curr->dynCast<CallImport>();
    if (!call || call->operands.size() != 1 || call->operands[0]->type != i32 || !call->operands[0]->is<Const>()) return nullptr;
    Index index = call->operands[0]->cast<Const>()->value.geti32();
    // handle our special functions
    if (call->target == I32_EXPR || call->target == I64_EXPR || call->target == F32_EXPR || call->target == F64_EXPR || call->target == ANY_EXPR) {
      return ExpressionManipulator::copy(wildcards.at(index), wasm);
    }
    return nullptr;
  }
};

// Main pass class
struct OptimizeInstructions : public WalkerPass<PostWalker<OptimizeInstructions, UnifiedExpressionVisitor<OptimizeInstructions>>> {
  bool isFunctionParallel() override { return true; }

  Pass* create() override { return new OptimizeInstructions; }

  void prepareToRun(PassRunner* runner, Module* module) override {
    static DatabaseEnsurer ensurer;
  }

  void visitExpression(Expression* curr) {
    // we may be able to apply multiple patterns, one may open opportunities that look deeper NB: patterns must not have cycles
    while (1) {
      auto* handOptimized = handOptimize(curr);
      if (handOptimized) {
        curr = handOptimized;
        replaceCurrent(curr);
        continue;
      }
      auto iter = database->patternMap.find(curr->_id);
      if (iter == database->patternMap.end()) return;
      auto& patterns = iter->second;
      bool more = false;
      for (auto& pattern : patterns) {
        Match match(*getModule(), pattern);
        if (match.check(curr)) {
          curr = match.apply();
          replaceCurrent(curr);
          more = true;
          break; // exit pattern for loop, return to main while loop
        }
      }
      if (!more) break;
    }
  }

  // Optimizations that don't yet fit in the pattern DSL, but could be eventually maybe
  Expression* handOptimize(Expression* curr) {
    if (auto* binary = curr->dynCast<Binary>()) {
      // pattern match a load of 8 bits and a sign extend using a shl of 24 then shr_s of 24 as well, etc.
      if (binary->op == BinaryOp::ShrSInt32 && binary->right->is<Const>()) {
        auto shifts = binary->right->cast<Const>()->value.geti32();
        if (shifts == 24 || shifts == 16) {
          auto* left = binary->left->dynCast<Binary>();
          if (left && left->op == ShlInt32 && left->right->is<Const>() && left->right->cast<Const>()->value.geti32() == shifts) {
            auto* load = left->left->dynCast<Load>();
            if (load && ((load->bytes == 1 && shifts == 24) || (load->bytes == 2 && shifts == 16))) {
              load->signed_ = true;
              return load;
            }
          }
        }
      } else if (binary->op == EqInt32) {
        if (auto* c = binary->right->dynCast<Const>()) {
          if (c->value.geti32() == 0) {
            // equal 0 => eqz
            return Builder(*getModule()).makeUnary(EqZInt32, binary->left);
          }
        }
        if (auto* c = binary->left->dynCast<Const>()) {
          if (c->value.geti32() == 0) {
            // equal 0 => eqz
            return Builder(*getModule()).makeUnary(EqZInt32, binary->right);
          }
        }
      } else if (binary->op == AndInt32) {
        if (auto* right = binary->right->dynCast<Const>()) {
          if (right->type == i32) {
            auto mask = right->value.geti32();
            // and with -1 does nothing (common in asm.js output)
            if (mask == -1) {
              return binary->left;
            }
            // small loads do not need to be masted, the load itself masks
            if (auto* load = binary->left->dynCast<Load>()) {
              if ((load->bytes == 1 && mask == 0xff) ||
                  (load->bytes == 2 && mask == 0xffff)) {
                load->signed_ = false;
                return load;
              }
            }
          }
        }
      }
    } else if (auto* unary = curr->dynCast<Unary>()) {
      // de-morgan's laws
      if (unary->op == EqZInt32) {
        if (auto* inner = unary->value->dynCast<Binary>()) {
          switch (inner->op) {
            case EqInt32:  inner->op = NeInt32;  return inner;
            case NeInt32:  inner->op = EqInt32;  return inner;
            case LtSInt32: inner->op = GeSInt32; return inner;
            case LtUInt32: inner->op = GeUInt32; return inner;
            case LeSInt32: inner->op = GtSInt32; return inner;
            case LeUInt32: inner->op = GtUInt32; return inner;
            case GtSInt32: inner->op = LeSInt32; return inner;
            case GtUInt32: inner->op = LeUInt32; return inner;
            case GeSInt32: inner->op = LtSInt32; return inner;
            case GeUInt32: inner->op = LtUInt32; return inner;

            case EqInt64:  inner->op = NeInt64;  return inner;
            case NeInt64:  inner->op = EqInt64;  return inner;
            case LtSInt64: inner->op = GeSInt64; return inner;
            case LtUInt64: inner->op = GeUInt64; return inner;
            case LeSInt64: inner->op = GtSInt64; return inner;
            case LeUInt64: inner->op = GtUInt64; return inner;
            case GtSInt64: inner->op = LeSInt64; return inner;
            case GtUInt64: inner->op = LeUInt64; return inner;
            case GeSInt64: inner->op = LtSInt64; return inner;
            case GeUInt64: inner->op = LtUInt64; return inner;

            case EqFloat32: inner->op = NeFloat32; return inner;
            case NeFloat32: inner->op = EqFloat32; return inner;

            case EqFloat64: inner->op = NeFloat64; return inner;
            case NeFloat64: inner->op = EqFloat64; return inner;

            default: {}
          }
        }
      }
    } else if (auto* set = curr->dynCast<SetGlobal>()) {
      // optimize out a set of a get
      auto* get = set->value->dynCast<GetGlobal>();
      if (get && get->name == set->name) {
        ExpressionManipulator::nop(curr);
      }
    } else if (auto* iff = curr->dynCast<If>()) {
      iff->condition = optimizeBoolean(iff->condition);
      if (iff->ifFalse) {
        if (auto* unary = iff->condition->dynCast<Unary>()) {
          if (unary->op == EqZInt32) {
            // flip if-else arms to get rid of an eqz
            iff->condition = unary->value;
            std::swap(iff->ifTrue, iff->ifFalse);
          }
        }
      }
    } else if (auto* select = curr->dynCast<Select>()) {
      select->condition = optimizeBoolean(select->condition);
      auto* condition = select->condition->dynCast<Unary>();
      if (condition && condition->op == EqZInt32) {
        // flip select to remove eqz, if we can reorder
        EffectAnalyzer ifTrue(select->ifTrue);
        EffectAnalyzer ifFalse(select->ifFalse);
        if (!ifTrue.invalidates(ifFalse)) {
          select->condition = condition->value;
          std::swap(select->ifTrue, select->ifFalse);
        }
      }
    } else if (auto* br = curr->dynCast<Break>()) {
      if (br->condition) {
        br->condition = optimizeBoolean(br->condition);
      }
    } else if (auto* store = curr->dynCast<Store>()) {
      // stores of fewer bits truncates anyhow
      if (auto* value = store->value->dynCast<Binary>()) {
        if (value->op == AndInt32) {
          if (auto* right = value->right->dynCast<Const>()) {
            if (right->type == i32) {
              auto mask = right->value.geti32();
              if ((store->bytes == 1 && mask == 0xff) ||
                  (store->bytes == 2 && mask == 0xffff)) {
                store->value = value->left;
              }
            }
          }
        }
      }
    }
    return nullptr;
  }

private:

  Expression* optimizeBoolean(Expression* boolean) {
    auto* condition = boolean->dynCast<Unary>();
    if (condition && condition->op == EqZInt32) {
      auto* condition2 = condition->value->dynCast<Unary>();
      if (condition2 && condition2->op == EqZInt32) {
        // double eqz
        return condition2->value;
      }
    }
    return boolean;
  }
};

Pass *createOptimizeInstructionsPass() {
  return new OptimizeInstructions();
}

} // namespace wasm
