#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Script that filters a bag file using a sequence of message filters."""

from __future__ import absolute_import, division, print_function

import argparse
import copy

import os
import re
import sys

import yaml
from argparse import ArgumentParser
from glob import glob
from shutil import copyfile

import genpy
import rosbag
from cras import pretty_file_size
from cras.string_utils import STRING_TYPE

from cras_bag_tools.bag_filter import filter_bag
from cras_bag_tools.message_filter import BAG_NAME_PATTERN, FilterChain, MessageFilter, Passthrough, get_filters
from cras_bag_tools.time_range import TimeRange, TimeRanges
from cras_bag_tools.tqdm_bag import TqdmMultiBag


def resolve_path(path_format, reference_file):
    """Resolve path template.

    The following variables are available in path_format:

    - dirname (absolute path of the reference file's directory)
    - basename (base name of the reference file, with extension)
    - name (base name of the reference file, without extension)
    - ext (extension of the reference file (including the dot); empty if it has no extension)
    - ext_no_dot (extension of the reference file (without the dot); empty if it has no extension)
    - bag_prefix (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `PREFIX`; otherwise, it is `name`).
    - bag_stamp (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `STAMP`; otherwise, it is `name`).
    - bag_suffix (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `SUFFIX`; otherwise, it is `name`).
    - bag_base (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `PREFIX_STAMP`; otherwise, it is `name`).

    :param str path_format: The path template in the format :meth:`str.format()` accepts.
    :param str reference_file: Path of the current bag file or another reference file.
    :return: The absolute resolved path.
    :rtype: str
    """
    dirname, basename = os.path.split(os.path.abspath(os.path.expanduser(reference_file)))
    name, ext = os.path.splitext(basename)

    format_vars = {
        'dirname': dirname,
        'basename': basename,
        'name': name,
        'ext': ext,
        'ext_no_dot': ext[1:] if len(ext) >= 1 and ext[0] == ':' else ext,
        'bag_prefix': name,
        'bag_stamp': name,
        'bag_suffix': name,
        'bag_base': name,
    }

    match = BAG_NAME_PATTERN.match(name)
    if match is not None:
        format_vars['bag_prefix'] = match.group(1)
        format_vars['bag_stamp'] = match.group(2)
        format_vars['bag_suffix'] = match.group(3)
        format_vars['bag_base'] = match.group(1) + '_' + match.group(2)

    resolved = path_format.format(**format_vars)
    return os.path.abspath(os.path.expanduser(resolved))


def copy_params_if_any(param_file, out_bag_path):
    """If the bagfile has a sidecar parameters file named $BAG.params, copy this file to the out path.

    :param str param_file: Path to the file with parameters (can be None).
    :param out_bag_path: Output bag path.
    """
    if param_file is None:
        return
    try:
        copyfile(param_file, out_bag_path + '.params')
        print('Params: %s copied to %s.' % (param_file, out_bag_path + '.params'))
    except (OSError, IOError) as ex:
        print('Params: %s not found.' % (param_file,))


