blob: 52c296e58461245983858ba783d46bea062a0460 [file] [log] [blame]
#! /usr/bin/env python
# Copyright 2016 WebAssembly Community Group participants
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
This fuzzes the relooper using the C API.
'''
import difflib
import os
import random
import subprocess
if os.environ.get('LD_LIBRARY_PATH'):
os.environ['LD_LIBRARY_PATH'] += os.pathsep + 'lib'
else:
os.environ['LD_LIBRARY_PATH'] = 'lib'
counter = 0
while True:
# Random decisions
num = random.randint(2, 250)
density = random.random() * random.random()
max_decision = num * 20
decisions = [random.randint(1, max_decision) for x in range(num * 3)]
branches = [0] * num
defaults = [0] * num
for i in range(num):
b = set([])
bs = random.randint(1, max(1,
round(density * random.random() * (num - 1))))
for j in range(bs):
b.add(random.randint(1, num - 1))
b = list(b)
defaults[i] = random.choice(b)
b.remove(defaults[i])
branches[i] = b
optimize = random.random() < 0.5
print counter, ':', num, density, optimize
counter += 1
for temp in ['fuzz.wasm', 'fuzz.wast', 'fast.txt', 'fuzz.slow.js',
'fuzz.c']:
try:
os.unlink(temp)
except:
pass
# parts
entry = '''
var label = 0;
var state;
var decisions = %s;
var index = 0;
function check() {
if (index == decisions.length) throw 'HALT';
console.log('(i32.const ' + (-decisions[index]) + ')');
return decisions[index++];
}
''' % str(decisions)
slow = entry + '\n'
slow += 'label = 0;\n'
slow += '''
while(1) switch(label) {
'''
fast = '''
#include <assert.h>
#include <stdio.h>
#include "binaryen-c.h"
// globals: address 4 is index
// decisions are at address 8+
int main() {
BinaryenModuleRef module = BinaryenModuleCreate();
// check()
// if the end, halt
BinaryenExpressionRef halter = BinaryenIf(module,
BinaryenBinary(module,
BinaryenEqInt32(),
BinaryenLoad(module, 4, 0, 0, 0, BinaryenInt32(),
BinaryenConst(module, BinaryenLiteralInt32(4))),
BinaryenConst(module, BinaryenLiteralInt32(4 * %d)) // jumps of 4 bytes
),
BinaryenUnreachable(module),
NULL
);
// increment index
BinaryenExpressionRef incer = BinaryenStore(module,
4, 0, 0,
BinaryenConst(module, BinaryenLiteralInt32(4)),
BinaryenBinary(module,
BinaryenAddInt32(),
BinaryenLoad(module, 4, 0, 0, 0, BinaryenInt32(),
BinaryenConst(module, BinaryenLiteralInt32(4))),
BinaryenConst(module, BinaryenLiteralInt32(4))
),
BinaryenInt32()
);
// optionally, print the return value
BinaryenExpressionRef args[] = {
BinaryenBinary(module,
BinaryenSubInt32(),
BinaryenConst(module, BinaryenLiteralInt32(0)),
BinaryenLoad(module,
4, 0, 4, 0, BinaryenInt32(),
BinaryenLoad(module, 4, 0, 0, 0, BinaryenInt32(),
BinaryenConst(module, BinaryenLiteralInt32(4)))
)
)
};
BinaryenExpressionRef debugger;
if (1) debugger = BinaryenCallImport(module, "print", args, 1,
BinaryenNone());
else debugger = BinaryenNop(module);
// return the decision. need to subtract 4 that we just added,
// and add 8 since that's where we start, so overall offset 4
BinaryenExpressionRef returner = BinaryenLoad(module,
4, 0, 4, 0, BinaryenInt32(),
BinaryenLoad(module, 4, 0, 0, 0, BinaryenInt32(),
BinaryenConst(module, BinaryenLiteralInt32(4)))
);
BinaryenExpressionRef checkBodyList[] = { halter, incer, debugger,
returner };
BinaryenExpressionRef checkBody = BinaryenBlock(module,
NULL, checkBodyList, sizeof(checkBodyList) / sizeof(BinaryenExpressionRef)
);
BinaryenFunctionTypeRef i = BinaryenAddFunctionType(module, "i",
BinaryenInt32(),
NULL, 0);
BinaryenAddFunction(module, "check", i, NULL, 0, checkBody);
// contents of main() begin here
RelooperRef relooper = RelooperCreate();
''' % len(decisions)
for i in range(0, num):
slow += ' case %d: console.log("(i32.const %d)"); state = check(); \n' % (
i, i)
b = branches[i]
for j in range(len(b)):
slow += ' if (state %% %d == %d) { label = %d; break }\n' % (
len(b) + 1, j, b[j]) # TODO: split range 1-n into these options
slow += ' label = %d; break\n' % defaults[i]
use_switch = [random.random() < 0.5 for i in range(num)]
for i in range(num):
fast += '''
RelooperBlockRef b%d;
{
BinaryenExpressionRef args[] = {
BinaryenConst(module, BinaryenLiteralInt32(%d))
};
BinaryenExpressionRef list[] = {
BinaryenCallImport(module, "print", args, 1, BinaryenNone()),
BinaryenSetLocal(module, 0, BinaryenCall(module, "check", NULL, 0,
BinaryenInt32()))
};
''' % (i, i)
if use_switch[i]:
fast += '''
b%d = RelooperAddBlockWithSwitch(relooper,
BinaryenBlock(module, NULL, list, 2),
BinaryenBinary(module,
BinaryenRemUInt32(),
BinaryenGetLocal(module, 0, BinaryenInt32()),
BinaryenConst(module, BinaryenLiteralInt32(%d))
)
);
''' % (i, len(branches[i]) + 1)
else: # non-switch
fast += '''
b%d = RelooperAddBlock(relooper, BinaryenBlock(module, NULL, list, 2));
''' % i
fast += '''
}
'''
for i in range(num):
b = branches[i]
for j in range(len(b)):
if use_switch[i]:
total = len(b) + 1
values = ','.join([str(x) for x in range(random.randint(len(b) + 1,
max_decision + 2)) if x % total == j])
fast += '''
{
BinaryenIndex values[] = { %s };
RelooperAddBranchForSwitch(b%d, b%d, values,
sizeof(values) / sizeof(BinaryenIndex), NULL);
}
''' % (values, i, b[j])
else: # non-switch
fast += '''
RelooperAddBranch(b%d, b%d, BinaryenBinary(module,
BinaryenEqInt32(),
BinaryenBinary(module,
BinaryenRemUInt32(),
BinaryenGetLocal(module, 0, BinaryenInt32()),
BinaryenConst(module, BinaryenLiteralInt32(%d))
),
BinaryenConst(module, BinaryenLiteralInt32(%d))
), NULL);
''' % (i, b[j], len(b) + 1, j)
# default branch
if use_switch[i]:
fast += '''
RelooperAddBranchForSwitch(b%d, b%d, NULL, 0, NULL);
''' % (i, defaults[i])
else:
fast += '''
RelooperAddBranch(b%d, b%d, NULL, NULL);
''' % (i, defaults[i])
fast += '''
BinaryenExpressionRef body = RelooperRenderAndDispose(relooper, b0, 1,
module);
int decisions[] = { %s };
int numDecisions = sizeof(decisions)/sizeof(int);
// write out all the decisions, then the body of the function
BinaryenExpressionRef full[numDecisions + 1];
{
int i;
for (i = 0; i < numDecisions; i++) {
full[i] = BinaryenStore(module,
4, 0, 0,
BinaryenConst(module, BinaryenLiteralInt32(8 + 4 * i)),
BinaryenConst(module, BinaryenLiteralInt32(decisions[i])),
BinaryenInt32()
);
}
}
full[numDecisions] = body;
BinaryenExpressionRef all = BinaryenBlock(module, NULL, full,
numDecisions + 1);
BinaryenFunctionTypeRef v = BinaryenAddFunctionType(module, "v",
BinaryenNone(),
NULL, 0);
// locals: state, free-for-label
BinaryenType localTypes[] = { BinaryenInt32(), BinaryenInt32() };
BinaryenFunctionRef theMain = BinaryenAddFunction(module, "main", v,
localTypes, 2, all);
BinaryenSetStart(module, theMain);
// import
BinaryenType iparams[] = { BinaryenInt32() };
BinaryenFunctionTypeRef vi = BinaryenAddFunctionType(module, "vi",
BinaryenNone(),
iparams, 1);
BinaryenAddImport(module, "print", "spectest", "print", vi);
// memory
BinaryenSetMemory(module, 1, 1, "mem", NULL, NULL, NULL, 0);
// optionally, optimize
if (%d) BinaryenModuleOptimize(module);
assert(BinaryenModuleValidate(module));
// write it out
BinaryenModulePrint(module);
BinaryenModuleDispose(module);
return 0;
}
''' % (', '.join(map(str, decisions)), optimize)
slow += '}'
open('fuzz.slow.js', 'w').write(slow)
open('fuzz.c', 'w').write(fast)
print '.'
cmd = [os.environ.get('CC') or 'gcc', 'fuzz.c', '-Isrc',
'-lbinaryen', '-lasmjs',
'-lsupport', '-Llib/.', '-pthread', '-o', 'fuzz']
subprocess.check_call(cmd)
print '^'
subprocess.check_call(['./fuzz'], stdout=open('fuzz.wast', 'w'))
print '*'
fast_out = subprocess.Popen(['bin/wasm-shell', 'fuzz.wast'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE).communicate()[0]
print '-'
slow_out = subprocess.Popen(['nodejs', 'fuzz.slow.js'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE).communicate()[0]
print '_'
if slow_out != fast_out:
print ''.join([a.rstrip() + '\n' for a in difflib.unified_diff(
slow_out.split('\n'),
fast_out.split('\n'),
fromfile='slow',
tofile='fast')])
assert False