#!/usr/bin/env python3
##
## This file is part of the libsigrokdecode project.
##
## Copyright (C) 2013 Bert Vermeulen <bert@biot.com>
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.
##

import os
import sys
from getopt import getopt
from tempfile import mkstemp
from subprocess import Popen, PIPE
from difflib import Differ
from hashlib import md5
from shutil import copy

DEBUG = 0
VERBOSE = False


class E_syntax(Exception):
    pass
class E_badline(Exception):
    pass

def INFO(msg, end='\n'):
    if VERBOSE:
        print(msg, end=end)
        sys.stdout.flush()


def DBG(msg):
    if DEBUG:
        print(msg)


def ERR(msg):
    print(msg, file=sys.stderr)


def usage(msg=None):
    if msg:
        print(msg.strip() + '\n')
    print("""Usage: testpd [-dvarslR] [test, ...]
  -d  Turn on debugging
  -v  Verbose
  -a  All tests
  -l  List all tests
  -s  Show test(s)
  -r  Run test(s)
  -f  Fix failed test(s)
  -c  Report decoder code coverage
  -R <directory>  Save test reports to <directory>
  <test>  Protocol decoder name ("i2c") and optionally test name ("i2c/icc")""")
    sys.exit()


def check_tclist(tc):
    if 'pdlist' not in tc or not tc['pdlist']:
        return("No protocol decoders")
    if 'input' not in tc or not tc['input']:
        return("No input")
    if 'output' not in tc or not tc['output']:
        return("No output")
    for op in tc['output']:
        if 'match' not in op:
            return("No match in output")

    return None


def parse_testfile(path, pd, tc, op_type, op_class):
    DBG("Opening '%s'" % path)
    tclist = []
    for line in open(path).read().split('\n'):
        try:
            line = line.strip()
            if len(line) == 0 or line[0] == "#":
                continue
            f = line.split()
            if not tclist and f[0] != "test":
                # That can't be good.
                raise E_badline
            key = f.pop(0)
            if key == 'test':
                if len(f) != 1:
                    raise E_syntax
                # new testcase
                tclist.append({
                    'pd': pd,
                    'name': f[0],
                    'pdlist': [],
                    'output': [],
                })
            elif key == 'protocol-decoder':
                if len(f) < 1:
                    raise E_syntax
                pd_spec = {
                    'name': f.pop(0),
                    'channels': [],
                    'options': [],
                }
                while len(f):
                    if len(f) == 1:
                        # Always needs <key> <value>
                        raise E_syntax
                    a, b = f[:2]
                    f = f[2:]
                    if '=' not in b:
                        raise E_syntax
                    opt, val = b.split('=')
                    if a == 'channel':
                        try:
                            val = int(val)
                        except:
                            raise E_syntax
                        pd_spec['channels'].append([opt, val])
                    elif a == 'option':
                        pd_spec['options'].append([opt, val])
                    else:
                        raise E_syntax
                tclist[-1]['pdlist'].append(pd_spec)
            elif key == 'stack':
                if len(f) < 2:
                    raise E_syntax
                tclist[-1]['stack'] = f
            elif key == 'input':
                if len(f) != 1:
                    raise E_syntax
                tclist[-1]['input'] = f[0]
            elif key == 'output':
                op_spec = {
                    'pd': f.pop(0),
                    'type': f.pop(0),
                }
                while len(f):
                    if len(f) == 1:
                        # Always needs <key> <value>
                        raise E_syntax
                    a, b = f[:2]
                    f = f[2:]
                    if a == 'class':
                        op_spec['class'] = b
                    elif a == 'match':
                        op_spec['match'] = b
                    else:
                        raise E_syntax
                tclist[-1]['output'].append(op_spec)
            else:
                raise E_badline
        except E_badline as e:
            ERR("Invalid syntax in %s: line '%s'" % (path, line))
            return []
        except E_syntax as e:
            ERR("Unable to parse %s: unknown line '%s'" % (path, line))
            return []

    # If a specific testcase was requested, keep only that one.
    if tc is not None:
        target_tc = None
        for t in tclist:
            if t['name'] == tc:
                target_tc = t
                break
        # ...and a specific output type
        if op_type is not None:
            target_oplist = []
            for op in target_tc['output']:
                if op['type'] == op_type:
                    # ...and a specific output class
                    if op_class is None or ('class' in op and op['class'] == op_class):
                        target_oplist.append(op)
                        DBG("match on [%s]" % str(op))
            target_tc['output'] = target_oplist
        if target_tc is None:
            tclist = []
        else:
            tclist = [target_tc]
    for t in tclist:
        error = check_tclist(t)
        if error:
            ERR("Error in %s: %s" % (path, error))
            return []

    return tclist


