#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""Script that filters a bag file using a sequence of message filters."""

from __future__ import absolute_import, division, print_function

import argparse
import copy

import os
import sys

import yaml
from argparse import ArgumentParser
from glob import glob
from shutil import copyfile

import genpy
import rosbag
from cras import pretty_file_size

from cras_bag_tools.bag_filter import filter_bag
from cras_bag_tools.message_filter import FilterChain, MessageFilter, Passthrough, get_filters
from cras_bag_tools.time_range import TimeRange, TimeRanges
from cras_bag_tools.tqdm_bag import TqdmMultiBag


def out_path(path, fmt):
    """Resolve output path template.

    :param str path: Path of the current bag file.
    :param str fmt: The output name template.
    :return: The output path.
    :rtype: str
    """
    dirname, basename = os.path.split(path)
    name, ext = os.path.splitext(basename)
    out = fmt.format(dirname=dirname, name=name, ext=ext)
    try:
        os.makedirs(os.path.dirname(out))
    except OSError:
        pass
    return out


def copy_params_if_any(param_file, out_bag_path):
    """If the bagfile has a sidecar parameters file named $BAG.params, copy this file to the out path.

    :param str param_file: Path to the file with parameters (can be None).
    :param out_bag_path: Output bag path.
    """
    if param_file is None:
        return
    try:
        copyfile(param_file, out_bag_path + '.params')
        print('Params: %s copied to %s.' % (param_file, out_bag_path + '.params'))
    except (OSError, IOError) as ex:
        print('Params: %s not found.' % (param_file,))


def filter_bags(bags, out_format, compression, copy_params, filter, default_params_file=None,
                start_time=None, end_time=None, time_ranges=None, limit_to_first_bag=False):
    """Filter all given bags using the given filter.

    :param list bags: The bags to filter. If multiple bags should be read at once, add them as a single list item with
                      paths separated by colon. One of the files can also be be a YAML file with ROS parameters.
    :param str out_format: Output path template.
    :param str compression: Output bag compression. One of 'rosbag.Compression.*' constants.
    :param bool copy_params: If True, copy parameters file along with the bag file if it exists.
    :param MessageFilter filter: The filter to apply.
    :param str default_params_file: If nonempty, specifies the YAML file with ROS parameters that is used if no param
                                    file is specified for the particular bag.
    :param genpy.Time start_time: Time from which the bag filtering should be started.
    :param genpy.Time end_time: Time to which the bag filtering should be stopped.
    :param TimeRanges time_ranges: Time ranges of the bag files to process. If start_time and end_time are specified,
                                   they are merged with these ranges. Relative time ranges will be evaluated relative
                                   to each individual bag.
    :param bool limit_to_first_bag: If True, each multibag will report its start and end to be equal to the
                                    first open bag. If False, the start and end correspond to the earliest and latest
                                    stamp in all bags.
    """
    i = 0
    for bags_path in bags:
        i += 1
        bag_path = bags_path
        bags_paths = [bag_path]
        if os.path.pathsep in bags_path:
            bags_paths = bags_path.split(os.path.pathsep)
            bag_path = bags_paths[0]

        params_file = default_params_file
        if params_file is None:
            for b in bags_paths:
                ext = os.path.splitext(b)[-1]
                if ext in ('.params', '.yaml', '.yml'):
                    params_file = b
                    bags_paths.remove(b)
                    bags_path = os.path.pathsep.join(bags_paths)
                    break

        if params_file is None:
            for b in bags_paths:
                param = b + '.params'
                if os.path.exists(param):
                    params_file = param
                    break

        print()
        print("[{}/{}] Bag {}".format(i, len(bags), bag_path))

        print('Source:      %s' % (",".join([os.path.abspath(b) for b in bags_paths]),))
        bags_ok = True
        for b in bags_paths:
            if not os.path.exists(b):
                print('Source bag %s does not exist' % (b,), file=sys.stderr)
                bags_ok = False
        if not bags_ok:
            continue

        try:
            with TqdmMultiBag(bags_path, skip_index=True, limit_to_first_bag=limit_to_first_bag) as bag:
                print('- Size: %s' % (pretty_file_size(bag.size),))

                out_bag_path = out_path(bag_path, out_format)
                print('Destination: %s' % (os.path.abspath(out_bag_path),))
                print('- Compression: %s' % (compression,))
                if copy_params:
                    copy_params_if_any(params_file, out_bag_path)
                print()

                params = None
                if params_file is not None and os.path.exists(params_file):
                    try:
                        with open(params_file, "r") as f:
                            params = yaml.safe_load(f)
                        print("Loaded ROS parameters from file " + params_file)
                    except yaml.YAMLError as e:
                        print("Error parsing ROS parameters from file %s: %s" % (params_file, str(e)), file=sys.stderr)
                    except Exception as e:
                        print("Error loading ROS parameters from file %s: %s" % (params_file, str(e)), file=sys.stderr)

                orig_params = copy.deepcopy(params) if params is not None else None
                with rosbag.Bag(out_bag_path, 'w', compression=compression) as out:
                    bag.read_index()
                    filter_bag(bag, out, filter, params, start_time, end_time, time_ranges)

                if copy_params and params != orig_params:
                    print('Saving ROS parameters to', out_bag_path + '.params')
                    try:
                        with open(out_bag_path + '.params', 'w+') as f:
                            yaml.safe_dump(params, f)
                    except Exception as ex:
                        print('Error saving ROS params to %s: %s' % (out_bag_path + '.params', str(ex)),
                              file=sys.stderr)

        except Exception as e:
            print('Error processing bag file %s: %s' % (bags_path, str(e)), file=sys.stderr)
            import traceback
            traceback.print_exc()


class TimeRangesAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        if len(values) % 2 != 0:
            raise argparse.ArgumentError(self, "Argument requires an even number of values.")
        time_ranges = []
        for i in range(0, len(values), 2):
            time_ranges.append(TimeRange(float(values[i]), float(values[i + 1])))
        setattr(namespace, self.dest, TimeRanges(time_ranges))


def main():
    parser = ArgumentParser()
    parser.add_argument('bags', nargs='*', help="The (multi)bag files to filter.")
    parser.add_argument('-c', '--config', nargs='+', help="YAML configs of filters")
    parser.add_argument('-o', '--out-format', default=argparse.SUPPRESS,
                        help='Template for naming the output bag. Defaults to "{name}.proc{ext}"')
    parser.add_argument('--lz4', dest='compression', action='store_const', const=rosbag.Compression.LZ4,
                        help="Compress the bag via LZ4")
    parser.add_argument('--bz2', dest='compression', action='store_const', const=rosbag.Compression.BZ2,
                        help="Compress the bag via BZ2 (space-efficient, but very slow)")
    parser.add_argument('-f', '--filters', nargs='+', help=argparse.SUPPRESS)
    parser.add_argument('--no-copy-params', dest='copy_params', action='store_false', default=None,
                        help="If set, no .params file will be copied.")
    parser.add_argument('--default-params-file', dest='default_params_file', type=str, default=None,
                        help="If nonempty, specifies the YAML file with ROS parameters that is used if no param file "
                             "is specified for the particular bag.")
    parser.add_argument("--list-yaml-keys", dest="list_yaml_keys", action="store_true",
                        help="Print a list of all available YAML top-level keys provided by filters.")
    parser.add_argument("--list-filters", dest="list_filters", action="store_true",
                        help="Print a list of all available filters.")
    parser.add_argument("--start-time", dest="start_time", type=float, default=None,
                        help="Time from which the bag filtering should be started.")
    parser.add_argument("--end-time", dest="end_time", type=float, default=None,
                        help="Time to which the bag filtering should be stopped.")
    parser.add_argument("--time-ranges", dest="time_ranges", nargs='+', action=TimeRangesAction, default=None,
                        help="Time ranges of bags that should be processed.", metavar='START_TIME END_TIME')
    parser.add_argument("-l", "--limit-to-first-bag", dest="limit_to_first_bag", action="store_true",
                        help="Read duration only from the first bag of each multibag.")

    loaded_filters = get_filters()
    unique_filters = set(loaded_filters.values())
    for f in unique_filters:
        if hasattr(f, 'add_cli_args'):
            getattr(f, 'add_cli_args')(parser)

    default_yaml_keys = [
        'bags', 'out_format', 'compression', 'filters', 'copy_params', 'start_time', 'end_time', 'time_ranges',
    ]

    def default_process_cli_args(filters, args):
        if hasattr(args, 'time_ranges') and args.time_ranges is not None:
            if not isinstance(args.time_ranges, TimeRanges):
                args.time_ranges = MessageFilter._parse_time_ranges(args.time_ranges)

    if "--list-filters" in sys.argv:
        for f in unique_filters:
            print("{}: {}".format(".".join([f.__module__, f.__name__]), f.__doc__))
            if f.__init__.__doc__ is not None:
                print(f.__init__.__doc__)
        sys.exit(0)

    if "--list-yaml-keys" in sys.argv:
        print("Global:")
        for arg in default_yaml_keys:
            print("  {}".format(arg))
        for f in unique_filters:
            if hasattr(f, 'yaml_config_args'):
                args = getattr(f, 'yaml_config_args')()
                if len(args) > 0:
                    print("{}: {}".format(".".join([f.__module__, f.__name__]), f.__doc__))
                    for arg in args:
                        print("  {}".format(arg))
        sys.exit(0)

    args = parser.parse_args()

    print()
    print('Command-line arguments:')
    for k, v in sorted(vars(args).items(), key=lambda kv: kv[0]):
        if v is not None:
            print('%s: %s' % (k, v))

    if args.config is None:
        args.config = []
    args.config = [list(glob(config)) for config in args.config]
    args.config = sum(args.config, [])

    print()
    print('YAML arguments:')
    yaml_keys = list(default_yaml_keys)
    for f in unique_filters:
        if hasattr(f, 'yaml_config_args'):
            yaml_keys.extend(getattr(f, 'yaml_config_args')())
    for config in args.config:
        with open(config, 'r') as f:
            cfg = yaml.safe_load(f)
            for key in yaml_keys:
                if key not in cfg:
                    continue
                if not hasattr(args, key) or getattr(args, key) is None or \
                        (isinstance(getattr(args, key), list) and len(getattr(args, key)) == 0):
                    setattr(args, key, cfg[key])
                    if key != "filters":
                        print("{}: {}".format(key, cfg[key]))

    # Process command-line args
    filters = []
    default_process_cli_args(filters, args)
    for f in unique_filters:
        if hasattr(f, 'process_cli_args'):
            getattr(f, 'process_cli_args')(filters, args)

    filter_chain = FilterChain(filters) + MessageFilter.from_config(args.filters)
    if len(filter_chain.filters) == 0:
        filter_chain = FilterChain([Passthrough()])
        print("No filters defined, using Passthrough.")

    print()
    print('Filters:')
    for f in filter_chain.filters:
        print(' - ' + str(f))

    if args.compression is None:
        args.compression = rosbag.Compression.NONE
    if "out_format" not in args or args.out_format is None:
        args.out_format = "{name}.proc{ext}"
    if args.copy_params is None:
        args.copy_params = True

    filter_bags(args.bags, args.out_format, args.compression, args.copy_params, filter_chain, args.default_params_file,
                genpy.Time(args.start_time) if args.start_time is not None else None,
                genpy.Time(args.end_time) if args.end_time is not None else None, args.time_ranges,
                args.limit_to_first_bag)


if __name__ == '__main__':
    main()
