blob: d8a84008224eba8c06ffec499a47a8be9e812040 [file] [log] [blame]
// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "syzygy/sampler/sampled_module_cache.h"
#include <psapi.h>
#include "base/strings/stringprintf.h"
#include "syzygy/common/align.h"
#include "syzygy/common/com_utils.h"
#include "syzygy/common/path_util.h"
#include "syzygy/pe/pe_file.h"
#include "syzygy/trace/common/clock.h"
namespace sampler {
namespace {
typedef SampledModuleCache::Process::ModuleMap ModuleMap;
// Gets the path associated with a module.
bool GetModulePath(HANDLE process, HMODULE module, base::FilePath* path) {
DCHECK(process != INVALID_HANDLE_VALUE);
DCHECK(module != INVALID_HANDLE_VALUE);
DCHECK(path != NULL);
std::vector<wchar_t> filename(1024);
while (true) {
DWORD length = ::GetModuleFileNameExW(process,
module,
filename.data(),
filename.size());
if (length == 0) {
DWORD error = ::GetLastError();
LOG(ERROR) << "GetModuleFileNameExW failed: " << common::LogWe(error);
return false;
}
// If we didn't use the entire vector than we had enough room and we
// managed to read the entire filename.
if (length < filename.size())
break;
// Otherwise we need more space, so double the vector and try again.
filename.resize(filename.size() * 2);
}
base::FilePath temp_path = base::FilePath(filename.data());
if (!common::ConvertDevicePathToDrivePath(temp_path, path))
return false;
return true;
}
} // namespace
SampledModuleCache::SampledModuleCache(size_t log2_bucket_size)
: log2_bucket_size_(log2_bucket_size), module_count_(0) {
DCHECK_LE(2u, log2_bucket_size);
DCHECK_GE(31u, log2_bucket_size);
}
SampledModuleCache::~SampledModuleCache() {
// Force a clean up of all modules (and consequently all processes).
MarkAllModulesDead();
RemoveDeadModules();
}
bool SampledModuleCache::AddModule(HANDLE process,
HMODULE module_handle,
ProfilingStatus* status,
const Module** module) {
DCHECK(process != INVALID_HANDLE_VALUE);
DCHECK(status != NULL);
DCHECK(module != NULL);
*module = NULL;
// Create or find the process object. We don't actually insert it into the
// map until everything has succeeded, saving us the cleanup on failure.
DWORD pid = ::GetProcessId(process);
std::unique_ptr<Process> scoped_proc;
Process* proc = NULL;
ProcessMap::iterator proc_it = processes_.find(pid);
if (proc_it == processes_.end()) {
HANDLE temp_handle = INVALID_HANDLE_VALUE;
if (!::DuplicateHandle(::GetCurrentProcess(), process,
::GetCurrentProcess(), &temp_handle,
0, FALSE, DUPLICATE_SAME_ACCESS) ||
temp_handle == INVALID_HANDLE_VALUE) {
DWORD error = ::GetLastError();
LOG(ERROR) << "Failed to duplicate handle to process " << pid << ": "
<< common::LogWe(error);
return false;
}
scoped_proc.reset(new Process(temp_handle, pid));
if (!scoped_proc->Init())
return false;
proc = scoped_proc.get();
} else {
proc = proc_it->second;
}
DCHECK(proc != NULL);
if (!proc->AddModule(module_handle, log2_bucket_size_, status, module))
return false;
DCHECK(*module != NULL);
if (*status == kProfilingStarted)
++module_count_;
if (scoped_proc.get() != NULL) {
// Initialization was successful so we can safely insert the newly created
// process into the map.
processes_.insert(std::make_pair(pid, proc));
scoped_proc.release();
}
return true;
}
void SampledModuleCache::MarkAllModulesDead() {
for (ProcessMap::iterator proc_it = processes_.begin();
proc_it != processes_.end(); ++proc_it) {
proc_it->second->MarkDead();
}
}
void SampledModuleCache::RemoveDeadModules() {
ProcessMap::iterator proc_it = processes_.begin();
ProcessMap::iterator proc_it_next = proc_it;
while (proc_it != processes_.end()) {
++proc_it_next;
// Remove any dead modules from the process.
size_t old_module_count = proc_it->second->modules().size();
proc_it->second->RemoveDeadModules(dead_module_callback_);
size_t new_module_count = proc_it->second->modules().size();
// If the process itself is dead (contains no more profiling modules) then
// remove it.
if (!proc_it->second->alive()) {
Process* proc = proc_it->second;
delete proc;
processes_.erase(proc_it);
}
module_count_ += new_module_count;
DCHECK_LE(old_module_count, module_count_);
module_count_ -= old_module_count;
proc_it = proc_it_next;
}
}
SampledModuleCache::Process::Process(HANDLE process, DWORD pid)
: process_(process), pid_(pid), alive_(true) {
DCHECK(process != INVALID_HANDLE_VALUE);
}
SampledModuleCache::Process::~Process() {
MarkDead();
RemoveDeadModules(DeadModuleCallback());
}
bool SampledModuleCache::Process::Init() {
if (!process_info_.Initialize(pid_))
return false;
return true;
}
bool SampledModuleCache::Process::AddModule(HMODULE module_handle,
size_t log2_bucket_size,
ProfilingStatus* status,
const Module** module) {
DCHECK(module != INVALID_HANDLE_VALUE);
DCHECK_LE(2u, log2_bucket_size);
DCHECK_GE(31u, log2_bucket_size);
DCHECK(status != NULL);
DCHECK(module != NULL);
*module = NULL;
ModuleMap::iterator mod_it = modules_.find(module_handle);
if (mod_it != modules_.end()) {
// The module is already being profiled. Simply mark it as being alive.
mod_it->second->MarkAlive();
// And mark ourselves as being alive while we're at it.
MarkAlive();
*status = kProfilingContinued;
*module = mod_it->second;
return true;
}
// Create a new module object. We don't actually insert it into the map until
// everything has succeeded, saving us the cleanup on failure.
std::unique_ptr<Module> mod(
new Module(this, module_handle, log2_bucket_size));
if (!mod->Init())
return false;
if (!mod->Start())
return false;
// Initialization was successful so we can safely insert the initialized
// (and currently profiling) module into the map.
mod_it = modules_.insert(std::make_pair(module_handle, mod.release())).first;
MarkAlive();
*status = kProfilingStarted;
*module = mod_it->second;
return true;
}
void SampledModuleCache::Process::MarkDead() {
// Mark all of our children as dead, and ourselves.
alive_ = false;
for (ModuleMap::iterator it = modules_.begin(); it != modules_.end(); ++it)
it->second->MarkDead();
}
void SampledModuleCache::Process::RemoveDeadModules(
DeadModuleCallback callback) {
ModuleMap::iterator mod_it = modules_.begin();
ModuleMap::iterator mod_it_next = mod_it;
while (mod_it != modules_.end()) {
DCHECK(mod_it->second != NULL);
++mod_it_next;
if (!mod_it->second->alive()) {
// Stop profiling.
mod_it->second->Stop();
// Return the results to the callback if one has been provided.
if (!callback.is_null())
callback.Run(mod_it->second);
// And clean things up.
Module* mod = mod_it->second;
delete mod;
modules_.erase(mod_it);
}
mod_it = mod_it_next;
}
}
SampledModuleCache::Module::Module(Process* process,
HMODULE module,
size_t log2_bucket_size)
: process_(process),
module_(module),
module_size_(0),
module_checksum_(0),
module_time_date_stamp_(0),
buckets_begin_(NULL),
buckets_end_(NULL),
log2_bucket_size_(log2_bucket_size),
profiling_start_time_(0),
profiling_stop_time_(0),
alive_(true) {
DCHECK(process != NULL);
DCHECK(module_ != INVALID_HANDLE_VALUE);
DCHECK_LE(2u, log2_bucket_size);
DCHECK_GE(31u, log2_bucket_size);
}
bool SampledModuleCache::Module::Init() {
if (!GetModulePath(process_->process(), module_, &module_path_))
return false;
// Read the headers.
char headers[4096] = {};
size_t net_bytes_read = 0;
size_t empty_reads = 0;
while (net_bytes_read < sizeof(headers)) {
SIZE_T bytes_read = 0;
if (::ReadProcessMemory(process_->process(),
module_,
headers + net_bytes_read,
sizeof(headers) - net_bytes_read,
&bytes_read) == FALSE) {
DWORD error = ::GetLastError();
LOG(ERROR) << "ReadProcessMemory failed for module at address "
<< base::StringPrintf("0x%08X", module_)
<< " of process " << process_->pid() << ": "
<< common::LogWe(error);
return false;
}
if (bytes_read == 0) {
if (++empty_reads == 3) {
LOG(ERROR) << "ReadProcessMemory unable to read headers for module at "
<< "address " << base::StringPrintf("0x%08X", module_)
<< " of process " << process_->pid() << ".";
return false;
}
} else {
net_bytes_read += bytes_read;
empty_reads = 0;
}
}
const IMAGE_DOS_HEADER* dos_header =
reinterpret_cast<const IMAGE_DOS_HEADER*>(headers);
static_assert(sizeof(IMAGE_DOS_HEADER) <= sizeof(headers),
"headers must be big enough for DOS headers.");
// Get the NT headers and make sure they're fully contained in the block we
// read.
if (dos_header->e_lfanew > sizeof(headers))
return false;
const IMAGE_NT_HEADERS* nt_headers =
reinterpret_cast<const IMAGE_NT_HEADERS*>(headers + dos_header->e_lfanew);
if (reinterpret_cast<const char*>(nt_headers + 1) - headers > sizeof(headers))
return false;
// Get the section headers and make sure they're fully contained in the
// block we read.
size_t section_count = nt_headers->FileHeader.NumberOfSections;
const IMAGE_SECTION_HEADER* section_headers =
reinterpret_cast<const IMAGE_SECTION_HEADER*>(nt_headers + 1);
if (reinterpret_cast<const char*>(section_headers + section_count) - headers >
sizeof(headers)) {
return false;
}
module_size_ = nt_headers->OptionalHeader.SizeOfImage;
module_checksum_ = nt_headers->OptionalHeader.CheckSum;
module_time_date_stamp_ = nt_headers->FileHeader.TimeDateStamp;
// Find the RVA range associated with any text segments in the module.
DWORD text_begin = SIZE_MAX;
DWORD text_end = 0;
for (size_t i = 0; i < section_count; ++i) {
const IMAGE_SECTION_HEADER& sh = section_headers[i];
static const DWORD kExecFlags = IMAGE_SCN_MEM_EXECUTE | IMAGE_SCN_CNT_CODE;
if ((sh.Characteristics & kExecFlags) == 0)
continue;
DWORD sec_begin = sh.VirtualAddress;
DWORD sec_end = sec_begin + sh.Misc.VirtualSize;
if (sec_begin < text_begin)
text_begin = sec_begin;
if (sec_end > text_end)
text_end = sec_end;
}
// Adjust the address range for the bucket size.
DWORD bucket_size = 1 << log2_bucket_size_;
text_begin = (text_begin / bucket_size) * bucket_size;
text_end = ((text_end + bucket_size - 1) / bucket_size) * bucket_size;
// Calculate the number of buckets.
DCHECK_EQ(0u, (text_end - text_begin) % bucket_size);
DWORD bucket_count = (text_end - text_begin) / bucket_size;
// Calculate the bucket range in the remote address space.
buckets_begin_ = reinterpret_cast<const void*>(
reinterpret_cast<const char*>(module_) + text_begin);
buckets_end_ = reinterpret_cast<const void*>(
reinterpret_cast<const char*>(module_) + text_end);
// Initialize the profiler.
if (!profiler_.Initialize(process_->process(),
const_cast<void*>(buckets_begin_),
text_end - text_begin,
log2_bucket_size_)) {
LOG(ERROR) << "Failed to initialize profiler for address range "
<< base::StringPrintf("0x%08X - 0x%08X",
buckets_begin_,
buckets_end_)
<< " of process " << process_->pid() << ".";
return false;
}
DCHECK_EQ(bucket_count, profiler_.buckets().size());
return true;
}
bool SampledModuleCache::Module::Start() {
if (!profiler_.Start())
return false;
profiling_start_time_ = trace::common::GetTsc();
return true;
}
bool SampledModuleCache::Module::Stop() {
if (!profiler_.Stop())
return false;
profiling_stop_time_ = trace::common::GetTsc();
return true;
}
} // namespace sampler