def filter_bags(bags, out_format, compression, copy_params, filter, default_params_file=None,
                start_time=None, end_time=None, time_ranges=None, limit_to_first_bag=False, extra_bags=None):
    """Filter all given bags using the given filter.

    :param list bags: The bags to filter. If multiple bags should be read at once, add them as a single list item with
                      paths separated by colon. One of the files can also be be a YAML file with ROS parameters.
                      All sub-bags except the first one can have their name formatted similar to `out_format`.
    :param str out_format: Output path template in the format :meth:`str.format()` accepts. Available variables are:
                           `dirname`, `basename`, `name`, `ext`, `ext_no_dot`, `bag_prefix`, `bag_stamp`, `bag_suffix`,
                           `bag_base`. See :meth:`resolve_path()` for details.
    :param str compression: Output bag compression. One of 'rosbag.Compression.*' constants.
    :param bool copy_params: If True, copy parameters file along with the bag file if it exists.
    :param MessageFilter filter: The filter to apply.
    :param str default_params_file: If nonempty, specifies the YAML file with ROS parameters that is used if no param
                                    file is specified for the particular bag. The path will be resolved similar to
                                    out_format with the processed bag as reference file.
    :param genpy.Time start_time: Time from which the bag filtering should be started.
    :param genpy.Time end_time: Time to which the bag filtering should be stopped.
    :param TimeRanges time_ranges: Time ranges of the bag files to process. If start_time and end_time are specified,
                                   they are merged with these ranges. Relative time ranges will be evaluated relative
                                   to each individual bag.
    :param bool limit_to_first_bag: If True, each multibag will report its start and end to be equal to the
                                    first open bag. If False, the start and end correspond to the earliest and latest
                                    stamp in all bags.
    :param extra_bags: List of bag files that will be added to each of the processed bags. Names of these bags
                       will be resolved similar to out_format with the processed bag as reference file.
    :type extra_bags: list or str
    :return: The number of bags that failed to be processed.
    :rtype: int
    """
    num_failed = 0
    i = 0
    for bags_path in bags:
        i += 1
        bag_path = os.path.abspath(os.path.expanduser(bags_path))
        bags_paths = [bag_path]
        if os.path.pathsep in bags_path:
            bags_paths = bags_path.split(os.path.pathsep)
            bag_path = os.path.abspath(os.path.expanduser(bags_paths[0]))

        for b in range(1, len(bags_paths)):
            bags_paths[b] = resolve_path(bags_paths[b], bag_path)
        required_bags = set(bags_paths)

        extra_bags_paths = set()
        if extra_bags is not None:
            if isinstance(extra_bags, STRING_TYPE):
                extra = extra_bags.split(os.path.pathsep)
            else:
                extra = list(extra_bags)
            for b in extra:
                extra_bag_path = resolve_path(b, bag_path)
                bags_paths.append(extra_bag_path)
                extra_bags_paths.add(extra_bag_path)

        params_file = resolve_path(default_params_file, bag_path) if default_params_file is not None else None
        if params_file is None:
            for b in bags_paths:
                ext = os.path.splitext(b)[-1]
                if ext in ('.params', '.yaml', '.yml'):
                    params_file = b
                    bags_paths.remove(b)
                    break

        if params_file is None:
            for b in bags_paths:
                param = b + '.params'
                if os.path.exists(param):
                    params_file = param
                    break

        print()
        print("[{}/{}] Bag {}".format(i, len(bags), bag_path))

        bags_ok = True
        bags_to_remove = []
        for b in bags_paths:
            if not os.path.exists(b):
                if b in required_bags:
                    print('Source bag %s does not exist' % (b,), file=sys.stderr)
                    bags_ok = False
                bags_to_remove.append(b)
        if not bags_ok:
            num_failed += 1
            continue
        for b in bags_to_remove:
            bags_paths.remove(b)

        print('Source:      %s' % ("\n             ".join([os.path.abspath(b) for b in bags_paths]),))

        try:
            with TqdmMultiBag(bags_paths, skip_index=True, limit_to_first_bag=limit_to_first_bag) as bag:
                print('- Size: %s' % (pretty_file_size(bag.size),))

                ei = 0
                for b in bag.bags:
                    if b.filename in extra_bags_paths:
                        bag.add_bag_tag(b, "extra_bag")
                        bag.add_bag_tag(b, "extra_bag_" + str(ei))
                        ei += 1

                out_bag_path = resolve_path(out_format, bag_path)
                out_bag_ok = True
                for b in bags_paths:
                    if os.path.realpath(b) == os.path.realpath(out_bag_path):
                        print("Out bag path %s is the same file as input bag %s!" % (out_bag_path, b), file=sys.stderr)
                        out_bag_ok = False
                        break
                if not out_bag_ok:
                    num_failed += 1
                    continue

                try:
                    os.makedirs(os.path.dirname(out_bag_path))
                except OSError:
                    pass
                print('Destination: %s' % (out_bag_path,))

                print('- Compression: %s' % (compression,))
                if copy_params:
                    copy_params_if_any(params_file, out_bag_path)
                print()

                params = None
                if params_file is not None and os.path.exists(params_file):
                    try:
                        with open(params_file, "r") as f:
                            params = yaml.safe_load(f)
                        print("Loaded ROS parameters from file " + params_file)
                    except yaml.YAMLError as e:
                        print("Error parsing ROS parameters from file %s: %s" % (params_file, str(e)), file=sys.stderr)
                    except Exception as e:
                        print("Error loading ROS parameters from file %s: %s" % (params_file, str(e)), file=sys.stderr)

                orig_params = copy.deepcopy(params) if params is not None else None
                with rosbag.Bag(out_bag_path, 'w', compression=compression) as out:
                    bag.read_index()
                    filter_bag(bag, out, filter, params, start_time, end_time, time_ranges, print_stats=True)

                if copy_params and orig_params is not None and params != orig_params:
                    print('Saving ROS parameters to', out_bag_path + '.params')
                    try:
                        with open(out_bag_path + '.params', 'w+') as f:
                            yaml.safe_dump(params, f)
                    except Exception as ex:
                        print('Error saving ROS params to %s: %s' % (out_bag_path + '.params', str(ex)),
                              file=sys.stderr)

        except Exception as e:
            print('Error processing bag file %s: %s' % (bag_path, str(e)), file=sys.stderr)
            num_failed += 1
            import traceback
            traceback.print_exc()

    return num_failed


class TimeRangesAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        if len(values) % 2 != 0:
            raise argparse.ArgumentError(self, "Argument requires an even number of values.")
        time_ranges = []
        for i in range(0, len(values), 2):
            time_ranges.append(TimeRange(float(values[i]), float(values[i + 1])))
        setattr(namespace, self.dest, TimeRanges(time_ranges))


def main():
    parser = ArgumentParser()
    parser.add_argument('bags', nargs='*', help="The (multi)bag files to filter. A multibag is a colon-separated list "
                                                "of bag files that will be processed together (basically merged). "
                                                "The multibag can also specify at most one YAML file with parameters."
                                                "Except the first item, the names of all sub-bags in a multibag are "
                                                "treated as templates for str.format().")
    parser.add_argument('--extra-bags', dest='extra_bags', nargs='+', default=None,
                        help="Bag files added to every (multi)bag from `bags`. If an extra bag doesn't exist, "
                             "it will be ignored. The names of all extra bags are treated as templates for "
                             "str.format().")
    parser.add_argument('-c', '--config', nargs='+', help="YAML configs of filters")
    parser.add_argument('-o', '--out-format', default=argparse.SUPPRESS,
                        help='File name of the output bag. It is treated as a template for str.format(). '
                             'Defaults to "{name}.proc{ext}".')
    parser.add_argument('--lz4', dest='compression', action='store_const', const=rosbag.Compression.LZ4,
                        help="Compress the bag via LZ4")
    parser.add_argument('--bz2', dest='compression', action='store_const', const=rosbag.Compression.BZ2,
                        help="Compress the bag via BZ2 (space-efficient, but very slow)")
    parser.add_argument('-f', '--filters', nargs='+', help=argparse.SUPPRESS)
    parser.add_argument('--no-copy-params', dest='copy_params', action='store_false', default=None,
                        help="If set, no .params file will be copied.")
    parser.add_argument('--default-params-file', dest='default_params_file', type=str, default=None,
                        help="If nonempty, specifies the YAML file with ROS parameters that is used if no param file "
                             "is specified for the particular bag. The file name is treated as a template for "
                             "str.format().")
    parser.add_argument("--list-yaml-keys", dest="list_yaml_keys", action="store_true",
                        help="Print a list of all available YAML top-level keys provided by filters.")
    parser.add_argument("--list-filters", dest="list_filters", action="store_true",
                        help="Print a list of all available filters.")
    parser.add_argument("--start-time", dest="start_time", type=float, default=None,
                        help="Time from which the bag filtering should be started.")
    parser.add_argument("--end-time", dest="end_time", type=float, default=None,
                        help="Time to which the bag filtering should be stopped.")
    parser.add_argument("--time-ranges", dest="time_ranges", nargs='+', action=TimeRangesAction, default=None,
                        help="Time ranges of bags that should be processed.", metavar='START_TIME END_TIME')
    parser.add_argument("-l", "--limit-to-first-bag", dest="limit_to_first_bag", action="store_true",
                        help="Read duration only from the first bag of each multibag.")

    loaded_filters = get_filters()
    unique_filters = set(loaded_filters.values())
    for f in unique_filters:
        if hasattr(f, 'add_cli_args'):
            getattr(f, 'add_cli_args')(parser)

    default_yaml_keys = [
        'bags', 'extra_bags', 'out_format', 'compression', 'filters', 'copy_params', 'start_time', 'end_time',
        'time_ranges',
    ]

    def default_process_cli_args(filters, args):
        if hasattr(args, 'time_ranges') and args.time_ranges is not None:
            if not isinstance(args.time_ranges, TimeRanges):
                args.time_ranges = MessageFilter._parse_time_ranges(args.time_ranges)

    if "--list-filters" in sys.argv:
        for f in unique_filters:
            print("{}: {}".format(".".join([f.__module__, f.__name__]), f.__doc__))
            if f.__init__.__doc__ is not None:
                print(f.__init__.__doc__)
        sys.exit(0)

    if "--list-yaml-keys" in sys.argv:
        print("Global:")
        for arg in default_yaml_keys:
            print("  {}".format(arg))
        for f in unique_filters:
            if hasattr(f, 'yaml_config_args'):
                args = getattr(f, 'yaml_config_args')()
                if len(args) > 0:
                    print("{}: {}".format(".".join([f.__module__, f.__name__]), f.__doc__))
                    for arg in args:
                        print("  {}".format(arg))
        sys.exit(0)

    args = parser.parse_args()

    print()
    print('Command-line arguments:')
    for k, v in sorted(vars(args).items(), key=lambda kv: kv[0]):
        if v is not None:
            print('%s: %s' % (k, v))

    if args.config is None:
        args.config = []
    args.config = [list(glob(config)) for config in args.config]
    args.config = sum(args.config, [])

    print()
    print('YAML arguments:')
    yaml_keys = list(default_yaml_keys)
    for f in unique_filters:
        if hasattr(f, 'yaml_config_args'):
            yaml_keys.extend(getattr(f, 'yaml_config_args')())
    for config in args.config:
        with open(config, 'r') as f:
            cfg = yaml.safe_load(f)
            for key in yaml_keys:
                if key not in cfg:
                    continue
                if not hasattr(args, key) or getattr(args, key) is None or \
                        (isinstance(getattr(args, key), list) and len(getattr(args, key)) == 0):
                    setattr(args, key, cfg[key])
                    if key != "filters":
                        print("{}: {}".format(key, cfg[key]))

    # Process command-line args
    filters = []
    default_process_cli_args(filters, args)
    for f in unique_filters:
        if hasattr(f, 'process_cli_args'):
            getattr(f, 'process_cli_args')(filters, args)

    filter_chain = FilterChain(filters) + MessageFilter.from_config(args.filters)
    if len(filter_chain.filters) == 0:
        filter_chain = FilterChain([Passthrough()])
        print("No filters defined, using Passthrough.")

    print()
    print('Filters:')
    for f in filter_chain.filters:
        print(' - ' + str(f))

    if args.compression is None:
        args.compression = rosbag.Compression.NONE
    if "out_format" not in args or args.out_format is None:
        args.out_format = "{name}.proc{ext}"
    if args.copy_params is None:
        args.copy_params = True

    num_failed = filter_bags(
        args.bags, args.out_format, args.compression, args.copy_params, filter_chain, args.default_params_file,
        genpy.Time(args.start_time) if args.start_time is not None else None,
        genpy.Time(args.end_time) if args.end_time is not None else None, args.time_ranges,
        args.limit_to_first_bag, args.extra_bags)

    sys.exit(num_failed)


if __name__ == '__main__':
    main()
