blob: 356b81a93a8d52aa7041e6bd8abc7ae9e0eb8f69 [file] [log] [blame]
# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Codegen for calling C++ methods from Java."""
from codegen import convert_type
from codegen import header_common
import common
def _return_type_cpp(java_type):
if converted_type := java_type.converted_type:
return converted_type
if java_type.is_primitive():
return java_type.to_cpp()
return f'jni_zero::ScopedJavaLocalRef<{java_type.to_cpp()}>'
def _param_type_cpp(java_type, use_param_type):
if converted_type := java_type.converted_type:
# Drop & when the type is obviously a pointer to avoid "const char *&".
if not java_type.is_primitive() and not converted_type.endswith('*'):
converted_type += '&'
return converted_type
ret = java_type.to_cpp()
if java_type.is_primitive():
return ret
if use_param_type:
return f'const jni_zero::JavaParamRef<{ret}>&'
else:
return f'const jni_zero::JavaRef<{ret}>&'
def _impl_forward_declaration(sb, native, params):
sb('// Forward declaration. To be implemented by the including .cc file.\n')
with sb.statement():
name = f'JNI_{native.java_class.name}_{native.capitalized_name}'
sb(f'static {_return_type_cpp(native.return_type)} {name}')
with sb.param_list() as plist:
plist.append('JNIEnv* env')
if not native.static:
plist.append('const jni_zero::JavaParamRef<jobject>& jcaller')
plist.extend(f'{_param_type_cpp(p.java_type, True)} {p.cpp_name()}'
for p in params)
def _entry_point_example(sb, native):
if native.first_param_cpp_type:
name = f'{native.first_param_cpp_type}::'
params = native.params[1:]
else:
name = f'JNI_{native.java_class.name}_'
params = native.params
name += native.capitalized_name
with sb.statement():
if not native.first_param_cpp_type:
sb('static ')
sb(f'{_return_type_cpp(native.return_type)} {name}')
with sb.param_list() as plist:
plist.append('JNIEnv* env')
if not native.static:
plist.append('const jni_zero::JavaRef<jobject>& jcaller')
plist.extend(f'{_param_type_cpp(p.java_type, False)} {p.cpp_name()}'
for p in params)
def _prep_param(sb, is_proxy, param):
"""Returns the snippet to use for the parameter."""
orig_name = param.cpp_name()
java_type = param.java_type
if java_type.converted_type:
ret = f'{param.name}_converted'
with sb.statement():
sb(f'{java_type.converted_type} {ret} = ')
convert_type.from_jni_expression(sb, orig_name, java_type)
return ret
if java_type.is_primitive():
return orig_name
if is_proxy and java_type.to_cpp() != java_type.to_proxy().to_cpp():
# E.g. jobject -> jstring
orig_name = f'static_cast<{java_type.to_cpp()}>({orig_name})'
return f'jni_zero::JavaParamRef<{java_type.to_cpp()}>(env, {orig_name})'
def entry_point_declaration(sb, jni_mode, jni_obj, native, gen_jni_class):
"""The method called by JNI, or by multiplexing methods."""
if jni_mode.is_muxing and native.is_proxy:
# In this case, it's not the symbol that JNI resolves, but the one the
# switch table jumps to.
function_name = native.muxed_entry_point_name
define = 'JNI_ZERO_MUXED_ENTRYPOINT'
else:
function_name = native.boundary_name_cpp(jni_mode,
gen_jni_class=gen_jni_class)
define = 'JNI_ZERO_BOUNDARY_EXPORT'
return_type_cpp = native.entry_point_return_type.to_cpp()
params = native.entry_point_params(jni_mode)
sb(f'{define} {return_type_cpp} {function_name}')
with sb.param_list() as plist:
plist.append('JNIEnv* env')
if not jni_mode.is_muxing:
# The jclass param is never used, so do not bother adding it since muxed
# entry points are not boundary (JNI) methods.
jtype = 'jclass' if native.static else 'jobject'
plist.append(f'{jtype} jcaller')
plist.extend(f'{p.java_type.to_cpp()} {p.cpp_name()}' for p in params)
def entry_point_method(sb,
jni_mode,
jni_obj,
native,
gen_jni_class,
include_forward_declaration=False):
"""The method called by JNI, or by multiplexing methods."""
params = native.params
cpp_class = native.first_param_cpp_type
if cpp_class:
params = params[1:]
if include_forward_declaration:
# Only non-class methods need to be forward-declared.
if not cpp_class:
_impl_forward_declaration(sb, native, params)
sb('\n')
entry_point_declaration(sb, jni_mode, jni_obj, native, gen_jni_class)
entry_point_return_type = native.entry_point_return_type
return_type = native.return_type
with sb.block(after='\n'):
param_rvalues = [
_prep_param(sb, native.is_proxy, param) for param in params
]
with sb.statement():
if not return_type.is_void():
sb('auto _ret = ')
if cpp_class:
sb(f'reinterpret_cast<{cpp_class}*>({native.params[0].cpp_name()})'
f'->{native.capitalized_name}')
else:
sb(f'JNI_{native.java_class.name}_{native.capitalized_name}')
with sb.param_list() as plist:
plist.append('env')
if not native.static:
plist.append('jni_zero::JavaParamRef<jobject>(env, jcaller)')
plist.extend(param_rvalues)
if return_type.is_void():
return
if not return_type.converted_type:
if return_type.is_primitive():
sb('return _ret;\n')
else:
# Use ReleaseLocal() to ensure we are not calling .Release() on a
# global ref. https://crbug.com/40944912
sb('return _ret.ReleaseLocal();\n')
return
with sb.statement():
sb('jobject converted_ret = ')
if native.needs_implicit_array_element_class_param:
clazz_snippet = f'static_cast<jclass>({native.proxy_params[-1].name})'
else:
clazz_snippet = None
convert_type.to_jni_expression(sb,
'_ret',
return_type,
clazz_snippet=clazz_snippet)
sb('.Release()')
with sb.statement():
sb('return ')
if entry_point_return_type.to_cpp() != 'jobject':
sb(f'static_cast<{entry_point_return_type.to_cpp()}>(converted_ret)')
else:
sb('converted_ret')
def natives_macro_definition(sb, jni_mode, jni_obj, gen_jni_class, *,
enable_definition_macros):
macro_name = f'DEFINE_JNI_FOR_{jni_obj.java_class.name}'
if enable_definition_macros and jni_obj.natives:
with sb.section(
'Example signatures (to be implemented by #including file).'):
with sb.commented_section():
for native in jni_obj.natives:
_entry_point_example(sb, native)
sb('\n')
with sb.section('Java to native functions'):
with sb.cpp_macro(macro_name):
# Anonymous namespace to scope the "using namespace" declaration.
# All symbols use extern "C", so it doesn't actually hide symbols.
with sb.namespace(''):
if jni_obj.jni_namespace:
sb(f'using namespace {jni_obj.jni_namespace};\n')
for native in jni_obj.natives:
entry_point_method(sb, jni_mode, jni_obj, native, gen_jni_class)
elif enable_definition_macros:
sb(f'// There are no Java->Native methods.\n')
sb(f'#define {macro_name}()\n')
else:
sb(f'// Macro for transition to enable_definition_macros=true.\n')
sb(f'#define {macro_name}()\n')
def multiplexing_boundary_method(sb, muxed_aliases, gen_jni_class):
"""The method called by JNI when multiplexing is enabled."""
native = muxed_aliases[0]
sig = native.muxed_signature
has_switch_num = native.muxed_switch_num != -1
boundary_name_cpp = native.boundary_name_cpp(common.JniMode.MUXING,
gen_jni_class=gen_jni_class)
sb(f'JNI_ZERO_BOUNDARY_EXPORT {sig.return_type.to_cpp()} {boundary_name_cpp}')
param_names = []
with sb.param_list() as plist:
plist += ['JNIEnv* env', 'jclass jcaller']
if has_switch_num:
plist.append('jint switch_num')
param_names += ['env']
for i, p in enumerate(sig.param_list):
plist.append(f'{p.java_type.to_cpp()} p{i}')
param_names.append(f'p{i}')
param_call_str = ', '.join(param_names)
with sb.block():
if not has_switch_num:
sb(f'return {native.muxed_entry_point_name}({param_call_str});\n')
else:
num_aliases = len(muxed_aliases)
sb(f'JNI_ZERO_DCHECK(switch_num >= 0 && switch_num < {num_aliases});\n')
sb('switch (switch_num)')
with sb.block():
for native in muxed_aliases:
sb(f'case {native.muxed_switch_num}:\n')
sb(f' return {native.muxed_entry_point_name}({param_call_str});\n')
sb('default:\n')
sb(' __builtin_unreachable();\n')