blob: 89823f14e485da95cdaef6ad63fcc4c34c5ca3e4 [file] [log] [blame]
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#include "utils/regex-match.h"
#include <memory>
#include "annotator/types.h"
#ifndef TC3_DISABLE_LUA
#include "utils/lua-utils.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "third_party/lua/lua/v5_4/src/lauxlib.h"
#include "third_party/lua/lua/v5_4/src/lualib.h"
#ifdef __cplusplus
}
#endif
#endif
namespace libtextclassifier3 {
namespace {
#ifndef TC3_DISABLE_LUA
// Provide a lua environment for running regex match post verification.
// It sets up and exposes the match data as well as the context.
class LuaVerifier : public LuaEnvironment {
public:
static std::unique_ptr<LuaVerifier> Create(
const std::string& context, const std::string& verifier_code,
const UniLib::RegexMatcher* matcher);
bool Verify(bool* result);
private:
explicit LuaVerifier(const std::string& context,
const std::string& verifier_code,
const UniLib::RegexMatcher* matcher)
: context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
bool Initialize();
// Provides details of a capturing group to lua.
int GetCapturingGroup();
const std::string& context_;
const std::string& verifier_code_;
const UniLib::RegexMatcher* matcher_;
};
bool LuaVerifier::Initialize() {
// Run protected to not lua panic in case of setup failure.
return RunProtected([this] {
LoadDefaultLibraries();
// Expose context of the match as `context` global variable.
PushString(context_);
lua_setglobal(state_, "context");
// Expose match array as `match` global variable.
// Each entry `match[i]` exposes the ith capturing group as:
// * `begin`: span start
// * `end`: span end
// * `text`: the text
PushLazyObject(&LuaVerifier::GetCapturingGroup);
lua_setglobal(state_, "match");
return LUA_OK;
}) == LUA_OK;
}
std::unique_ptr<LuaVerifier> LuaVerifier::Create(
const std::string& context, const std::string& verifier_code,
const UniLib::RegexMatcher* matcher) {
auto verifier = std::unique_ptr<LuaVerifier>(
new LuaVerifier(context, verifier_code, matcher));
if (!verifier->Initialize()) {
TC3_LOG(ERROR) << "Could not initialize lua environment.";
return nullptr;
}
return verifier;
}
int LuaVerifier::GetCapturingGroup() {
if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
<< lua_type(state_, /*idx=*/-1);
lua_error(state_);
return 0;
}
const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
int status = UniLib::RegexMatcher::kNoError;
const CodepointSpan span = {matcher_->Start(group_id, &status),
matcher_->End(group_id, &status)};
std::string text = matcher_->Group(group_id, &status).ToUTF8String();
if (status != UniLib::RegexMatcher::kNoError) {
TC3_LOG(ERROR) << "Could not extract span from capturing group.";
lua_error(state_);
return 0;
}
lua_newtable(state_);
lua_pushinteger(state_, span.first);
lua_setfield(state_, /*idx=*/-2, "begin");
lua_pushinteger(state_, span.second);
lua_setfield(state_, /*idx=*/-2, "end");
PushString(text);
lua_setfield(state_, /*idx=*/-2, "text");
return 1;
}
bool LuaVerifier::Verify(bool* result) {
if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
/*name=*/nullptr) != LUA_OK) {
TC3_LOG(ERROR) << "Could not load verifier snippet.";
return false;
}
if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
TC3_LOG(ERROR) << "Could not run verifier snippet.";
return false;
}
if (RunProtected(
[this, result] {
if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
TC3_LOG(ERROR) << "Unexpected verification result type: "
<< lua_type(state_, /*idx=*/-1);
lua_error(state_);
return LUA_ERRRUN;
}
*result = lua_toboolean(state_, /*idx=*/-1);
return LUA_OK;
},
/*num_args=*/1) != LUA_OK) {
TC3_LOG(ERROR) << "Could not read lua result.";
return false;
}
return true;
}
#endif // TC3_DISABLE_LUA
} // namespace
Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
const int group_id) {
int status = UniLib::RegexMatcher::kNoError;
std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
return Optional<std::string>();
}
return Optional<std::string>(group_text);
}
bool VerifyMatch(const std::string& context,
const UniLib::RegexMatcher* matcher,
const std::string& lua_verifier_code) {
bool status = false;
#ifndef TC3_DISABLE_LUA
auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
if (verifier == nullptr) {
TC3_LOG(ERROR) << "Could not create verifier.";
return false;
}
if (!verifier->Verify(&status)) {
TC3_LOG(ERROR) << "Could not create verifier.";
return false;
}
#endif // TC3_DISABLE_LUA
return status;
}
} // namespace libtextclassifier3