def get_tests(testnames):
    tests = {}
    for testspec in testnames:
        # Optional testspec in the form pd/testcase/type/class
        tc = op_type = op_class = None
        ts = testspec.strip("/").split("/")
        pd = ts.pop(0)
        tests[pd] = []
        if ts:
            tc = ts.pop(0)
        if ts:
            op_type = ts.pop(0)
        if ts:
            op_class = ts.pop(0)
        path = os.path.join(decoders_dir, pd)
        if not os.path.isdir(path):
            # User specified non-existent PD
            raise Exception("%s not found." % path)
        path = os.path.join(decoders_dir, pd, "test/test.conf")
        if not os.path.exists(path):
            # PD doesn't have any tests yet
            continue
        tests[pd].append(parse_testfile(path, pd, tc, op_type, op_class))

    return tests


def diff_text(f1, f2):
    t1 = open(f1).readlines()
    t2 = open(f2).readlines()
    diff = []
    d = Differ()
    for line in d.compare(t1, t2):
        if line[:2] in ('- ', '+ '):
            diff.append(line.strip())

    return diff


def compare_binary(f1, f2):
    h1 = md5()
    h1.update(open(f1, 'rb').read())
    h2 = md5()
    h2.update(open(f2, 'rb').read())
    if h1.digest() == h2.digest():
        result = None
    else:
        result = ["Binary output does not match."]

    return result


# runtc's stdout can have lines like:
# coverage: lines=161 missed=2 coverage=99%
def parse_stats(text):
    stats = {}
    for line in text.strip().split('\n'):
        fields = line.split()
        key = fields.pop(0).strip(':')
        if key not in stats:
            stats[key] = []
        stats[key].append({})
        for f in fields:
            k, v = f.split('=')
            stats[key][-1][k] = v

    return stats


# take result set of all tests in a PD, and summarize which lines
# were not covered by any of the tests.
def coverage_sum(cvglist):
    lines = 0
    missed = 0
    missed_lines = {}
    for record in cvglist:
        lines = int(record['lines'])
        missed += int(record['missed'])
        if 'missed_lines' not in record:
            continue
        for linespec in record['missed_lines'].split(','):
            if linespec not in missed_lines:
                missed_lines[linespec] = 1
            else:
                missed_lines[linespec] += 1

    # keep only those lines that didn't show up in every non-summary record
    final_missed = []
    for linespec in missed_lines:
        if missed_lines[linespec] != len(cvglist):
            continue
        final_missed.append(linespec)

    return lines, final_missed


