/*
 * Copyright 2012, Google Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are
 * met:
 *
 *    * Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *    * Redistributions in binary form must reproduce the above
 *      copyright notice, this list of conditions and the following
 *      disclaimer in the documentation and/or other materials provided
 *      with the distribution.
 *    * Neither the name of Google Inc. nor the names of its
 *      contributors may be used to endorse or promote products derived
 *      from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

#include <limits.h>
#include <setjmp.h>
#include <stdarg.h>
#include <stddef.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>

// cmocka doesn't include some headers it uses, e.g. setjmp. Prevent
// clang-format from putting the headers in order, so it gets the above
// includes.
// clang-format off
#include <cmocka.h>
// clang-format on

#include "mosys/file_backed_range.h"
#include "mosys/platform.h"

#include "mosys/globals.h"

#include "intf/io.h"

// Storage for the file_backed_range names (see fn. comment for
// build_test_intf).
static char file_backed_range_path_buf[PATH_MAX];

// Build a struct platform_intf for use in tests. The first file_backed_range
// can be configured with params
// - <file_backed_range_size>: The size of the range (it is assumed the range
// starts at 0).
// - <file_backed_range_file_name>: The name of the file in the
// "unittests/testdata/io_unittest" directory. E.g. if the arg is "/dev/port",
// the "unittests/testdata/io_unittest/dev/port" file will be used.
static struct platform_intf
build_test_intf(uint64_t file_backed_range_size,
		const char *file_backed_range_file_name)
{
	const char *test_ids[] = {
	    "TEST",
	    NULL,
	};
	struct platform_op test_ops = {
	    .io = &io_intf,
	};
	struct platform_cb test_cbs;

	struct platform_intf intf = {
	    .type = PLATFORM_DEFAULT,
	    .name = "UNITTEST",
	    .id_list = test_ids,
	    .op = &test_ops,
	    .cb = &test_cbs,
	};

	assert_int_equal(0, intf.op->io->setup(&intf));

	char *src = getenv("SRC");
	assert_non_null(src);

	// Note that the leading / needs to be removed for the file_backed_range
	// file_name; thus, cwd[1] is used.
	snprintf(file_backed_range_path_buf, PATH_MAX,
		 "%s/unittests/testdata/io_unittest%s", &src[1],
		 file_backed_range_file_name);

	struct file_backed_range *first_range = &intf.op->io->ranges[0];
	first_range->range.end = file_backed_range_size;
	first_range->file_name = file_backed_range_path_buf;

	return intf;
}

/* Test a range with an address out of bounds */
static void bad_address(void **state)
{
	struct platform_intf intf =
	    build_test_intf(/*file_backed_range_size=*/0x10,
			    /*file_backed_range_file_name=*/"/dev/port");
	uint8_t data;

	// A read / write within the range works.
	assert_int_equal(0, intf.op->io->read(&intf, 0x01, IO_ACCESS_8, &data));

	assert_int_equal(0,
			 intf.op->io->write(&intf, 0x01, IO_ACCESS_8, &data));

	// A read / write outside the range fails.
	assert_int_equal(-1,
			 intf.op->io->read(&intf, 0x20, IO_ACCESS_8, &data));

	assert_int_equal(-1,
			 intf.op->io->write(&intf, 0x20, IO_ACCESS_8, &data));
}

/* Test a non-existing dev file */
static void open_nonexistent_file(void **state)
{
	struct platform_intf intf =
	    build_test_intf(/*file_backed_range_size=*/0x10000,
			    /*file_backed_range_file_name=*/"/dev/nonexistent");
	uint8_t data;

	assert_int_equal(-1, intf.op->io->read(&intf, 0, IO_ACCESS_8, &data));
}

/* Test a bad read by having a range larger than the actual dev file */
static void read_eof(void **state)
{
	struct platform_intf intf =
	    build_test_intf(/*file_backed_range_size=*/0x10000,
			    /*file_backed_range_file_name=*/"/dev/port");
	uint8_t data;

	// The test file is 4 bytes. A read a 3 bytes works, a read beyond that
	// fails.
	assert_int_equal(0, intf.op->io->read(&intf, 0x03, IO_ACCESS_8, &data));

	assert_int_equal(-1,
			 intf.op->io->read(&intf, 0xffff, IO_ACCESS_8, &data));
}

