blob: adb3d015646d968d303b51b23b61b7a17c6c172b [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/enterprise/profile_management/saml_response_parser.h"
#include <string>
#include <vector>
#include "base/base64.h"
#include "base/containers/flat_map.h"
#include "base/task/sequenced_task_runner.h"
#include "content/public/browser/navigation_handle.h"
#include "content/public/browser/navigation_throttle.h"
#include "mojo/public/cpp/system/simple_watcher.h"
#include "services/data_decoder/public/cpp/data_decoder.h"
namespace profile_management {
namespace {
constexpr char kChildrenKey[] = "children";
// Finds the value of the attribute named SAMLResponse from `dict` by doing a
// depth-first search in `dict`.
// Returns a pointer to the value or nullptr if nothing was found.
const std::string* GetSamlResponseFromDict(const base::Value::Dict& dict) {
const std::string* attribute_name =
dict.FindStringByDottedPath("attributes.name");
if (attribute_name && *attribute_name == "SAMLResponse") {
return dict.FindStringByDottedPath("attributes.value");
}
const base::Value::List* children = dict.FindList(kChildrenKey);
if (!children) {
return nullptr;
}
for (const auto& child : *children) {
if (!child.is_dict()) {
continue;
}
const auto* child_saml_response = GetSamlResponseFromDict(child.GetDict());
if (child_saml_response) {
return child_saml_response;
}
}
return nullptr;
}
// Finds the value of `attribute` from `dict` by doing a depth-first search in
// `dict`. First we find an attribute with "Name" == `attribute`, then we try
// to find a children of that attribute with the tag AttributeValue. That
// attribute itself has children, and we look for the value the children with a
// "text" key.
// Returns a pointer to the value or nullptr if nothing was found.
const std::string* GetAttributeValue(const base::Value::Dict& dict,
const std::string& attribute) {
const std::string* attribute_name =
dict.FindStringByDottedPath("attributes.Name");
if (attribute_name && *attribute_name == attribute) {
const base::Value::List* attribute_children = dict.FindList(kChildrenKey);
if (!attribute_children) {
return nullptr;
}
for (const auto& attribute_child : *attribute_children) {
const std::string* tag = attribute_child.FindStringPath("tag");
if (!tag || *tag != "AttributeValue") {
continue;
}
const base::Value* attribute_value_children =
attribute_child.FindListPath(kChildrenKey);
if (!attribute_value_children) {
continue;
}
for (const auto& attribute_value_item :
attribute_value_children->GetList()) {
if (attribute_value_item.is_dict()) {
return attribute_value_item.FindStringPath("text");
}
}
}
}
const base::Value::List* children = dict.FindList(kChildrenKey);
if (!children) {
return nullptr;
}
for (const auto& child : *children) {
if (!child.is_dict()) {
continue;
}
const std::string* value = GetAttributeValue(child.GetDict(), attribute);
if (value) {
return value;
}
}
return nullptr;
}
} // namespace
SAMLResponseParser::SAMLResponseParser(
std::vector<std::string>&& attributes,
const mojo::DataPipeConsumerHandle& body,
base::OnceCallback<void(base::flat_map<std::string, std::string>)> callback)
: attributes_(std::move(attributes)),
body_(body),
body_consumer_watcher_(FROM_HERE,
mojo::SimpleWatcher::ArmingPolicy::MANUAL,
base::SequencedTaskRunner::GetCurrentDefault()),
callback_(std::move(callback)) {
body_consumer_watcher_.Watch(
body_, MOJO_HANDLE_SIGNAL_READABLE | MOJO_HANDLE_SIGNAL_PEER_CLOSED,
base::BindRepeating(&SAMLResponseParser::OnBodyReady,
weak_ptr_factory_.GetWeakPtr()));
body_consumer_watcher_.ArmOrNotify();
}
SAMLResponseParser::~SAMLResponseParser() = default;
void SAMLResponseParser::OnBodyReady(MojoResult) {
uint32_t num_bytes = 0;
MojoResult result =
body_.ReadData(nullptr, &num_bytes, MOJO_READ_DATA_FLAG_QUERY);
switch (result) {
case MOJO_RESULT_OK: {
std::string response(num_bytes, '\0');
body_.ReadData(response.data(), &num_bytes, MOJO_READ_DATA_FLAG_PEEK);
data_decoder::DataDecoder::ParseXmlIsolated(
response,
data_decoder::mojom::XmlParser::WhitespaceBehavior::
kPreserveSignificant,
base::BindOnce(&SAMLResponseParser::GetSamlResponse,
weak_ptr_factory_.GetWeakPtr()));
break;
}
case MOJO_RESULT_FAILED_PRECONDITION:
DCHECK_EQ(num_bytes, 0u);
break;
case MOJO_RESULT_SHOULD_WAIT:
body_consumer_watcher_.ArmOrNotify();
return;
default:
NOTREACHED();
return;
}
// Stop watching for response body changes.
body_consumer_watcher_.Cancel();
}
void SAMLResponseParser::GetSamlResponse(
data_decoder::DataDecoder::ValueOrError value_or_error) {
if (!value_or_error.has_value() || !value_or_error.value().is_dict()) {
std::move(callback_).Run(base::flat_map<std::string, std::string>());
return;
}
const auto* saml_response =
GetSamlResponseFromDict(value_or_error.value().GetDict());
if (!saml_response) {
std::move(callback_).Run(base::flat_map<std::string, std::string>());
return;
}
std::string decoded_response_value;
if (!base::Base64Decode(*saml_response, &decoded_response_value)) {
std::move(callback_).Run(base::flat_map<std::string, std::string>());
return;
}
data_decoder::DataDecoder::ParseXmlIsolated(
decoded_response_value,
data_decoder::mojom::XmlParser::WhitespaceBehavior::kPreserveSignificant,
base::BindOnce(&SAMLResponseParser::GetAttributesFromSAMLResponse,
weak_ptr_factory_.GetWeakPtr()));
}
void SAMLResponseParser::GetAttributesFromSAMLResponse(
data_decoder::DataDecoder::ValueOrError value_or_error) {
base::flat_map<std::string, std::string> result;
if (!value_or_error.has_value() || !value_or_error.value().is_dict()) {
std::move(callback_).Run(std::move(result));
return;
}
for (const auto& attribute : attributes_) {
const auto* value =
GetAttributeValue(value_or_error.value().GetDict(), attribute);
if (value) {
result[attribute] = *value;
}
}
std::move(callback_).Run(std::move(result));
}
} // namespace profile_management