def run_tests(tests, fix=False):
    errors = 0
    results = []
    cmd = [os.path.join(tests_dir, 'runtc')]
    if opt_coverage:
        fd, coverage = mkstemp()
        os.close(fd)
        cmd.extend(['-c', coverage])
    else:
        coverage = None
    for pd in sorted(tests.keys()):
        pd_cvg = []
        for tclist in tests[pd]:
            for tc in tclist:
                args = cmd.copy()
                if DEBUG > 1:
                    args.append('-d')
                # Set up PD stack for this test.
                for spd in tc['pdlist']:
                    args.extend(['-P', spd['name']])
                    for label, channel in spd['channels']:
                        args.extend(['-p', "%s=%d" % (label, channel)])
                    for option, value in spd['options']:
                        args.extend(['-o', "%s=%s" % (option, value)])
                args.extend(['-i', os.path.join(dumps_dir, tc['input'])])
                for op in tc['output']:
                    name = "%s/%s/%s" % (pd, tc['name'], op['type'])
                    opargs = ['-O', "%s:%s" % (op['pd'], op['type'])]
                    if 'class' in op:
                        opargs[-1] += ":%s" % op['class']
                        name += "/%s" % op['class']
                    if VERBOSE:
                        dots = '.' * (60 - len(name) - 2)
                        INFO("%s %s " % (name, dots), end='')
                    results.append({
                        'testcase': name,
                    })
                    try:
                        fd, outfile = mkstemp()
                        os.close(fd)
                        opargs.extend(['-f', outfile])
                        DBG("Running %s" % (' '.join(args + opargs)))
                        p = Popen(args + opargs, stdout=PIPE, stderr=PIPE)
                        stdout, stderr = p.communicate()
                        if stdout:
                            # statistics and coverage data on stdout
                            results[-1].update(parse_stats(stdout.decode('utf-8')))
                        if stderr:
                            results[-1]['error'] = stderr.decode('utf-8').strip()
                            errors += 1
                        elif p.returncode != 0:
                            # runtc indicated an error, but didn't output a
                            # message on stderr about it
                            results[-1]['error'] = "Unknown error: runtc %d" % p.returncode
                        if 'error' not in results[-1]:
                            matchfile = os.path.join(decoders_dir, op['pd'], 'test', op['match'])
                            DBG("Comparing with %s" % matchfile)
                            try:
                                diff = diff_error = None
                                if op['type'] in ('annotation', 'python'):
                                    diff = diff_text(matchfile, outfile)
                                elif op['type'] == 'binary':
                                    diff = compare_binary(matchfile, outfile)
                                else:
                                    diff = ["Unsupported output type '%s'." % op['type']]
                            except Exception as e:
                                diff_error = e
                            if fix:
                                if diff or diff_error:
                                    copy(outfile, matchfile)
                                    DBG("Wrote %s" % matchfile)
                            else:
                                if diff:
                                    results[-1]['diff'] = diff
                                elif diff_error is not None:
                                    raise diff_error
                    except Exception as e:
                        results[-1]['error'] = str(e)
                    finally:
                        if coverage:
                            results[-1]['coverage_report'] = coverage
                        os.unlink(outfile)
                    if VERBOSE:
                        if 'diff' in results[-1]:
                            INFO("Output mismatch")
                        elif 'error' in results[-1]:
                            error = results[-1]['error']
                            if len(error) > 20:
                                error = error[:17] + '...'
                            INFO(error)
                        elif 'coverage' in results[-1]:
                            # report coverage of this PD
                            for record in results[-1]['coverage']:
                                # but not others used in the stack
                                # as part of the test.
                                if record['scope'] == pd:
                                    INFO(record['coverage'])
                                    break
                        else:
                            INFO("OK")
                    gen_report(results[-1])
                    if coverage:
                        os.unlink(coverage)
                        # only keep track of coverage records for this PD,
                        # not others in the stack just used for testing.
                        for cvg in results[-1]['coverage']:
                            if cvg['scope'] == pd:
                                pd_cvg.append(cvg)
        if VERBOSE and opt_coverage and len(pd_cvg) > 1:
            # report total coverage of this PD, across all the tests
            # that were done on it.
            total_lines, missed_lines = coverage_sum(pd_cvg)
            pd_coverage = 100 - (float(len(missed_lines)) / total_lines * 100)
            if VERBOSE:
                dots = '.' * (54 - len(pd) - 2)
            INFO("%s total %s %d%%" % (pd, dots, pd_coverage))

    return results, errors