static void io_read_test(void **state)
{
	struct platform_intf intf =
	    build_test_intf(/*file_backed_range_size=*/0x10000,
			    /*file_backed_range_file_name=*/"/dev/port");

	int ret;
	uint8_t data8;
	uint16_t data16;
	uint32_t data32;

	/* 8 bit read */
	data8 = 0xff;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	assert_int_equal(0x00, (int)data8);
	/* now through inlines */
	data8 = 0xff;
	ret = io_read(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	assert_int_equal(0x00, (int)data8);
	data8 = 0xff;
	ret = io_read8(&intf, 0, &data8);
	assert_int_equal(0, ret);
	assert_int_equal(0x00, (int)data8);

	/* 16 bit read */
	data16 = 0xffff;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_16, &data16);
	assert_int_equal(0, ret);
	assert_int_equal(0x0100, data16);
	/* now through inlines */
	data16 = 0xffff;
	ret = io_read16(&intf, 0, &data16);
	assert_int_equal(0, ret);
	assert_int_equal(0x0100, data16);

	/* 32 bit read */
	data32 = 0xffffffff;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_32, &data32);
	assert_int_equal(0, ret);
	assert_int_equal(0x03020100, data32);
	/* now through inlines */
	data32 = 0xffffffff;
	ret = io_read32(&intf, 0, &data32);
	assert_int_equal(0, ret);
	assert_int_equal(0x03020100, data32);
}

static void io_write_test(void **state)
{
	struct platform_intf intf =
	    build_test_intf(/*file_backed_range_size=*/0x10000,
			    /*file_backed_range_file_name=*/"/dev/port");

	int ret;
	uint8_t data8, orig8;
	uint16_t data16, orig16;
	uint32_t data32, orig32;

	/* 8 bit write */
	data8 = 0xff;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_8, &orig8);
	assert_int_equal(0, ret);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	data8 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	assert_int_equal(0xff, (int)data8);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_8, &orig8);
	assert_int_equal(0, ret);
	/* now through inlines */
	data8 = 0xff;
	ret = io_write(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	data8 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	assert_int_equal(0xff, (int)data8);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_8, &orig8);
	assert_int_equal(0, ret);
	data8 = 0xff;
	ret = io_write8(&intf, 0, data8);
	assert_int_equal(0, ret);
	data8 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_8, &data8);
	assert_int_equal(0, ret);
	assert_int_equal(0xff, (int)data8);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_8, &orig8);
	assert_int_equal(0, ret);

	/* 16 bit write */
	data16 = 0xffff;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_16, &orig16);
	assert_int_equal(0, ret);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_16, &data16);
	assert_int_equal(0, ret);
	data16 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_16, &data16);
	assert_int_equal(0, ret);
	assert_int_equal(0xffff, data16);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_16, &orig16);
	assert_int_equal(0, ret);
	/* now through inlines */
	data16 = 0xffff;
	ret = io_write16(&intf, 0, data16);
	assert_int_equal(0, ret);
	data16 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_16, &data16);
	assert_int_equal(0, ret);
	assert_int_equal(0xffff, data16);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_16, &orig16);
	assert_int_equal(0, ret);

	/* 32 bit write */
	data32 = 0xffffffff;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_32, &orig32);
	assert_int_equal(0, ret);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_32, &data32);
	assert_int_equal(0, ret);
	data32 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_32, &data32);
	assert_int_equal(0, ret);
	assert_int_equal(0xffffffff, data32);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_32, &orig32);
	assert_int_equal(0, ret);
	/* now through inlines */
	data32 = 0xffffffff;
	ret = io_write32(&intf, 0, data32);
	assert_int_equal(0, ret);
	data32 = 0;
	ret = intf.op->io->read(&intf, 0, IO_ACCESS_32, &data32);
	assert_int_equal(0, ret);
	assert_int_equal(0xffffffff, data32);
	ret = intf.op->io->write(&intf, 0, IO_ACCESS_32, &orig32);
	assert_int_equal(0, ret);
}

int main(void)
{
	const struct CMUnitTest tests[] = {
	    cmocka_unit_test(bad_address),
	    cmocka_unit_test(open_nonexistent_file),
	    cmocka_unit_test(read_eof),
	    cmocka_unit_test(io_read_test),
	    cmocka_unit_test(io_write_test),
	};

	return cmocka_run_group_tests(tests, /*group_setup=*/NULL,
				      /*group_teardown=*/NULL);
}
