blob: 131c862299d2dcc28595d157d3a94d4bfc544e9b [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2020 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unittests for the compiler module."""
import os
import parser # pylint: disable=wrong-import-order
import random
import shutil
import tempfile
import unittest
from importlib import resources
import arch
import bpf
import compiler
ARCH_64 = arch.Arch.load_from_json_bytes(
resources.files("testdata").joinpath("arch_64.json").read_bytes()
)
class CompileFilterStatementTests(unittest.TestCase):
"""Tests for PolicyCompiler.compile_filter_statement."""
def setUp(self):
self.arch = ARCH_64
self.compiler = compiler.PolicyCompiler(self.arch)
def _compile(self, line):
with tempfile.NamedTemporaryFile(mode="w") as policy_file:
policy_file.write(line)
policy_file.flush()
policy_parser = parser.PolicyParser(
self.arch, kill_action=bpf.KillProcess()
)
parsed_policy = policy_parser.parse_file(policy_file.name)
assert len(parsed_policy.filter_statements) == 1
return self.compiler.compile_filter_statement(
parsed_policy.filter_statements[0],
kill_action=bpf.KillProcess(),
)
def test_allow(self):
"""Accept lines where the syscall is accepted unconditionally."""
block = self._compile("read: allow")
self.assertEqual(block.filter, None)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 1)[1],
"ALLOW",
)
def test_arg0_eq_generated_code(self):
"""Accept lines with an argument filter with ==."""
block = self._compile("read: arg0 == 0x100")
# It might be a bit brittle to check the generated code in each test
# case instead of just the behavior, but there should be at least one
# test where this happens.
self.assertEqual(
block.filter.instructions,
[
bpf.SockFilter(
bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS,
0,
0,
bpf.arg_offset(0, True),
),
# Jump to KILL_PROCESS if the high word does not match.
bpf.SockFilter(bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 0, 2, 0),
bpf.SockFilter(
bpf.BPF_LD | bpf.BPF_W | bpf.BPF_ABS,
0,
0,
bpf.arg_offset(0, False),
),
# Jump to KILL_PROCESS if the low word does not match.
bpf.SockFilter(
bpf.BPF_JMP | bpf.BPF_JEQ | bpf.BPF_K, 1, 0, 0x100
),
bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_KILL_PROCESS),
bpf.SockFilter(bpf.BPF_RET, 0, 0, bpf.SECCOMP_RET_ALLOW),
],
)
def test_arg0_comparison_operators(self):
"""Accept lines with an argument filter with comparison operators."""
biases = (-1, 0, 1)
# For each operator, store the expectations of simulating the program
# against the constant plus each entry from the |biases| array.
cases = (
("==", ("KILL_PROCESS", "ALLOW", "KILL_PROCESS")),
("!=", ("ALLOW", "KILL_PROCESS", "ALLOW")),
("<", ("ALLOW", "KILL_PROCESS", "KILL_PROCESS")),
("<=", ("ALLOW", "ALLOW", "KILL_PROCESS")),
(">", ("KILL_PROCESS", "KILL_PROCESS", "ALLOW")),
(">=", ("KILL_PROCESS", "ALLOW", "ALLOW")),
)
for operator, expectations in cases:
block = self._compile("read: arg0 %s 0x100" % operator)
# Check the filter's behavior.
for bias, expectation in zip(biases, expectations):
self.assertEqual(
block.simulate(
self.arch.arch_nr,
self.arch.syscalls["read"],
0x100 + bias,
)[1],
expectation,
)
def test_arg0_mask_operator(self):
"""Accept lines with an argument filter with &."""
block = self._compile("read: arg0 & 0x3")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 1)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 2)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 3)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 4)[1],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 5)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 6)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 7)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 8)[1],
"KILL_PROCESS",
)
def test_arg0_in_operator(self):
"""Accept lines with an argument filter with in."""
block = self._compile("read: arg0 in 0x3")
# The 'in' operator only ensures that no bits outside the mask are set,
# which means that 0 is always allowed.
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 1)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 2)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 3)[1],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 4)[1],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 5)[1],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 6)[1],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 7)[1],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 8)[1],
"KILL_PROCESS",
)
def test_arg0_short_gt_ge_comparisons(self):
"""Ensure that the short comparison optimization kicks in."""
if self.arch.bits == 32:
return
short_constant_str = "0xdeadbeef"
short_constant = int(short_constant_str, base=0)
long_constant_str = "0xbadc0ffee0ddf00d"
long_constant = int(long_constant_str, base=0)
biases = (-1, 0, 1)
# For each operator, store the expectations of simulating the program
# against the constant plus each entry from the |biases| array.
cases = (
("<", ("ALLOW", "KILL_PROCESS", "KILL_PROCESS")),
("<=", ("ALLOW", "ALLOW", "KILL_PROCESS")),
(">", ("KILL_PROCESS", "KILL_PROCESS", "ALLOW")),
(">=", ("KILL_PROCESS", "ALLOW", "ALLOW")),
)
for operator, expectations in cases:
short_block = self._compile(
"read: arg0 %s %s" % (operator, short_constant_str)
)
long_block = self._compile(
"read: arg0 %s %s" % (operator, long_constant_str)
)
# Check that the emitted code is shorter when the high word of the
# constant is zero.
self.assertLess(
len(short_block.filter.instructions),
len(long_block.filter.instructions),
)
# Check the filter's behavior.
for bias, expectation in zip(biases, expectations):
self.assertEqual(
long_block.simulate(
self.arch.arch_nr,
self.arch.syscalls["read"],
long_constant + bias,
)[1],
expectation,
)
self.assertEqual(
short_block.simulate(
self.arch.arch_nr,
self.arch.syscalls["read"],
short_constant + bias,
)[1],
expectation,
)
def test_and_or(self):
"""Accept lines with a complex expression in DNF."""
block = self._compile("read: arg0 == 0 && arg1 == 0 || arg0 == 1")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0, 0)[
1
],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0, 1)[
1
],
"KILL_PROCESS",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 1, 0)[
1
],
"ALLOW",
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 1, 1)[
1
],
"ALLOW",
)
def test_trap(self):
"""Accept lines that trap unconditionally."""
block = self._compile("read: trap")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"TRAP",
)
def test_ret_errno(self):
"""Accept lines that return errno."""
block = self._compile("read : arg0 == 0 || arg0 == 1 ; return 1")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[
1:
],
("ERRNO", 1),
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 1)[
1:
],
("ERRNO", 1),
)
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 2)[1],
"KILL_PROCESS",
)
def test_ret_errno_unconditionally(self):
"""Accept lines that return errno unconditionally."""
block = self._compile("read: return 1")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[
1:
],
("ERRNO", 1),
)
def test_trace(self):
"""Accept lines that trace unconditionally."""
block = self._compile("read: trace")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"TRACE",
)
def test_user_notify(self):
"""Accept lines that notify unconditionally."""
block = self._compile("read: user-notify")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"USER_NOTIF",
)
def test_log(self):
"""Accept lines that log unconditionally."""
block = self._compile("read: log")
self.assertEqual(
block.simulate(self.arch.arch_nr, self.arch.syscalls["read"], 0)[1],
"LOG",
)
def test_mmap_write_xor_exec(self):
"""Accept the idiomatic filter for mmap."""
block = self._compile(
"read : arg0 in ~PROT_WRITE || arg0 in ~PROT_EXEC"
)
prot_exec_and_write = 6
for prot in range(0, 0xF):
if (prot & prot_exec_and_write) == prot_exec_and_write:
self.assertEqual(
block.simulate(
self.arch.arch_nr, self.arch.syscalls["read"], prot
)[1],
"KILL_PROCESS",
)
else:
self.assertEqual(
block.simulate(
self.arch.arch_nr, self.arch.syscalls["read"], prot
)[1],
"ALLOW",
)
class CompileFileTests(unittest.TestCase):
"""Tests for PolicyCompiler.compile_file."""
def setUp(self):
self.arch = ARCH_64
self.compiler = compiler.PolicyCompiler(self.arch)
self.tempdir = tempfile.mkdtemp()
def tearDown(self):
shutil.rmtree(self.tempdir)
def _write_file(self, filename, contents):
"""Helper to write out a file for testing."""
path = os.path.join(self.tempdir, filename)
with open(path, "w", encoding="utf-8") as outf:
outf.write(contents)
return path
def test_compile(self):
"""Ensure compilation works with all strategies."""
self._write_file(
"test.frequency",
"""
read: 1
close: 10
""",
)
path = self._write_file(
"test.policy",
"""
@frequency ./test.frequency
read: 1
close: 1
""",
)
program = self.compiler.compile_file(
path,
optimization_strategy=compiler.OptimizationStrategy.LINEAR,
kill_action=bpf.KillProcess(),
)
self.assertGreater(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["read"],
0,
)[0],
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["close"],
0,
)[0],
)
def test_compile_bst(self):
"""Ensure compilation with BST is cheaper than the linear model."""
self._write_file(
"test.frequency",
"""
read: 1
close: 10
""",
)
path = self._write_file(
"test.policy",
"""
@frequency ./test.frequency
read: 1
close: 1
""",
)
for strategy in list(compiler.OptimizationStrategy):
program = self.compiler.compile_file(
path,
optimization_strategy=strategy,
kill_action=bpf.KillProcess(),
)
self.assertGreater(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["read"],
0,
)[0],
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["close"],
0,
)[0],
)
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["read"],
0,
)[1],
"ALLOW",
)
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["close"],
0,
)[1],
"ALLOW",
)
def test_compile_empty_file(self):
"""Accept empty files."""
path = self._write_file(
"test.policy",
"""
@default kill-thread
""",
)
for strategy in list(compiler.OptimizationStrategy):
program = self.compiler.compile_file(
path,
optimization_strategy=strategy,
kill_action=bpf.KillProcess(),
)
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls["read"],
0,
)[1],
"KILL_THREAD",
)
def test_compile_simulate(self):
"""Ensure policy reflects script by testing some random scripts."""
iterations = 5
for i in range(iterations):
num_entries = 64 * (i + 1) // iterations
syscalls = dict(
zip(
random.sample(list(self.arch.syscalls.keys()), num_entries),
(random.randint(1, 1024) for _ in range(num_entries)),
)
)
frequency_contents = "\n".join(
"%s: %d" % s for s in syscalls.items()
)
policy_contents = "@frequency ./test.frequency\n" + "\n".join(
"%s: 1" % s[0] for s in syscalls.items()
)
self._write_file("test.frequency", frequency_contents)
path = self._write_file("test.policy", policy_contents)
for strategy in list(compiler.OptimizationStrategy):
program = self.compiler.compile_file(
path,
optimization_strategy=strategy,
kill_action=bpf.KillProcess(),
)
for name, number in self.arch.syscalls.items():
expected_result = (
"ALLOW" if name in syscalls else "KILL_PROCESS"
)
self.assertEqual(
bpf.simulate(
program.instructions, self.arch.arch_nr, number, 0
)[1],
expected_result,
(
"syscall name: %s, syscall number: %d, "
"strategy: %s, policy:\n%s"
)
% (name, number, strategy, policy_contents),
)
@unittest.skipIf(not int(os.getenv("SLOW_TESTS", "0")), "slow")
def test_compile_huge_policy(self):
"""Ensure jumps while compiling a huge policy are still valid."""
# Given that the BST strategy is O(n^3), don't choose a crazy large
# value, but it still needs to be around 128 so that we exercise the
# codegen paths that depend on the length of the jump.
#
# Immediate jump offsets in BPF comparison instructions are limited to
# 256 instructions, so given that every syscall filter consists of a
# load and jump instructions, with 128 syscalls there will be at least
# one jump that's further than 256 instructions.
num_entries = 128
syscalls = dict(random.sample(self.arch.syscalls.items(), num_entries))
# Here we force every single filter to be distinct. Otherwise the
# codegen layer will coalesce filters that compile to the same
# instructions.
policy_contents = "\n".join(
"%s: arg0 == %d" % s for s in syscalls.items()
)
path = self._write_file("test.policy", policy_contents)
program = self.compiler.compile_file(
path,
optimization_strategy=compiler.OptimizationStrategy.BST,
kill_action=bpf.KillProcess(),
)
for name, number in self.arch.syscalls.items():
expected_result = "ALLOW" if name in syscalls else "KILL_PROCESS"
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls[name],
number,
)[1],
expected_result,
)
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls[name],
number + 1,
)[1],
"KILL_PROCESS",
)
def test_compile_huge_filter(self):
"""Ensure jumps while compiling a huge policy are still valid."""
# This is intended to force cases where the AST visitation would result
# in a combinatorial explosion of calls to Block.accept(). An optimized
# implementation should be O(n).
num_entries = 128
syscalls = {}
# Here we force every single filter to be distinct. Otherwise the
# codegen layer will coalesce filters that compile to the same
# instructions.
policy_contents = []
for name in random.sample(list(self.arch.syscalls.keys()), num_entries):
values = random.sample(range(1024), num_entries)
syscalls[name] = values
policy_contents.append(
"%s: %s"
% (name, " || ".join("arg0 == %d" % value for value in values))
)
path = self._write_file("test.policy", "\n".join(policy_contents))
program = self.compiler.compile_file(
path,
optimization_strategy=compiler.OptimizationStrategy.LINEAR,
kill_action=bpf.KillProcess(),
)
for name, values in syscalls.items():
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls[name],
random.choice(values),
)[1],
"ALLOW",
)
self.assertEqual(
bpf.simulate(
program.instructions,
self.arch.arch_nr,
self.arch.syscalls[name],
1025,
)[1],
"KILL_PROCESS",
)
if __name__ == "__main__":
unittest.main()