blob: 97e545a1778becb12bbb1f406406b72e6d4d32a1 [file] [log] [blame]
##########################################################################
#
# Copyright 2008-2010 VMware, Inc.
# All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
##########################################################################/
"""Common trace code generation."""
# Adjust path
import os.path
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import itertools
import specs.stdapi as stdapi
def getWrapperInterfaceName(interface):
return "Wrap" + interface.expr
debug = False
class ComplexValueSerializer(stdapi.OnceVisitor):
'''Type visitors which generates serialization functions for
complex types.
Simple types are serialized inline.
'''
def __init__(self, serializer):
stdapi.OnceVisitor.__init__(self)
self.serializer = serializer
def visitVoid(self, literal):
pass
def visitLiteral(self, literal):
pass
def visitString(self, string):
pass
def visitConst(self, const):
self.visit(const.type)
def visitStruct(self, struct):
# Write array with structure's member names
numMembers = len(struct.members)
if numMembers:
# Ensure member array has nonzero length to avoid MSVC error C2466
memberNames = '_struct%s_members' % (struct.tag,)
print('static const char * %s[%u] = {' % (memberNames, numMembers))
for type, name, in struct.members:
if name is None:
print(' "",')
else:
print(' "%s",' % (name,))
print('};')
else:
sys.stderr.write('warning: %s has no members\n' % struct.name)
memberNames = 'nullptr'
# Write structure's signature
print('static const trace::StructSig _struct%s_sig = {' % (struct.tag,))
if struct.name is None:
structName = '""'
else:
structName = '"%s"' % struct.name
print(' %u, %s, %u, %s' % (struct.id, structName, numMembers, memberNames))
print('};')
print()
def visitArray(self, array):
self.visit(array.type)
def visitAttribArray(self, array):
pass
def visitBlob(self, array):
pass
def visitEnum(self, enum):
print('static const trace::EnumValue _enum%s_values[] = {' % (enum.tag))
for value in enum.values:
print(' {"%s", %s},' % (value, value))
print('};')
print()
print('static const trace::EnumSig _enum%s_sig = {' % (enum.tag))
print(' %u, %u, _enum%s_values' % (enum.id, len(enum.values), enum.tag))
print('};')
print()
def visitBitmask(self, bitmask):
print('static const trace::BitmaskFlag _bitmask%s_flags[] = {' % (bitmask.tag))
for value in bitmask.values:
print(' {"%s", %s},' % (value, value))
print('};')
print()
print('static const trace::BitmaskSig _bitmask%s_sig = {' % (bitmask.tag))
print(' %u, %u, _bitmask%s_flags' % (bitmask.id, len(bitmask.values), bitmask.tag))
print('};')
print()
def visitPointer(self, pointer):
self.visit(pointer.type)
def visitIntPointer(self, pointer):
pass
def visitObjPointer(self, pointer):
self.visit(pointer.type)
def visitLinearPointer(self, pointer):
self.visit(pointer.type)
def visitHandle(self, handle):
self.visit(handle.type)
def visitReference(self, reference):
self.visit(reference.type)
def visitAlias(self, alias):
self.visit(alias.type)
def visitOpaque(self, opaque):
pass
def visitInterface(self, interface):
pass
def visitPolymorphic(self, polymorphic):
if not polymorphic.contextLess:
return
print('static void _write__%s(int selector, %s const & value) {' % (polymorphic.tag, polymorphic.expr))
print(' switch (selector) {')
for cases, type in polymorphic.iterSwitch():
for case in cases:
print(' %s:' % case)
self.serializer.visit(type, '(%s)(value)' % (type,))
print(' break;')
print(' }')
print('}')
print()
class ValueSerializer(stdapi.Visitor, stdapi.ExpanderMixin):
'''Visitor which generates code to serialize any type.
Simple types are serialized inline here, whereas the serialization of
complex types is dispatched to the serialization functions generated by
ComplexValueSerializer visitor above.
'''
def visitLiteral(self, literal, instance):
print(' trace::localWriter.write%s(%s);' % (literal.kind, instance))
def visitString(self, string, instance):
if not string.wide:
cast = 'const char *'
suffix = 'String'
else:
cast = 'const wchar_t *'
suffix = 'WString'
if cast != string.expr:
# reinterpret_cast is necessary for GLubyte * <=> char *
instance = 'reinterpret_cast<%s>(%s)' % (cast, instance)
if string.length is not None:
length = ', %s' % self.expand(string.length)
else:
length = ''
print(' trace::localWriter.write%s(%s%s);' % (suffix, instance, length))
def visitConst(self, const, instance):
self.visit(const.type, instance)
def visitStruct(self, struct, instance):
print(' trace::localWriter.beginStruct(&_struct%s_sig);' % (struct.tag,))
for member in struct.members:
self.visitMember(member, instance)
print(' trace::localWriter.endStruct();')
def visitArray(self, array, instance):
length = '_c' + array.type.tag
index = '_i' + array.type.tag
array_length = self.expand(array.length)
print(' if (%s) {' % instance)
print(' size_t %s = %s > 0 ? %s : 0;' % (length, array_length, array_length))
print(' trace::localWriter.beginArray(%s);' % length)
print(' for (size_t %s = 0; %s < %s; ++%s) {' % (index, index, length, index))
print(' trace::localWriter.beginElement();')
self.visitElement(index, array.type, '(%s)[%s]' % (instance, index))
print(' trace::localWriter.endElement();')
print(' }')
print(' trace::localWriter.endArray();')
print(' } else {')
print(' trace::localWriter.writeNull();')
print(' }')
def visitAttribArray(self, array, instance):
# For each element, decide if it is a key or a value (which depends on the previous key).
# If it is a value, store it as the right type - usually int, some bitfield, or some enum.
# It is currently assumed that an unknown key means that it is followed by an int value.
# determine the array length which must be passed to writeArray() up front
count = '_c' + array.baseType.tag
print(' {')
print(' int %s;' % count)
print(' for (%(c)s = 0; %(array)s && %(array)s[%(c)s] != %(terminator)s; %(c)s += 2) {' \
% {'c': count, 'array': instance, 'terminator': array.terminator})
if array.hasKeysWithoutValues:
print(' switch (int(%(array)s[%(c)s])) {' % {'array': instance, 'c': count})
for key, valueType in array.valueTypes:
if valueType is None:
print(' case %s:' % key)
print(' %s--;' % count) # the next value is a key again and checked if it's the terminator
print(' break;')
print(' }')
print(' }')
print(' %(c)s += %(array)s ? 1 : 0;' % {'c': count, 'array': instance})
print(' trace::localWriter.beginArray(%s);' % count)
# for each key / key-value pair write the key and the value, if the key requires one
index = '_i' + array.baseType.tag
print(' for (int %(i)s = 0; %(i)s < %(count)s; %(i)s++) {' % {'i': index, 'count': count})
print(' trace::localWriter.beginElement();')
self.visit(array.baseType, "%(array)s[%(i)s]" % {'array': instance, 'i': index})
print(' trace::localWriter.endElement();')
print(' if (%(i)s + 1 >= %(count)s) {' % {'i': index, 'count': count})
print(' break;')
print(' }')
print(' switch (int(%(array)s[%(i)s++])) {' % {'array': instance, 'i': index})
# write generic value the usual way
for key, valueType in array.valueTypes:
if valueType is not None:
print(' case %s:' % key)
print(' trace::localWriter.beginElement();')
self.visitElement(index, valueType, '(%(array)s)[%(i)s]' % {'array': instance, 'i': index})
print(' trace::localWriter.endElement();')
print(' break;')
# known key with no value, just decrease the index so we treat the next value as a key
if array.hasKeysWithoutValues:
for key, valueType in array.valueTypes:
if valueType is None:
print(' case %s:' % key)
print(' %s--;' % index)
print(' break;')
# unknown key, write an int value
print(' default:')
print(' trace::localWriter.beginElement();')
print(' os::log("apitrace: warning: %s: unknown key 0x%04X, interpreting value as int\\n", ' + \
'__FUNCTION__, int(%(array)s[%(i)s - 1]));' % {'array': instance, 'i': index})
print(' trace::localWriter.writeSInt(%(array)s[%(i)s]);' % {'array': instance, 'i': index})
print(' trace::localWriter.endElement();')
print(' break;')
print(' }')
print(' }')
print(' trace::localWriter.endArray();')
print(' }')
def visitBlob(self, blob, instance):
print(' trace::localWriter.writeBlob(%s, %s);' % (instance, self.expand(blob.size)))
def visitEnum(self, enum, instance):
print(' trace::localWriter.writeEnum(&_enum%s_sig, %s);' % (enum.tag, instance))
def visitBitmask(self, bitmask, instance):
print(' trace::localWriter.writeBitmask(&_bitmask%s_sig, %s);' % (bitmask.tag, instance))
def visitPointer(self, pointer, instance):
print(' if (%s) {' % instance)
print(' trace::localWriter.beginArray(1);')
print(' trace::localWriter.beginElement();')
self.visit(pointer.type, "*" + instance)
print(' trace::localWriter.endElement();')
print(' trace::localWriter.endArray();')
print(' } else {')
print(' trace::localWriter.writeNull();')
print(' }')
def visitIntPointer(self, pointer, instance):
print(' trace::localWriter.writePointer((uintptr_t)%s);' % instance)
def visitObjPointer(self, pointer, instance):
print(' trace::localWriter.writePointer((uintptr_t)%s);' % instance)
def visitLinearPointer(self, pointer, instance):
print(' trace::localWriter.writePointer((uintptr_t)%s);' % instance)
def visitReference(self, reference, instance):
self.visit(reference.type, instance)
def visitHandle(self, handle, instance):
self.visit(handle.type, instance)
def visitAlias(self, alias, instance):
self.visit(alias.type, instance)
def visitOpaque(self, opaque, instance):
print(' trace::localWriter.writePointer((uintptr_t)%s);' % instance)
def visitInterface(self, interface, instance):
assert False
def visitPolymorphic(self, polymorphic, instance):
if polymorphic.contextLess:
print(' _write__%s(%s, %s);' % (polymorphic.tag, polymorphic.switchExpr, instance))
else:
switchExpr = self.expand(polymorphic.switchExpr)
print(' switch (%s) {' % switchExpr)
for cases, type in polymorphic.iterSwitch():
for case in cases:
print(' %s:' % case)
caseInstance = instance
if type.expr is not None:
caseInstance = 'static_cast<%s>(%s)' % (type, caseInstance)
self.visit(type, caseInstance)
print(' break;')
if polymorphic.defaultType is None:
print(r' default:')
print(r' os::log("apitrace: warning: %%s: unexpected polymorphic case %%i\n", __FUNCTION__, (int)%s);' % (switchExpr,))
print(r' trace::localWriter.writeNull();')
print(r' break;')
print(' }')
class WrapDecider(stdapi.Traverser):
'''Type visitor which will decide wheter this type will need wrapping or not.
For complex types (arrays, structures), we need to know this before hand.
'''
def __init__(self):
self.needsWrapping = False
def visitLinearPointer(self, void):
pass
def visitObjPointer(self, interface):
self.needsWrapping = True
class ValueWrapper(stdapi.Traverser, stdapi.ExpanderMixin):
'''Type visitor which will generate the code to wrap an instance.
Wrapping is necessary mostly for interfaces, however interface pointers can
appear anywhere inside complex types.
'''
def visitStruct(self, struct, instance):
for member in struct.members:
self.visitMember(member, instance)
def visitArray(self, array, instance):
array_length = self.expand(array.length)
print(" if (%s) {" % instance)
print(" for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length)
self.visitElement('_i', array.type, instance + "[_i]")
print(" }")
print(" }")
def visitPointer(self, pointer, instance):
print(" if (%s) {" % instance)
self.visit(pointer.type, "*" + instance)
print(" }")
def visitObjPointer(self, pointer, instance):
elem_type = pointer.type.mutable()
if isinstance(elem_type, stdapi.Interface):
self.visitInterfacePointer(elem_type, instance)
elif isinstance(elem_type, stdapi.Alias) and isinstance(elem_type.type, stdapi.Interface):
self.visitInterfacePointer(elem_type.type, instance)
else:
# All interfaces should at least implement IUnknown
print(" WrapIUnknown::_wrap(__FUNCTION__, (IUnknown **) &%s);" % (instance,))
def visitInterface(self, interface, instance):
raise NotImplementedError
def visitInterfacePointer(self, interface, instance):
print(" Wrap%s::_wrap(__FUNCTION__, &%s);" % (interface.name, instance))
def visitPolymorphic(self, type, instance):
# XXX: There might be polymorphic values that need wrapping in the future
raise NotImplementedError
class ValueUnwrapper(ValueWrapper):
'''Reverse of ValueWrapper.'''
allocated = False
def visitStruct(self, struct, instance):
if not self.allocated:
# Argument is constant. We need to create a non const
print(' {')
print(" %s * _t = static_cast<%s *>(alloca(sizeof *_t));" % (struct, struct))
print(' *_t = %s;' % (instance,))
assert instance.startswith('*')
print(' %s = _t;' % (instance[1:],))
instance = '*_t'
self.allocated = True
try:
return ValueWrapper.visitStruct(self, struct, instance)
finally:
print(' }')
else:
return ValueWrapper.visitStruct(self, struct, instance)
def visitArray(self, array, instance):
if self.allocated or isinstance(instance, stdapi.Interface):
return ValueWrapper.visitArray(self, array, instance)
array_length = self.expand(array.length)
elem_type = array.type.mutable()
print(" if (%s && %s) {" % (instance, array_length))
print(" %s * _t = static_cast<%s *>(alloca(%s * sizeof *_t));" % (elem_type, elem_type, array_length))
print(" for (size_t _i = 0, _s = %s; _i < _s; ++_i) {" % array_length)
print(" _t[_i] = %s[_i];" % instance)
self.allocated = True
self.visit(array.type, "_t[_i]")
print(" }")
print(" %s = _t;" % instance)
print(" }")
def visitInterfacePointer(self, interface, instance):
print(r' Wrap%s::_unwrap(__FUNCTION__, &%s);' % (interface.name, instance))
def _getInterfaceHierarchy(allIfaces, baseIface, result):
for iface in allIfaces:
if iface.base is baseIface:
_getInterfaceHierarchy(allIfaces, iface, result)
result.append(iface)
def getInterfaceHierarchy(allIfaces, baseIface):
result = []
_getInterfaceHierarchy(allIfaces, baseIface, result)
return result
class Tracer:
'''Base class to orchestrate the code generation of API tracing.'''
# 0-3 are reserved to memcpy, malloc, free, and realloc
__id = 4
def __init__(self):
self.api = None
def serializerFactory(self):
'''Create a serializer.
Can be overriden by derived classes to inject their own serialzer.
'''
return ValueSerializer()
def traceApi(self, api):
self.api = api
self.header(api)
# Includes
for module in api.modules:
for header in module.headers:
print(header)
print()
# Generate the serializer functions
types = api.getAllTypes()
visitor = ComplexValueSerializer(self.serializerFactory())
for tp in types:
visitor.visit(tp)
print()
# Interfaces wrapers
self.traceInterfaces(api)
# Function wrappers
self.interface = None
self.base = None
for function in api.getAllFunctions():
self.traceFunctionDecl(function)
for function in api.getAllFunctions():
try:
self.traceFunctionImpl(function)
except:
sys.stderr.write("error: %s: exception\n" % function.name)
raise
print()
self.footer(api)
def header(self, api):
print('#ifdef _WIN32')
print('# include <malloc.h> // alloca')
print('# ifndef alloca')
print('# define alloca _alloca')
print('# endif')
print('#else')
print('# include <alloca.h> // alloca')
print('#endif')
print()
print()
print(r'/*')
print(r' * g_WrappedObjects is already protected by trace::LocalWriter::mutex')
print(r' * This lock is hold during the beginEnter/endEnter and beginLeave/endLeave sections')
print(r' */')
print('static std::map<void *, void *> g_WrappedObjects;')
def footer(self, api):
pass
def traceFunctionDecl(self, function):
# Per-function declarations
if not function.internal:
if function.args:
print('static const char * _%s_args[%u] = {%s};' % (function.name, len(function.args), ', '.join(['"%s"' % arg.name for arg in function.args])))
else:
print('static const char ** _%s_args = NULL;' % (function.name,))
print('static const trace::FunctionSig _%s_sig = {%u, "%s", %u, _%s_args};' % (function.name, self.getFunctionSigId(), function.sigName(), len(function.args), function.name))
print()
def getFunctionSigId(self):
id = Tracer.__id
Tracer.__id += 1
return id
def isFunctionPublic(self, function):
return True
def traceFunctionImpl(self, function):
if self.isFunctionPublic(function):
print('extern "C" PUBLIC')
else:
print('extern "C" PRIVATE')
print(function.prototype() + ' {')
if function.type is not stdapi.Void:
print(' %s _result;' % function.type)
for arg in function.args:
if not arg.output:
self.unwrapArg(function, arg)
self.traceFunctionImplBody(function)
# XXX: wrapping should go here, but before we can do that we'll need to protect g_WrappedObjects with its own mutex
if function.type is not stdapi.Void:
print(' return _result;')
print('}')
print()
def traceFunctionImplBody(self, function):
if not function.internal:
print(' unsigned _call = trace::localWriter.beginEnter(&_%s_sig);' % (function.name,))
for arg in function.args:
if not arg.output:
self.serializeArg(function, arg)
print(' trace::localWriter.endEnter();')
self.invokeFunction(function)
if not function.internal:
print(' trace::localWriter.beginLeave(_call);')
print(' if (%s) {' % self.wasFunctionSuccessful(function))
for arg in function.args:
if arg.output:
self.serializeArg(function, arg)
self.wrapArg(function, arg)
print(' }')
if function.type is not stdapi.Void:
self.serializeRet(function, "_result")
if function.type is not stdapi.Void:
self.wrapRet(function, "_result")
print(' trace::localWriter.endLeave();')
def invokeFunction(self, function):
self.doInvokeFunction(function)
def doInvokeFunction(self, function, prefix='_', suffix=''):
# Same as invokeFunction() but called both when trace is enabled or disabled.
if function.type is stdapi.Void:
result = ''
else:
result = '_result = '
dispatch = prefix + function.name + suffix
print(' %s%s(%s);' % (result, dispatch, ', '.join([str(arg.name) for arg in function.args])))
def wasFunctionSuccessful(self, function):
if function.type is stdapi.Void:
return 'true'
if str(function.type) == 'HRESULT':
return 'SUCCEEDED(_result)'
return 'true'
def serializeArg(self, function, arg):
print(' trace::localWriter.beginArg(%u);' % (arg.index,))
self.serializeArgValue(function, arg)
print(' trace::localWriter.endArg();')
def serializeArgValue(self, function, arg):
self.serializeValue(arg.type, arg.name)
def wrapArg(self, function, arg):
assert not isinstance(arg.type, stdapi.ObjPointer)
from specs.winapi import REFIID
riid = None
for other_arg in function.args:
if not other_arg.output and other_arg.type is REFIID:
riid = other_arg
if riid is not None \
and riid.name != 'EmulatedInterface' \
and isinstance(arg.type, stdapi.Pointer) \
and isinstance(arg.type.type, stdapi.ObjPointer):
self.wrapIid(function, riid, arg)
return
self.wrapValue(arg.type, arg.name)
def unwrapArg(self, function, arg):
self.unwrapValue(arg.type, arg.name)
def serializeRet(self, function, instance):
print(' trace::localWriter.beginReturn();')
self.serializeValue(function.type, instance)
print(' trace::localWriter.endReturn();')
def serializeValue(self, type, instance):
serializer = self.serializerFactory()
serializer.visit(type, instance)
def wrapRet(self, function, instance):
self.wrapValue(function.type, instance)
def needsWrapping(self, type):
visitor = WrapDecider()
visitor.visit(type)
return visitor.needsWrapping
def wrapValue(self, type, instance):
if self.needsWrapping(type):
visitor = ValueWrapper()
visitor.visit(type, instance)
def unwrapValue(self, type, instance):
if self.needsWrapping(type):
visitor = ValueUnwrapper()
visitor.visit(type, instance)
def traceInterfaces(self, api):
interfaces = api.getAllInterfaces()
if not interfaces:
return
print(r'#include "guids.hpp"')
print()
# Helper functions to wrap/unwrap interface pointers
print(r'static inline bool')
print(r'hasChildInterface(REFIID riid, IUnknown *pUnknown) {')
print(r' IUnknown *pObj = NULL;')
print(r' HRESULT hr = pUnknown->QueryInterface(riid, (VOID **)&pObj);')
print(r' if (FAILED(hr)) {')
print(r' return false;')
print(r' }')
print(r' assert(pObj);')
print(r' pObj->Release();')
print(r' return pUnknown == pObj;')
print(r'}')
print()
print(r'static inline const void *')
print(r'getVtbl(const void *pvObj) {')
print(r' return pvObj ? *(const void **)pvObj : NULL;')
print(r'}')
print()
print(r'static void')
print(r'warnVtbl(const void *pVtbl) {')
print(r' HMODULE hModule = 0;')
print(r' BOOL bRet = GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |')
print(r' GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,')
print(r' (LPCTSTR)pVtbl,')
print(r' &hModule);')
print(r' if (bRet) {')
print(r' char szModule[MAX_PATH];')
print(r' DWORD dwRet = GetModuleFileNameA(hModule, szModule, sizeof szModule);')
print(r' assert(dwRet);')
print(r' if (dwRet) {')
print(r' DWORD dwOffset = (UINT_PTR)pVtbl - (UINT_PTR)hModule;')
print(r' os::log("apitrace: warning: pVtbl = %p (%s!+0x%0lx)\n", pVtbl, szModule, dwOffset);')
print(r' } else {')
print(r' os::log("apitrace: warning: pVtbl = %p\n", pVtbl);')
print(r' }')
print(r' }')
print(r'}')
print()
for iface in interfaces:
self.declareWrapperInterface(iface)
self.implementIidWrapper(api)
for iface in interfaces:
self.implementWrapperInterface(iface)
print()
def declareWrapperInterface(self, interface):
wrapperInterfaceName = getWrapperInterfaceName(interface)
print("class %s : public %s " % (wrapperInterfaceName, interface.name))
print("{")
print("private:")
print(" %s(%s * pInstance);" % (wrapperInterfaceName, interface.name))
print(" ~%s(); // Not implemented" % wrapperInterfaceName)
print("public:")
print(" static %s* _create(const char *entryName, %s * pInstance);" % (wrapperInterfaceName, interface.name))
print(" static void _wrap(const char *entryName, %s ** ppInstance);" % (interface.name,))
print(" static void _unwrap(const char *entryName, %s ** pInstance);" % (interface.name,))
print()
methods = list(interface.iterMethods())
for method in methods:
print(" " + method.prototype() + " override;")
print()
for type, name, value in self.enumWrapperInterfaceVariables(interface):
print(' %s %s;' % (type, name))
print()
print(r'private:')
print(r' void _dummy(unsigned i) const {')
print(r' os::log("error: %%s: unexpected virtual method %%i of instance pvObj=%%p pWrapper=%%p pVtbl=%%p\n", "%s", i, m_pInstance, this, m_pVtbl);' % interface.name)
print(r' warnVtbl(m_pVtbl);')
print(r' warnVtbl(getVtbl(m_pInstance));')
print(r' trace::localWriter.flush();')
print(r' os::abort();')
print(r' }')
print()
for i in range(len(methods), 64):
print(r' virtual void _dummy%i(void) const { _dummy(%i); }' % (i, i))
print()
print("};")
print()
def enumWrapperInterfaceVariables(self, interface):
return [
("DWORD", "m_dwMagic", "0xd8365d6c"),
("%s *" % interface.name, "m_pInstance", "pInstance"),
("const void *", "m_pVtbl", "getVtbl(pInstance)"),
("UINT", "m_NumMethods", len(list(interface.iterBaseMethods()))),
]
def implementWrapperInterface(self, iface):
self.interface = iface
wrapperInterfaceName = getWrapperInterfaceName(iface)
# Private constructor
print('%s::%s(%s * pInstance) {' % (wrapperInterfaceName, wrapperInterfaceName, iface.name))
for type, name, value in self.enumWrapperInterfaceVariables(iface):
if value is not None:
print(' %s = %s;' % (name, value))
print('}')
print()
# Public constructor
print('%s *%s::_create(const char *entryName, %s * pInstance) {' % (wrapperInterfaceName, wrapperInterfaceName, iface.name))
print(r' Wrap%s *pWrapper = new Wrap%s(pInstance);' % (iface.name, iface.name))
if debug:
print(r' os::log("%%s: created %s pvObj=%%p pWrapper=%%p pVtbl=%%p\n", entryName, pInstance, pWrapper, pWrapper->m_pVtbl);' % iface.name)
print(r' g_WrappedObjects[pInstance] = pWrapper;')
print(r' return pWrapper;')
print('}')
print()
baseMethods = list(iface.iterBaseMethods())
for base, method in baseMethods:
self.base = base
self.implementWrapperInterfaceMethod(iface, base, method)
print()
# Wrap pointer
ifaces = self.api.getAllInterfaces()
print(r'void')
print(r'%s::_wrap(const char *entryName, %s **ppObj) {' % (wrapperInterfaceName, iface.name))
print(r' if (!ppObj) {')
print(r' return;')
print(r' }')
print(r' %s *pObj = *ppObj;' % (iface.name,))
print(r' if (!pObj) {')
print(r' return;')
print(r' }')
print(r' assert(hasChildInterface(IID_%s, pObj));' % iface.name)
print(r' std::map<void *, void *>::const_iterator it = g_WrappedObjects.find(pObj);')
print(r' if (it != g_WrappedObjects.end()) {')
print(r' Wrap%s *pWrapper = (Wrap%s *)it->second;' % (iface.name, iface.name))
print(r' assert(pWrapper);')
print(r' assert(pWrapper->m_dwMagic == 0xd8365d6c);')
print(r' assert(pWrapper->m_pInstance == pObj);')
print(r' if (pWrapper->m_pVtbl == getVtbl(pObj) &&')
print(r' pWrapper->m_NumMethods >= %s) {' % len(baseMethods))
if debug:
print(r' os::log("%s: fetched pvObj=%p pWrapper=%p pVtbl=%p\n", entryName, pObj, pWrapper, pWrapper->m_pVtbl);')
print(r' assert(hasChildInterface(IID_%s, pWrapper->m_pInstance));' % iface.name)
print(r' *ppObj = pWrapper;')
print(r' return;')
print(r' } else {')
if debug:
print(r' os::log("%s::Release: deleted pvObj=%%p pWrapper=%%p pVtbl=%%p\n", pWrapper->m_pInstance, pWrapper, pWrapper->m_pVtbl);' % iface.name)
print(r' g_WrappedObjects.erase(pObj);')
print(r' }')
print(r' }')
for childIface in getInterfaceHierarchy(ifaces, iface):
print(r' if (hasChildInterface(IID_%s, pObj)) {' % (childIface.name,))
print(r' *ppObj = Wrap%s::_create(entryName, static_cast<%s *>(pObj));' % (childIface.name, childIface.name))
print(r' return;')
print(r' }')
print(r' *ppObj = Wrap%s::_create(entryName, pObj);' % iface.name)
print(r'}')
print()
# Unwrap pointer
print(r'void')
print(r'%s::_unwrap(const char *entryName, %s **ppObj) {' % (wrapperInterfaceName, iface.name))
print(r' if (!ppObj || !*ppObj) {')
print(r' return;')
print(r' }')
print(r' const %s *pWrapper = static_cast<const %s*>(*ppObj);' % (wrapperInterfaceName, getWrapperInterfaceName(iface)))
print(r' if (pWrapper && pWrapper->m_dwMagic == 0xd8365d6c) {')
print(r' *ppObj = pWrapper->m_pInstance;')
print(r' } else {')
print(r' os::log("apitrace: warning: %%s: unexpected %%s pointer %%p\n", entryName, "%s", *ppObj);' % iface.name)
print(r' trace::localWriter.flush();')
print(r' }')
print(r'}')
print()
def implementWrapperInterfaceMethod(self, interface, base, method):
wrapperInterfaceName = getWrapperInterfaceName(interface)
print(method.prototype(wrapperInterfaceName + '::' + method.name) + ' {')
if False:
print(r' os::log("%%s(%%p -> %%p)\n", "%s", this, m_pInstance);' % (wrapperInterfaceName + '::' + method.name))
if method.type is not stdapi.Void:
print(' %s _result;' % method.type)
print(' %s *_this = static_cast<%s *>(m_pInstance);' % (base, base))
for arg in method.args:
if not arg.output:
self.unwrapArg(method, arg)
self.implementWrapperInterfaceMethodBody(interface, base, method)
# XXX: wrapping should go here, but before we can do that we'll need to protect g_WrappedObjects with its own mutex
if method.type is not stdapi.Void:
print(' return _result;')
print('}')
print()
def implementWrapperInterfaceMethodBody(self, interface, base, method):
assert not method.internal
sigName = interface.name + '::' + method.sigName()
if method.overloaded:
# Once the method signature name goes into a trace, we'll need to
# support it indefinetely, so log them so one can make sure nothing
# weird gets baked in
sys.stderr.write('note: overloaded method %s\n' % (sigName,))
numArgs = len(method.args) + 1
print(' static const char * _args[%u] = {%s};' % (numArgs, ', '.join(['"this"'] + ['"%s"' % arg.name for arg in method.args])))
print(' static const trace::FunctionSig _sig = {%u, "%s", %u, _args};' % (self.getFunctionSigId(), sigName, numArgs))
print(' unsigned _call = trace::localWriter.beginEnter(&_sig);')
print(' trace::localWriter.beginArg(0);')
print(' trace::localWriter.writePointer((uintptr_t)m_pInstance);')
print(' trace::localWriter.endArg();')
for arg in method.args:
if not arg.output:
self.serializeArg(method, arg)
print(' trace::localWriter.endEnter();')
self.invokeMethod(interface, base, method)
print(' trace::localWriter.beginLeave(_call);')
print(' if (%s) {' % self.wasFunctionSuccessful(method))
for arg in method.args:
if arg.output:
self.serializeArg(method, arg)
self.wrapArg(method, arg)
print(' }')
if method.type is not stdapi.Void:
self.serializeRet(method, '_result')
if method.type is not stdapi.Void:
self.wrapRet(method, '_result')
if method.name == 'Release':
assert method.type is not stdapi.Void
print(r' if (!_result) {')
print(r' // NOTE: Must not delete the wrapper here. See')
print(r' // https://github.com/apitrace/apitrace/issues/462')
print(r' }')
print(' trace::localWriter.endLeave();')
def implementIidWrapper(self, api):
ifaces = api.getAllInterfaces()
print(r'static void')
print(r'warnIID(const char *entryName, REFIID riid, void *pvObj, const char *reason) {')
print(r' os::log("apitrace: warning: %s: %s IID %s\n",')
print(r' entryName, reason,')
print(r' getGuidName(riid));')
print(r' const void * pVtbl = getVtbl(pvObj);')
print(r' warnVtbl(pVtbl);')
print(r'}')
print()
print(r'static void')
print(r'wrapIID(const char *entryName, REFIID riid, void * * ppvObj) {')
print(r' if (!ppvObj || !*ppvObj) {')
print(r' return;')
print(r' }')
for iface in ifaces:
print(r' if (riid == IID_%s) {' % (iface.name,))
print(r' Wrap%s::_wrap(entryName, (%s **) ppvObj);' % (iface.name, iface.name))
print(r' return;')
print(r' }')
print(r' warnIID(entryName, riid, *ppvObj, "unsupported");')
print(r'}')
print()
def wrapIid(self, function, riid, out):
# Cast output arg to `void **` if necessary
out_name = out.name
obj_type = out.type.type.type
if not obj_type is stdapi.Void:
assert isinstance(obj_type, stdapi.Interface)
out_name = 'reinterpret_cast<void * *>(%s)' % out_name
print(r' if (%s && *%s) {' % (out.name, out.name))
functionName = function.name
else_ = ''
if self.interface is not None:
functionName = self.interface.name + '::' + functionName
print(r' if (*%s == m_pInstance &&' % (out_name,))
print(r' (%s)) {' % ' || '.join('%s == IID_%s' % (riid.name, iface.name) for iface in self.interface.iterBases()))
print(r' *%s = this;' % (out_name,))
print(r' }')
else_ = 'else '
print(r' %s{' % else_)
print(r' wrapIID("%s", %s, %s);' % (functionName, riid.name, out_name))
print(r' }')
print(r' }')
def invokeMethod(self, interface, base, method):
if method.type is stdapi.Void:
result = ''
else:
result = '_result = '
print(' %s_this->%s(%s);' % (result, method.name, ', '.join([str(arg.name) for arg in method.args])))
def emit_memcpy(self, ptr, size):
print(' trace::fakeMemcpy(%s, %s);' % (ptr, size))
def fake_call(self, function, args):
print(' {')
print(' unsigned _fake_call = trace::localWriter.beginEnter(&_%s_sig, true);' % (function.name,))
for arg, instance in zip(function.args, args):
assert not arg.output
print(' trace::localWriter.beginArg(%u);' % (arg.index,))
self.serializeValue(arg.type, instance)
print(' trace::localWriter.endArg();')
print(' trace::localWriter.endEnter();')
print(' trace::localWriter.beginLeave(_fake_call);')
print(' trace::localWriter.endLeave();')
print(' }')