blob: 1e5db6b76559873007d80867fd6ec34691771bd3 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// Implementations of specializations of TensorView<> for all supported tensor
// data types
#include "ml/tensor_view.h"
namespace ml {
using ::chromeos::machine_learning::mojom::FloatList;
using ::chromeos::machine_learning::mojom::Int64List;
using ::chromeos::machine_learning::mojom::ValueList;
template <>
std::vector<int64_t>& TensorView<int64_t>::GetValues() {
return tensor_->data->get_int64_list()->value;
}
template <>
bool TensorView<int64_t>::IsValidType() const {
return tensor_->data->which() == ValueList::Tag::INT64_LIST;
}
template <>
void TensorView<int64_t>::AllocateValues() {
tensor_->data->set_int64_list(Int64List::New());
// TODO(hidehiko): assigning std::vector<>() to `value` is unneeded
// on libmojo uprev. Remove them after the uprev.
tensor_->data->get_int64_list()->value = std::vector<int64_t>();
}
template <>
std::vector<double>& TensorView<double>::GetValues() {
return tensor_->data->get_float_list()->value;
}
template <>
bool TensorView<double>::IsValidType() const {
return tensor_->data->which() == ValueList::Tag::FLOAT_LIST;
}
template <>
void TensorView<double>::AllocateValues() {
tensor_->data->set_float_list(FloatList::New());
// TODO(hidehiko): assigning std::vector<>() to `value` is unneeded
// on libmojo uprev. Remove them after the uprev.
tensor_->data->get_float_list()->value = std::vector<double>();
}
} // namespace ml