blob: 890da7461f7ea17dfeb3c005a33e899a8cdf70b6 [file] [log] [blame]
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef LANG_ID_NN_PARAMS_H_
#define LANG_ID_NN_PARAMS_H_
#include "base.h"
#include "embedding_network_params.h"
#include "float16.h"
namespace chrome_lang_id {
class LangIdNNParams : public EmbeddingNetworkParams {
public:
~LangIdNNParams() override {}
// Access methods for embeddings:
int embeddings_size() const override { return 6; }
int embeddings_num_rows(int i) const override {
return kEmbeddingsNumRows[i];
}
int embeddings_num_cols(int i) const override {
return kEmbeddingsNumCols[i];
}
const void *embeddings_weights(int i) const override {
return embeddings_weights_[i];
}
QuantizationType embeddings_quant_type(int i) const override {
return QuantizationType::UINT8;
}
const float16 *embeddings_quant_scales(int i) const override {
return embeddings_quant_scales_[i];
}
// Access methods for hidden:
int hidden_size() const override { return 1; }
int hidden_num_rows(int i) const override { return kHiddenNumRows[i]; }
int hidden_num_cols(int i) const override { return kHiddenNumCols[i]; }
const void *hidden_weights(int i) const override {
return hidden_weights_[i];
}
// Access methods for hidden_bias:
int hidden_bias_size() const override { return 1; }
int hidden_bias_num_rows(int i) const override {
return kHiddenBiasNumRows[i];
}
int hidden_bias_num_cols(int i) const override {
return kHiddenBiasNumCols[i];
}
const void *hidden_bias_weights(int i) const override {
return hidden_bias_weights_[i];
}
// Access methods for softmax:
int softmax_size() const override { return 1; }
int softmax_num_rows(int i) const override { return kSoftmaxNumRows[i]; }
int softmax_num_cols(int i) const override { return kSoftmaxNumCols[i]; }
const void *softmax_weights(int i) const override {
return softmax_weights_[i];
}
// Access methods for softmax_bias:
int softmax_bias_size() const override { return 1; }
int softmax_bias_num_rows(int i) const override {
return kSoftmaxBiasNumRows[i];
}
int softmax_bias_num_cols(int i) const override {
return kSoftmaxBiasNumCols[i];
}
const void *softmax_bias_weights(int i) const override {
return softmax_bias_weights_[i];
}
// Access methods for embedding_dim:
int embedding_dim_size() const override { return 6; }
int32 embedding_dim(int i) const override { return kEmbeddingDimValues[i]; }
// Access methods for embedding_num_features:
int embedding_num_features_size() const override { return 6; }
int32 embedding_num_features(int i) const override {
return kEmbeddingNumFeaturesValues[i];
}
// Access methods for embedding_features_domain_size:
int embedding_features_domain_size_size() const override { return 6; }
int32 embedding_features_domain_size(int i) const override {
return kEmbeddingFeaturesDomainSizeValues[i];
}
// Access methods for concat_offset:
int concat_offset_size() const override { return 6; }
int32 concat_offset(int i) const override { return kConcatOffsetValues[i]; }
// Access methods for concat_layer_size:
bool has_concat_layer_size() const override { return true; }
int32 concat_layer_size() const override { return 80; }
// Access methods for is_precomputed:
bool has_is_precomputed() const override { return false; }
bool is_precomputed() const override { return false; }
private:
// Private fields for embeddings:
static const int kEmbeddingsNumRows[];
static const int kEmbeddingsNumCols[];
static const uint8 kEmbeddingsWeights0[];
static const uint8 kEmbeddingsWeights1[];
static const uint8 kEmbeddingsWeights2[];
static const uint8 kEmbeddingsWeights3[];
static const uint8 kEmbeddingsWeights4[];
static const uint8 kEmbeddingsWeights5[];
const void *embeddings_weights_[6] = {
kEmbeddingsWeights0, kEmbeddingsWeights1, kEmbeddingsWeights2,
kEmbeddingsWeights3, kEmbeddingsWeights4, kEmbeddingsWeights5};
static const float16 kEmbeddingsQuantScales0[];
static const float16 kEmbeddingsQuantScales1[];
static const float16 kEmbeddingsQuantScales2[];
static const float16 kEmbeddingsQuantScales3[];
static const float16 kEmbeddingsQuantScales4[];
static const float16 kEmbeddingsQuantScales5[];
const float16 *embeddings_quant_scales_[6] = {
kEmbeddingsQuantScales0, kEmbeddingsQuantScales1,
kEmbeddingsQuantScales2, kEmbeddingsQuantScales3,
kEmbeddingsQuantScales4, kEmbeddingsQuantScales5};
// Private fields for hidden:
static const int kHiddenNumRows[];
static const int kHiddenNumCols[];
static const float kHiddenWeights0[];
const void *hidden_weights_[1] = {kHiddenWeights0};
// Private fields for hidden_bias:
static const int kHiddenBiasNumRows[];
static const int kHiddenBiasNumCols[];
static const float kHiddenBiasWeights0[];
const void *hidden_bias_weights_[1] = {kHiddenBiasWeights0};
// Private fields for softmax:
static const int kSoftmaxNumRows[];
static const int kSoftmaxNumCols[];
static const float kSoftmaxWeights0[];
const void *softmax_weights_[1] = {kSoftmaxWeights0};
// Private fields for softmax_bias:
static const int kSoftmaxBiasNumRows[];
static const int kSoftmaxBiasNumCols[];
static const float kSoftmaxBiasWeights0[];
const void *softmax_bias_weights_[1] = {kSoftmaxBiasWeights0};
// Private fields for embedding_dim:
static const int32 kEmbeddingDimValues[];
// Private fields for embedding_num_features:
static const int32 kEmbeddingNumFeaturesValues[];
// Private fields for embedding_features_domain_size:
static const int32 kEmbeddingFeaturesDomainSizeValues[];
// Private fields for concat_offset:
static const int32 kConcatOffsetValues[];
}; // class LangIdNNParams
} // namespace chrome_lang_id
#endif // LANG_ID_NN_PARAMS_H_