def gen_report(result):
    out = []
    if 'error' in result:
        out.append("Error:")
        out.append(result['error'])
        out.append('')
    if 'diff' in result:
        out.append("Test output mismatch:")
        out.extend(result['diff'])
        out.append('')
    if 'coverage_report' in result:
        out.append(open(result['coverage_report'], 'r').read())
        out.append('')

    if out:
        text = "Testcase: %s\n" % result['testcase']
        text += '\n'.join(out)
    else:
        return

    if report_dir:
        filename = result['testcase'].replace('/', '_')
        open(os.path.join(report_dir, filename), 'w').write(text)
    else:
        print(text)


def show_tests(tests):
    for pd in sorted(tests.keys()):
        for tclist in tests[pd]:
            for tc in tclist:
                print("Testcase: %s/%s" % (tc['pd'], tc['name']))
                for pd in tc['pdlist']:
                    print("  Protocol decoder: %s" % pd['name'])
                    for label, channel in pd['channels']:
                        print("    Channel %s=%d" % (label, channel))
                    for option, value in pd['options']:
                        print("    Option %s=%d" % (option, value))
                if 'stack' in tc:
                    print("  Stack: %s" % ' '.join(tc['stack']))
                print("  Input: %s" % tc['input'])
                for op in tc['output']:
                    print("  Output:\n    Protocol decoder: %s" % op['pd'])
                    print("    Type: %s" % op['type'])
                    if 'class' in op:
                        print("    Class: %s" % op['class'])
                    print("    Match: %s" % op['match'])
            print()


def list_tests(tests):
    for pd in sorted(tests.keys()):
        for tclist in tests[pd]:
            for tc in tclist:
                for op in tc['output']:
                    line = "%s/%s/%s" % (tc['pd'], tc['name'], op['type'])
                    if 'class' in op:
                        line += "/%s" % op['class']
                    print(line)


#
# main
#

# project root
tests_dir = os.path.abspath(os.path.dirname(sys.argv[0]))
base_dir = os.path.abspath(os.path.join(os.curdir, tests_dir, os.path.pardir))
dumps_dir = os.path.abspath(os.path.join(base_dir, os.path.pardir, 'sigrok-dumps'))
decoders_dir = os.path.abspath(os.path.join(base_dir, 'decoders'))

if len(sys.argv) == 1:
    usage()

opt_all = opt_run = opt_show = opt_list = opt_fix = opt_coverage = False
report_dir = None
opts, args = getopt(sys.argv[1:], "dvarslfcR:S:")
for opt, arg in opts:
    if opt == '-d':
        DEBUG += 1
    if opt == '-v':
        VERBOSE = True
    elif opt == '-a':
        opt_all = True
    elif opt == '-r':
        opt_run = True
    elif opt == '-s':
        opt_show = True
    elif opt == '-l':
        opt_list = True
    elif opt == '-f':
        opt_fix = True
    elif opt == '-c':
        opt_coverage = True
    elif opt == '-R':
        report_dir = arg
    elif opt == '-S':
        dumps_dir = arg

if opt_run and opt_show:
    usage("Use either -s or -r, not both.")
if args and opt_all:
    usage("Specify either -a or tests, not both.")
if report_dir is not None and not os.path.isdir(report_dir):
    usage("%s is not a directory" % report_dir)

ret = 0
try:
    if args:
        testlist = get_tests(args)
    elif opt_all:
        testlist = get_tests(os.listdir(decoders_dir))
    else:
        usage("Specify either -a or tests.")

    if opt_run:
        if not os.path.isdir(dumps_dir):
            ERR("Could not find sigrok-dumps repository at %s" % dumps_dir)
            sys.exit(1)
        results, errors = run_tests(testlist, fix=opt_fix)
        ret = errors
    elif opt_show:
        show_tests(testlist)
    elif opt_list:
        list_tests(testlist)
    elif opt_fix:
        run_tests(testlist, fix=True)
    else:
        usage()
except Exception as e:
    print("Error: %s" % str(e))
    if DEBUG:
        raise

sys.exit(ret)