# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""A message filter that can decide whether a message should be kept or not, and possibly alter it."""

from __future__ import absolute_import, division, print_function

import copy
import os.path
import re
import sys
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union

import genpy
import rospkg
import rospy
import rospy.names
from cras.message_utils import raw_to_msg, msg_to_raw
from cras.plugin_utils import get_plugin_implementations
from cras.string_utils import to_str, STRING_TYPE
from std_msgs.msg import Header

from .bag_utils import BagWrapper, MultiBag, BAG_NAME_PATTERN
from .time_range import TimeRange, TimeRanges
from .topic_set import TopicSet


def is_sequence(o):
    return isinstance(o, (list, tuple))


loaded_filters = None


def get_filters():
    """Get all defined message filters.

    :return: The dictionary of (filter name => filter).
    :rtype: dict
    """
    global loaded_filters
    if loaded_filters is not None:
        return loaded_filters

    loaded_filters = {}
    for module_name, name, cls in get_plugin_implementations("cras_bag_tools", "filters", MessageFilter):
        loaded_filters[name] = cls
        loaded_filters["%s.%s" % (module_name, name)] = cls  # Also add a fully qualified name of the filter
    return loaded_filters


ConnectionHeader = Dict[STRING_TYPE, STRING_TYPE]
Tags = Set[STRING_TYPE]
RawMessage = Tuple[str, bytes, str, type]
"""datatype, data, md5sum, pytype"""
RawMessageData = Tuple[str, str, bytes, str, type, rospy.Time, ConnectionHeader, Tags]
"""topic, datatype, data, md5sum, pytype, stamp, connection_header, tags"""
RawMessageDataShort = Tuple[str, RawMessage, rospy.Time, ConnectionHeader, Tags]
"""topic, (datatype, data, md5sum, pytype), stamp, connection_header, tags"""
DeserializedMessageData = Tuple[str, genpy.Message, rospy.Time, ConnectionHeader, Tags]
"""topic, msg, stamp, connection_header, tags"""
AnyMessageData = Union[RawMessageDataShort, DeserializedMessageData]
RawFilterResult = Union[None, RawMessageData, List[Union[RawMessageData, DeserializedMessageData]]]
DeserializedFilterResult = Union[None, DeserializedMessageData, List[Union[DeserializedMessageData, RawMessageData]]]


def normalize_filter_result(result):
    # type: (RawFilterResult) -> List[Union[RawMessageData, DeserializedMessageData]]
    if not isinstance(result, list):
        result = [result]
    if len(result) == 0:
        result = [None]
    return result


def msg_long_to_short(msg_tuple):
    # type: (Union[RawMessageData, DeserializedMessageData]) -> AnyMessageData
    try:
        topic, datatype, data, md5sum, pytype, stamp, connection_header, tags = msg_tuple
        return topic, (datatype, data, md5sum, pytype), stamp, connection_header, tags
    except ValueError:
        return msg_tuple


def msg_short_to_long(msg_tuple):
    # type: (AnyMessageData) -> Union[RawMessageData, DeserializedMessageData]
    topic, msg, stamp, connection_header, tags = msg_tuple
    if isinstance(msg, genpy.Message):
        return msg_tuple

    datatype, data, md5sum, pytype = msg
    return topic, datatype, data, md5sum, pytype, stamp, connection_header, tags


class MessageTags(str, Enum):
    """These tags can be added to messages and later queried/altered by filters.

    This is not a closed set. Filters can create and use any strings they like as tags. This enum is just a collection
    of the common tags used by the library itself.

    There is a group of tags that handle the requeueing mechanism. If a message is returned with REQUEUE tag,
    it should not continue its traversal through filters, but it should be instead put back on the processing queue.
    To prevent infinite loops, each REQUEUE tag should be accompanied by tag (REQUEUES_LEFT_PREFIX + number) and the
    queue processor is responsible for "decrementing" this numbered tag every time it requeues the message.
    """

    ORIGINAL = "original"
    """This is a message that was read from the input bag(s)."""

    GENERATED = "generated"
    """This is a message generated by a filter (i.e. the message is not in the input bag(s))."""

    EXTRA_TIME_RANGE = "extra_time_range"
    """This is a message that comes from the extra time range and should be handled specially."""

    CHANGED = "changed"
    """The message has been changed by at least one filter."""

    REQUEUE = "requeue"
    """The message should be requeued. Most filters will ignore it."""

    REQUEUES_LEFT_PREFIX = "requeues:"
    """Number of requeues left for the tagged message.

    This tag should never be used in this form. Instead, tags like 'requeues:10' should be created."""

    DROP_ON_REQUEUE_TIMEOUT = "requeue-drop"
    """If the message is out of requeue attempts, it should be dropped."""

    PASS_ON_REQUEUE_TIMEOUT = "requeue-pass"
    """If the message is out of requeue attempts, it should be passed further."""

    __str__ = str.__str__
    __format__ = str.__format__


def tags_for_generated_msg(orig_tags, add_tags=None):
    # type: (Tags, Optional[Tags]) -> Tags
    """Given message tags for an input message, creates tags for a derived message.

    This means ORIGINAL and CHANGED tags are removed (if present) and GENERATED tag is added (if not present).

    :param orig_tags: Tags of the original message.
    :param add_tags: Additional tags to add.
    :return: Tags for the derived message.
    """

    new_tags = copy.deepcopy(orig_tags)
    if MessageTags.ORIGINAL in new_tags:
        new_tags.remove(MessageTags.ORIGINAL)
    if MessageTags.CHANGED in new_tags:
        new_tags.remove(MessageTags.CHANGED)
    new_tags.add(MessageTags.GENERATED)
    if add_tags:
        new_tags = new_tags.union(add_tags)
    return new_tags


def tags_for_changed_msg(orig_tags, add_tags=None):
    # type: (Tags, Optional[Tags]) -> Tags
    """Given message tags for an input message, creates tags for a derived message.

    This means ORIGINAL and CHANGED tags are removed (if present) and GENERATED tag is added (if not present).

    :param orig_tags: Tags of the original message.
    :param add_tags: Additional tags to add.
    :return: Tags for the derived message.
    """

    new_tags = copy.deepcopy(orig_tags)
    new_tags.add(MessageTags.CHANGED)
    if add_tags:
        new_tags = new_tags.union(add_tags)
    return new_tags


def tags_for_requeuing_msg(orig_tags, max_requeues=10, drop_on_timeout=False):
    # type: (Tags, int, bool) -> Tags
    """Given message tags for an input message, creates tags for requeuing the message.

    :param orig_tags: Tags of the original message.
    :param max_requeues: Maximum number of requeues.
    :param drop_on_timeout: If True, the message is dropped when the number of requeues reaches 0.
    :return: Tags for requeuing the message.
    """

    # If the message has already been requeued, do not alter it
    if MessageTags.REQUEUE in orig_tags:
        return orig_tags

    new_tags = copy.deepcopy(orig_tags)
    new_tags.add(MessageTags.REQUEUE)
    new_tags.add(MessageTags.REQUEUES_LEFT_PREFIX + str(max_requeues))
    if drop_on_timeout:
        new_tags.add(MessageTags.DROP_ON_REQUEUE_TIMEOUT)
    else:
        new_tags.add(MessageTags.PASS_ON_REQUEUE_TIMEOUT)
    return new_tags


def decrement_requeue_tag(orig_tags):
    # type: (Tags) -> Tuple[bool, Tags]
    """Given message tags for an input message, decrements the requeue counter.

    :param orig_tags: Tags of the original message.
    :return: Tuple of (some requeues left, original tags with the requeue tag decremented).
    """

    if MessageTags.REQUEUE not in orig_tags:
        return True, orig_tags

    new_tags = copy.deepcopy(orig_tags)
    requeues_left_tags = [t for t in orig_tags if t.startswith(MessageTags.REQUEUES_LEFT_PREFIX)]
    some_requeues_left = True
    for t in requeues_left_tags:
        new_tags.remove(t)
        _, requeues_left_str = t.split(':', maxsplit=1)
        new_requeues_left = int(requeues_left_str) - 1
        if new_requeues_left == 0:
            some_requeues_left = False
        else:
            new_tags.add(MessageTags.REQUEUES_LEFT_PREFIX + str(new_requeues_left))

    if not some_requeues_left:
        new_tags.remove(MessageTags.REQUEUE)

    return some_requeues_left, new_tags


def cleanup_requeue_tags(orig_tags):
    # type: (Tags) -> Tags
    """Given message tags for a previously requeued message, cleans the requeue-related tags.

    :param orig_tags: Tags of the original message.
    :return: Original tags with the requeue tags removed.
    """

    new_tags = copy.deepcopy(orig_tags)
    new_tags -= {MessageTags.REQUEUE, MessageTags.DROP_ON_REQUEUE_TIMEOUT, MessageTags.PASS_ON_REQUEUE_TIMEOUT}
    requeues_left_tags = [t for t in new_tags if t.startswith(MessageTags.REQUEUES_LEFT_PREFIX)]
    for t in requeues_left_tags:
        new_tags.remove(t)

    return new_tags


def num_used_requeues(tags, max_requeues):
    # type: (Tags, int) -> Optional[int]
    """Get the number of requeues used by a message with the given tags.

    :param tags: The tags of the message.
    :param max_requeues: The number of max_requeues with which the requeuing was started.
    :return: The number of used requeues. If the message is not requeued, returns None.
    """
    requeues_left_tags = [t for t in tags if t.startswith(MessageTags.REQUEUES_LEFT_PREFIX)]
    if len(requeues_left_tags) == 0:
        return None
    _, requeues_left_str = requeues_left_tags[0].split(':', maxsplit=1)
    requeues_left = int(requeues_left_str)
    num_requeues = max_requeues - requeues_left
    return num_requeues


class MessageFilter(object):
    """Base class for message filters. Do not implement this directly: instead implement either RawMessageFilter or
    DeserializeMessageFilter.

    The workflow of the filters is as follows:
    1. Get the topic filter and connection filter from the message filter and apply these to a bag reader (or other
       message stream provider, as this library is not bound to only processing bag messages).
    2. Read messages that satisfy the topic and connections filters.
    3. If the message does not satisfy consider_message(), it is passed further without a change. This is a kind of
       pre-filter that allows us to not deserialize messages just to tell to throw them away or pass them along.
    4. filter_message() is called. If it returns None, the message should be discarded. Otherwise, it either returns
       one message or a list of messages. The first (or only) message is considered to be the "direct followup" of the
       input message and continues going through the filter (or stops the filter if it is None). The other messages
       in the returned list should be fed into this filter again as new input messages (with the MessageTags.GENERATED
       tag).

    The filter can also override extra_time_ranges(). The "extra" time ranges are time ranges of the bagfile that
    should be read regardless of the normal start/end/min_stamp/max_stamp time ranges. Data from "extra" time ranges
    are tagged with MessageTags.EXTRA_TIME_RANGE and they are not considered by default, so the filter has to override
    consider_message() to actually receive them. This mechanism is meant to support e.g. reading latched static TFs from
    the start of the bag file even when working just on a part of the bag that does not start at the bag beginning.
    """

    def __init__(self, is_raw, include_topics=None, exclude_topics=None, include_types=None, exclude_types=None,
                 min_stamp=None, max_stamp=None, include_time_ranges=None, exclude_time_ranges=None,
                 include_tags=None, exclude_tags=None):
        """Constructor.

        :param bool is_raw: Whether the filter works on raw or deserialized messages.
        :param list include_topics: If nonempty, the filter will only work on these topics.
        :param list exclude_topics: If nonempty, the filter will skip these topics (but pass them further).
        :param list include_types: If nonempty, the filter will only work on these message types.
        :param list exclude_types: If nonempty, the filter will skip these message types (but pass them further).
        :param min_stamp: If set, the filter will only work on messages after this timestamp.
        :type min_stamp: rospy.Time or float
        :param max_stamp: If set, the filter will only work on messages before this timestamp.
        :type max_stamp: rospy.Time or float
        :param include_time_ranges: Time ranges that specify which regions of the bag should be processed.
                                    List of pairs (start, end_or_duration) or a TimeRanges object.
        :type include_time_ranges: list or TimeRanges
        :param exclude_time_ranges: Time ranges that specify which regions of the bag should be skipped.
                                    List of pairs (start, end_or_duration) or a TimeRanges object.
        :type exclude_time_ranges: list or TimeRanges
        :param list include_tags: If nonempty, the filter will only work on messages with these tags. Each element of
                                  this list is itself a set. For a message to be considered, at least one set has to be
                                  a subset of the tags list of the message.
        :param list exclude_tags: If nonempty, the filter will skip messages with these tags. Each element of
                                  this list is itself a set. For a message to be skipped, at least one set has to be
                                  a subset of the tags list of the message.
        """
        self.is_raw = is_raw
        """Whether the filter works on raw or deserialized messages."""
        self._include_topics = TopicSet(include_topics)
        """If nonempty, the filter will only work on these topics."""
        self._exclude_topics = TopicSet(exclude_topics)
        """If nonempty, the filter will skip these topics (but pass them further)."""
        self._include_types = TopicSet(include_types)
        """If nonempty, the filter will only work on these message types."""
        self._exclude_types = TopicSet(exclude_types)
        """If nonempty, the filter will skip these message types (but pass them further)."""
        self._min_stamp = rospy.Time(min_stamp) if min_stamp is not None else None
        """If set, the filter will only work on messages after this timestamp."""
        self._max_stamp = rospy.Time(max_stamp) if max_stamp is not None else None
        """If set, the filter will only work on messages before this timestamp."""
        self._bag = None
        """If this filter is working on a bag, it should be set here before the filter starts being used on the bag."""
        self._multibag = None
        """If this filter is working on a multibag, it should be set here before the filter starts being used."""
        self._params = None
        """If ROS parameters are recorded for the bag, they should be passed here."""
        self._include_time_ranges = self._parse_time_ranges(include_time_ranges)
        """Time ranges that specify which regions of the bag should be processed by this filter. If empty,
        the filter should work for all time regions except the excluded ones."""
        self._exclude_time_ranges = self._parse_time_ranges(exclude_time_ranges)
        """Time ranges that specify which regions of the bag should be skipped (but passed further)."""
        self._include_tags = \
            tuple([({t} if isinstance(t, STRING_TYPE) else set(t)) for t in include_tags]) if include_tags else tuple()
        """If nonempty, the filter will only work on messages with these tags. Each element of this tuple is itself a
        set. For a message to be considered, at least one set has to be a subset of the tags list of the message."""
        self._exclude_tags = \
            tuple([({t} if isinstance(t, STRING_TYPE) else set(t)) for t in exclude_tags]) if exclude_tags else tuple()
        """If nonempty, the filter will skip messages with these tags. Each element of this tuple is itself a
        set. For a message to be skipped, at least one set has to be a subset of the tags list of the message."""
        self._accept_requeued_messages = False
        """If True, requeued messages will be accepted by this filter."""

        self.__rospack = None

    @staticmethod
    def _parse_time_ranges(ranges):
        if ranges is None:
            return None
        return TimeRanges(TimeRange(start, end) for start, end in ranges)

    def set_bag(self, bag):
        """If this filter is working on a bag, it should be set here before the filter starts being used on the bag.

        :param rosbag.bag.Bag bag: The bag file open for reading.
        """
        self._bag = bag

        start_time = bag.get_start_time()
        if self._include_time_ranges is not None:
            self._include_time_ranges.set_base_time(start_time)
        if self._exclude_time_ranges is not None:
            self._exclude_time_ranges.set_base_time(start_time)

    def set_multibag(self, bag):
        # type: (Union[BagWrapper, MultiBag]) -> None
        """If this filter is working on a multibag, it should be set here before the filter starts being used.

        :param rosbag.bag.Bag bag: The bag file open for reading.
        """
        self._multibag = bag

    def set_params(self, params):
        """Set the ROS parameters recorded for the currently open bag file.

        :param params: The ROS parameters.
        :type params: dict
        """
        self._params = params

    def _get_param(self, param, default=None):
        """Get parameter `param` from the parameters set by :meth:`set_params`.

        :param str param: The parameter to get.
        :param default: The default value returned in case the parameter is not found.
        :return: The found parameter value or the default.
        """
        return self.__get_param(self._params, param, default)

    def __get_param(self, params, param, default=None):
        if param.startswith("/"):
            param = param[1:]
        if params is None:
            return default
        if param in params:
            return params[param]
        if "/" not in param:
            return default
        key, rest = param.split("/", 1)
        return self.__get_param(params.get(key, None), rest, default)

    def _set_param(self, param, value):
        """Set parameter `param` to the parameters set by :meth:`set_params`.

        :param str param: Name of the parameter.
        :param value: The value to set.
        """
        if self._params is None:
            self._params = {}
        self.__set_param(self._params, param, value)

    def __set_param(self, params, param, value):
        if param.startswith("/"):
            param = param[1:]
        if "/" in param:
            key, rest = param.split("/", 1)
            if key not in params:
                params[key] = {}
            self.__set_param(params[key], rest, value)
        else:
            params[param] = value

    def __call__(self, *args, **kwargs):
        """Do the filtering.

        This method properly selects the raw/deserialized filter() method and calls it.
        :return: The filtered message or None if it should be discarded. A list of messages can also be returned. In
                 such case, the first message is considered to be the "direct" continuation of the input message and
                 it should be directly used. The remaining messages are considered as additional filter inputs and
                 should be fed back to the filter.
        """
        return self.filter(*args, **kwargs)

    def filter(self, *args, **kwargs):
        # type: (...) -> Union[RawFilterResult, DeserializedFilterResult]
        """Filter the message.

        :param args: The message can be either a RawMessageData or a DeserializedMessageData tuple.
        :param kwargs:
        :return: None if the message should be discarded. The possibly changed message otherwise. Multiple messages can
                 be returned, too. In that case, the additional messages should be passed through the filter as if
                 they are newly read messages.
        """
        raise NotImplementedError()

    def consider_message(self, topic, datatype, stamp, header, tags):
        # type: (STRING_TYPE, STRING_TYPE, rospy.Time, Dict[STRING_TYPE, STRING_TYPE], Set[STRING_TYPE]) -> bool
        """This function should be called before calling filter(). If it returns False, filter() should not be called
        and the original message should be used instead.

        :param topic:
        :param datatype:
        :param stamp:
        :param header:
        :param tags:
        :return: Whether filter() should be called. If False, the message is passed to the next filter (not discarded).
        """
        # Filters have to explicitly opt-in to reading messages from extra time ranges.
        if MessageTags.EXTRA_TIME_RANGE in tags:
            return False
        if self._min_stamp is not None and stamp < self._min_stamp:
            return False
        if self._max_stamp is not None and stamp > self._max_stamp:
            return False
        if self._include_topics and topic not in self._include_topics:
            return False
        if self._exclude_topics and topic in self._exclude_topics:
            return False
        if self._include_types and datatype not in self._include_types:
            return False
        if self._exclude_types and datatype in self._exclude_types:
            return False
        if self._include_time_ranges and stamp not in self._include_time_ranges:
            return False
        if self._exclude_time_ranges and stamp in self._exclude_time_ranges:
            return False
        if self._include_tags and not any(req_tags.issubset(tags) for req_tags in self._include_tags):
            return False
        if self._exclude_tags and any(req_tags.issubset(tags) for req_tags in self._exclude_tags):
            return False
        # Filters have to explicitly opt-in to reading requeued messages.
        if not self._accept_requeued_messages and MessageTags.REQUEUE in tags:
            return False

        return True

    def connection_filter(self, topic, datatype, md5sum, msg_def, header):
        """Connection filter passed to Bag.read_messages().

        :param topic:
        :param datatype:
        :param md5sum:
        :param msg_def:
        :param header:
        :return: If False, the topic will not be read from the input bag.
        :rtype: bool
        """
        return True

    def topic_filter(self, topic):
        """Filter of topics to be read from the bag file.

        :param topic:
        :return: If False, the topic will not be read from the input bag.
        :rtype: bool
        """
        return True

    def extra_initial_messages(self):
        # type: () -> Iterable[Union[AnyMessageData, RawMessageData]]
        """Get extra messages that should be passed to the filter before the iteration over bag messages starts.

        This can be used e.g. by filters that are more generators than actual filters (i.e. they do not operate on
        existing messages, but instead create new ones based on other information).

        :note: :py:meth:`set_bag` should be called before calling this method.
        :note: Do not generate very large data. All initial messages will be stored in RAM at once.
        :return: A list or iterator of the message 5-tuples `(topic, message, stamp, connection_header, tags)`.
        """
        return []

    def extra_final_messages(self):
        # type: () -> Iterable[Union[AnyMessageData, RawMessageData]]
        """Get extra messages that should be passed to the filter after the iteration over bag messages stops.

        This can be used e.g. by filters that are more generators than actual filters (i.e. they do not operate on
        existing messages, but instead create new ones based on other information).

        :note: :py:meth:`set_bag` should be called before calling this method.
        :note: Do not generate very large data. All final messages will be stored in RAM at once.
        :return: A list or iterator of the message 5-tuples `(topic, message, stamp, connection_header, tags)`.
        """
        return []

    def extra_time_ranges(self, bags):
        """If this filter requires that certain time ranges are read from the bagfiles in any case (like the static
        TFs at the beginning), it should return the requested time range here. Such parts of the bag will be always
        passed to the filter regardless of the include/exclude time ranges of this filter and the start/end times
        of the whole bag.

        :param bags: The bags that are going to be processed.
        :type bags: rosbag.Bag or MultiBag
        :return: The time ranges of the input bags that should always be read. If None, no extra ranges are required.
        :rtype: TimeRanges or None
        """
        return None

    def __get_rospack(self):
        if self.__rospack is None:
            self.__rospack = rospkg.RosPack()
        return self.__rospack

    def resolve_file(self, filename):
        """Resolve `filename` relative to the bag set by :meth:`set_bag`. `filename` can also contain some variables in
        :meth:`str.format()` style.

        The following variables are available in `filename`:

        - dirname (absolute path of the reference file's directory)
        - basename (base name of the reference file, with extension)
        - name (base name of the reference file, without extension)
        - ext (extension of the reference file (including the dot); empty if it has no extension)
        - ext_no_dot (extension of the reference file (without the dot); empty if it has no extension)
        - bag_prefix (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `PREFIX`; otherwise, it is `name`).
        - bag_stamp (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `STAMP`; otherwise, it is `name`).
        - bag_suffix (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `SUFFIX`; otherwise, it is `name`).
        - bag_base (if `name` has the format `PREFIX_STAMP_SUFFIX`, this is `PREFIX_STAMP`; otherwise, it is `name`).

        `filename` can also contain the special syntax `$(find package_name)` which is replaced with the absolute path
        to the specified ROS package.

        :note: This is ideally called from :meth:`on_filtering_start` or :meth:`filter` because earlier, the `_bag`
               member variable is not set.
        """
        match = re.match(r'\$\(find ([^)]+)\)', filename)
        if match is not None:
            package_path = self.__get_rospack().get_path(match[1])
            filename = filename.replace('$(find %s)' % (match[1],), package_path)

        if self._bag is None or len(self._bag.filename) == 0:
            return os.path.abspath(os.path.expanduser(filename))

        reference_file = self._bag.filename
        dirname, basename = os.path.split(os.path.abspath(os.path.expanduser(reference_file)))
        name, ext = os.path.splitext(basename)

        format_vars = {
            'dirname': dirname,
            'basename': basename,
            'name': name,
            'ext': ext,
            'ext_no_dot': ext[1:] if len(ext) >= 1 and ext[0] == ':' else ext,
            'bag_prefix': name,
            'bag_stamp': name,
            'bag_suffix': name,
            'bag_base': name,
        }

        match = BAG_NAME_PATTERN.match(name)
        if match is not None:
            format_vars['bag_prefix'] = match.group(1)
            format_vars['bag_stamp'] = match.group(2)
            format_vars['bag_suffix'] = match.group(3)
            format_vars['bag_base'] = match.group(1) + '_' + match.group(2)

        resolved = os.path.expanduser(filename.format(**format_vars))
        if not os.path.isabs(resolved):
            resolved = os.path.join(dirname, resolved)
        return os.path.abspath(resolved)

    def on_filtering_start(self):
        """This function is called right before the first message is passed to filter().

        :note: Specifically, :meth:`set_params` and :meth:`set_bag` are already called at this stage.
        :note: :meth:`extra_initial_messages` will be called after calling this method.
        """
        pass

    def on_filtering_end(self):
        """This function is called right after the last message is processed by filter()."""
        pass

    def reset(self):
        """Reset the filter. This should be called e.g. before starting a new bag."""
        self._bag = None
        self._multibag = None
        self._params = None

    @staticmethod
    def add_cli_args(parser):
        """Subclasses may reimplement this static method to specify extra CLI args they provide.

        :param argparse.ArgumentParser parser: The argument parser to configure.
        """
        pass

    @staticmethod
    def process_cli_args(filters, args):
        """Subclasses may reimplement this static method to process the custom CLI args from add_cli_args().

        :param list filters: The list of loaded filters. This method can add to the list.
        :param argparse.Namespace args: The parsed args.
        """
        pass

    @staticmethod
    def yaml_config_args():
        """Subclasses may reimplement this static method to specify extra YAML keys they provide.

        These keys should correspond to the names of the CLI args and they will be read into CLI args when found in the
        YAML file.

        :return: The list of provided YAML keys.
        :rtype: list
        """
        return []

    @staticmethod
    def from_config(cfg):
        """Create a MessageFilter from a config dict.

        Other filters can be defined by 3rd-party packages via pluginlib. The package has to
        `<exec_depend>cras_bag_tools</exec_depend>` and it has to put this line in its `<export>` tag in package.xml:
        `<cras_bag_tools filters="$PACKAGE.$MODULE" />`. With this in place, `filter_bag` will search the specified
        module for all classes that subclass `cras_bag_tools.MessageFilter` and it will provide these as additional
        filters.

        :param cfg: The filter configuration. If a sequence is given, a FilterChain will be created.
        :type cfg: dict or list or tuple
        :return: The configured filter.
        :rtype: MessageFilter
        """
        if cfg is None:
            return None
        # Assume cfg is either a filter config or a sequence of such configs.
        # If it is a sequence, construct filter chain.
        if is_sequence(cfg):
            filters = []
            for d in cfg:
                f = MessageFilter.from_config(d)
                if f is not None:
                    filters.append(f)
            return FilterChain(filters)
        # Assume one of the following filter config structures:
        # {class: [args, kwargs]}.
        # {class: kwargs}.
        assert len(cfg) == 1
        k, v = list(cfg.items())[0]
        if is_sequence(v):
            args = v[0] if len(v) >= 1 else ()
            kwargs = v[1] if len(v) >= 2 else {}
        else:
            args = ()
            kwargs = v

        filters = get_filters()
        if k not in filters:
            print("Filter %s is not defined. Check that its Python module is properly exported in package.xml." % k,
                  file=sys.stderr)
            return None
        # Eval in the current environment.
        f = filters[k](*args, **kwargs)
        return f

    def __str__(self):
        return "%s(%s)" % (self.__class__.__name__, self._str_params())

    def _default_str_params(self, include_topics=True, exclude_topics=True, include_types=True, exclude_types=True,
                            min_stamp=True, max_stamp=True, include_time_ranges=True, exclude_time_ranges=True,
                            include_tags=True, exclude_tags=True):
        """Parameters to be printed when stringifying this instance. This is called by __str__().

        :return: The parameters to print.
        :rtype: str
        """
        parts = []
        if self.is_raw:
            parts.append('raw')
        if self._include_topics and include_topics:
            parts.append('include_topics=%s' % str(self._include_topics))
        if self._exclude_topics and exclude_topics:
            parts.append('exclude_topics=%s' % str(self._exclude_topics))
        if self._include_types and include_types:
            parts.append('include_types=%s' % str(self._include_types))
        if self._exclude_types and exclude_types:
            parts.append('exclude_types=%s' % str(self._exclude_types))
        if self._min_stamp and min_stamp:
            parts.append('min_stamp=%s' % to_str(self._min_stamp))
        if self._max_stamp and max_stamp:
            parts.append('max_stamp=%s' % to_str(self._max_stamp))
        if self._include_time_ranges and include_time_ranges:
            parts.append('include_time_ranges=%s' % str(self._include_time_ranges))
        if self._exclude_time_ranges and exclude_time_ranges:
            parts.append('exclude_time_ranges=%s' % str(self._exclude_time_ranges))
        if self._include_tags and include_tags:
            parts.append('include_tags=%s' % str(self._include_tags))
        if self._exclude_tags and exclude_tags:
            parts.append('exclude_tags=%s' % str(self._exclude_tags))
        return ",".join(parts)

    def _str_params(self):
        """Parameters to be printed when stringifying this instance. This is called by __str__().

        :return: The parameters to print.
        :rtype: str
        """
        return self._default_str_params()


class RawMessageFilter(MessageFilter):
    """
    Message filter that processes raw messages.
    """

    def __init__(self, *args, **kwargs):
        super(RawMessageFilter, self).__init__(True, *args, **kwargs)

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        # type: (...) -> RawFilterResult
        """Do the filtering.

        :param str topic: Topic of the message.
        :param str datatype: ROS datatype of the message (as string).
        :param bytes data: The raw data.
        :param str md5sum: MD5 sum of the datatype.
        :param type pytype: ROS datatype of the message (as Python type).
        :param rospy.Time stamp: Receive timestamp of the message.
        :param dict header: Connection header.
        :param set tags: Message tags.
        :return: None if the message should be discarded, or raw message(s).
        """
        raise NotImplementedError


class DeserializedMessageFilter(MessageFilter):
    """
    Message filter that processes deserialized messages.
    """

    def __init__(self, *args, **kwargs):
        super(DeserializedMessageFilter, self).__init__(False, *args, **kwargs)

    def filter(self, topic, msg, stamp, header, tags):
        # type: (...) -> DeserializedFilterResult
        """Do the filtering.

        :param str topic: Topic of the message.
        :param genpy.Message msg: The decoded message.
        :param rospy.Time stamp: Receive timestamp of the message.
        :param dict header: Connection header.
        :param set tags: Message tags.
        :return: None if the message should be discarded, or a deserialized message(s).
        """
        raise NotImplementedError


class DeserializedMessageFilterWithTF(DeserializedMessageFilter):
    """Filter for deserialized messages that always correctly reads static TFs from the bag start."""

    def __init__(self, include_topics=None, include_types=None, tf_topics=("/tf",), tf_static_topics=("/tf_static",),
                 initial_bag_part_duration=genpy.Duration(2), *args, **kwargs):
        self._tf_topics = TopicSet(tf_topics)
        self._tf_static_topics = TopicSet(tf_static_topics)
        self._initial_bag_part_duration = initial_bag_part_duration

        if include_topics is not None:
            for t in tf_topics + tf_static_topics:
                if t not in include_topics:
                    include_topics.append(t)
        if include_types is not None:
            if "tf2_msgs/TFMessage" not in include_types:
                include_types.append("tf2_msgs/TFMessage")
        super(DeserializedMessageFilterWithTF, self).__init__(
            include_topics=include_topics, include_types=include_types, *args, **kwargs)

    def consider_message(self, topic, datatype, stamp, header, tags):
        # /tf has standard rules for being considered, but tf_static needs to be always accepted
        if topic in self._tf_static_topics:
            return True
        return super(DeserializedMessageFilterWithTF, self).consider_message(topic, datatype, stamp, header, tags)

    def extra_time_ranges(self, bags):
        # Require the beginnings of all bag files where static TFs can be stored
        return TimeRanges([TimeRange(0, self._initial_bag_part_duration)])


class UniversalFilter(MessageFilter):
    """Base for message filters that can act both as raw and deserialized, based just on a configuration value."""

    def filter(self, *args, **kwargs):
        # type: (...) -> Union[RawFilterResult, DeserializedFilterResult]
        if self.is_raw:
            return self.filter_raw(*args, **kwargs)
        else:
            return self.filter_deserialized(*args, **kwargs)

    def filter_raw(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        # type: (...) -> RawFilterResult
        raise NotImplementedError

    def filter_deserialized(self, topic, msg, stamp, header, tags):
        # type: (...) -> DeserializedFilterResult
        raise NotImplementedError


class Passthrough(UniversalFilter):
    """
    Just pass all messages through.
    """

    def __init__(self, is_raw=True, *args, **kwargs):
        super(Passthrough, self).__init__(is_raw=is_raw, *args, **kwargs)

    def filter_raw(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        return topic, datatype, data, md5sum, pytype, stamp, header, tags

    def filter_deserialized(self, topic, msg, stamp, header, tags):
        return topic, msg, stamp, header, tags


class NoMessageFilter(RawMessageFilter):
    """
    Ignore all messages. Good base for message generators or other helpers.
    """

    def consider_message(self, topic, datatype, stamp, header, tags):
        return False


class FilterChain(RawMessageFilter):
    """
    A chain of message filters.
    """

    def __init__(self, filters):
        """Constructor.

        :param filters: The filters to add.
        :type filters: list of MessageFilter
        """
        super(FilterChain, self).__init__()
        # leave the decision about accepting requeues on individual filters
        self._accept_requeued_messages = True
        self.filters = filters

    def consider_message(self, topic, datatype, stamp, header, tags):
        return True  # actual considering is done in the filter() loop

    def filter(self, topic, datatype, data, md5sum, pytype, stamp, header, tags):
        msg = None
        last_was_raw = True
        additional_msgs = []
        for f in self.filters:
            if not f.consider_message(topic, datatype, stamp, header, tags):
                continue
            if f.is_raw:
                if not last_was_raw:
                    datatype, data, md5sum, pytype = msg_to_raw(msg)
                ret = f(topic, datatype, data, md5sum, pytype, stamp, header, tags)
                ret = normalize_filter_result(ret)
                if ret[0] is None:
                    if len(ret) == 1 and len(additional_msgs) == 0:
                        return None
                    return [None] + additional_msgs + ret[1:]
                topic, datatype, data, md5sum, pytype, stamp, header, tags = ret[0]
                additional_msgs.extend(ret[1:])
                last_was_raw = True
            else:
                if last_was_raw:
                    msg = raw_to_msg(datatype, data, md5sum, pytype)
                ret = f(topic, msg, stamp, header, tags)
                ret = normalize_filter_result(ret)
                if ret[0] is None:
                    if len(ret) == 1 and len(additional_msgs) == 0:
                        return None
                    return [None] + additional_msgs + ret[1:]
                topic, msg, stamp, header, tags = ret[0]
                datatype = msg.__class__._type  # needed in consider_message() above
                additional_msgs.extend(ret[1:])
                last_was_raw = False
        if not last_was_raw:
            datatype, data, md5sum, pytype = msg_to_raw(msg)
        ret = topic, datatype, data, md5sum, pytype, stamp, header, tags
        if len(additional_msgs) == 0:
            return ret
        return [ret] + additional_msgs

    def connection_filter(self, topic, datatype, md5sum, msg_def, header):
        for f in self.filters:
            if not f.connection_filter(topic, datatype, md5sum, msg_def, header):
                return False
        return True

    def topic_filter(self, topic):
        for f in self.filters:
            if not f.topic_filter(topic):
                return False
        return True

    def set_bag(self, bag):
        for f in self.filters:
            f.set_bag(bag)
        super(FilterChain, self).set_bag(bag)

    def set_multibag(self, bag):
        for f in self.filters:
            f.set_multibag(bag)
        super(FilterChain, self).set_multibag(bag)

    def set_params(self, params):
        for f in self.filters:
            f.set_params(params)
        super(FilterChain, self).set_params(params)

    def extra_time_ranges(self, bags):
        time_ranges = None
        for f in self.filters:
            extra_ranges = f.extra_time_ranges(bags)
            if extra_ranges is not None:
                if time_ranges is None:
                    time_ranges = TimeRanges([])
                time_ranges.append(extra_ranges.ranges)
        return time_ranges

    def extra_initial_messages(self):
        for f in self.filters:
            for m in f.extra_initial_messages():
                yield m

    def extra_final_messages(self):
        for f in self.filters:
            for m in f.extra_final_messages():
                yield m

    def on_filtering_start(self):
        for f in self.filters:
            f.on_filtering_start()
        super(FilterChain, self).on_filtering_start()

    def on_filtering_end(self):
        for f in self.filters:
            f.on_filtering_end()
        super(FilterChain, self).on_filtering_end()

    def reset(self):
        for f in self.filters:
            f.reset()
        super(FilterChain, self).reset()

    def __str__(self):
        return '%s(%s)' % (self.__class__.__name__, ', '.join(str(f) for f in self.filters))

    def __iadd__(self, other):
        if other is None:
            return self
        assert isinstance(other, MessageFilter)
        if isinstance(other, FilterChain):
            self.filters += other.filters
            return self

        self.filters.append(other)
        return self

    def __add__(self, other):
        if other is None:
            return self
        assert isinstance(other, MessageFilter)
        if isinstance(other, FilterChain):
            return FilterChain(self.filters + other.filters)
        return FilterChain(self.filters + [other.filters])


def fix_connection_header(header, topic, datatype, md5sum, pytype):
    if header["type"] == datatype and header["topic"] == topic and header["md5sum"] == md5sum:
        return header
    header = dict(header)  # make a copy so that we don't alter the original message header instance
    header["topic"] = topic
    header["message_definition"] = pytype._full_text
    header["md5sum"] = md5sum
    header["type"] = datatype
    return header


def normalize_topic(topic):
    return rospy.names.canonicalize_name('/' + topic)


def filter_message(topic,  # type: STRING_TYPE
                   msg,  # type: Union[RawMessage, genpy.Message]
                   stamp,  # type: rospy.Time
                   connection_header,  # type: ConnectionHeader
                   tags,  # type: Set[STRING_TYPE]
                   filter,  # type: MessageFilter
                   raw_output=True  # type: bool
                   ):
    # type: (...) -> Optional[Union[AnyMessageData, List[AnyMessageData]]]
    """Apply the given filter to a message.

    :param topic: The message topic.
    :param msg: The message (either a deserialized message or a raw message as 4-tuple).
    :param stamp: Receive timestamp of the message.
    :param connection_header: Connection header.
    :param tags: Message tags. You should pass at least {MessageTags.ORIGINAL} if the message is directly from bag.
    :param filter: The filter to apply.
    :param raw_output: Whether to output a raw message or a deserialized one.
    :return: None if the message should be discarded, or a message, or a list of messages.
    """
    additional_msgs = []
    if filter.is_raw:
        # Convert to raw if decoded message was given
        try:
            datatype, data, md5sum, pytype = msg
        except (ValueError, TypeError):
            datatype, data, md5sum, pytype = msg_to_raw(msg)
        if filter.consider_message(topic, datatype, stamp, connection_header, tags):
            ret = filter(topic, datatype, data, md5sum, pytype, stamp, connection_header, tags)
            ret = normalize_filter_result(ret)
            if ret[0] is None:
                return None if len(ret) == 1 else ret
            topic, datatype, data, md5sum, pytype, stamp, connection_header, tags = ret[0]
            additional_msgs.extend(map(msg_long_to_short, ret[1:]))
        out_msg = (datatype, data, md5sum, pytype) if raw_output else raw_to_msg(datatype, data, md5sum, pytype)
    else:
        # Decode the message if raw was given
        if not isinstance(msg, genpy.Message):
            datatype, data, md5sum, pytype = msg
            msg = raw_to_msg(datatype, data, md5sum, pytype)
        if filter.consider_message(topic, msg.__class__._type, stamp, connection_header, tags):
            ret = filter(topic, msg, stamp, connection_header, tags)
            ret = normalize_filter_result(ret)
            if ret[0] is None:
                return None if len(ret) == 1 else ret
            topic, msg, stamp, connection_header, tags = ret[0]
            additional_msgs.extend(map(msg_long_to_short, ret[1:]))
        datatype = msg.__class__._type
        md5sum = msg.__class__._md5sum
        pytype = msg.__class__
        out_msg = msg_to_raw(msg) if raw_output else msg

    topic = normalize_topic(topic)
    # make sure connection header corresponds to the actual data type of the message
    # (if the filter forgot to update it)
    connection_header = fix_connection_header(connection_header, topic, datatype, md5sum, pytype)

    ret = topic, out_msg, stamp, connection_header, tags

    if len(additional_msgs) == 0:
        return ret
    return [ret] + additional_msgs


def deserialize_header(data, pytype):
    # type: (bytes, Type) -> Optional[Header]
    """Deserializes the header from the given raw message.

    This method deserializes only the header, not the rest of the message.

    :param data: The raw message whose header should be deserialized.
    :param pytype: The message type (as Python type).
    :return: A deserialized Header instance. If the message type has no header, None is returned.
    """
    has_header = len(pytype.__slots__) > 0 and pytype.__slots__[0] == 'header'
    if not has_header:
        return None
    return Header().deserialize(data)


__all__ = [
    DeserializedMessageFilter.__name__,
    DeserializedMessageFilterWithTF.__name__,
    FilterChain.__name__,
    MessageFilter.__name__,
    MessageTags.__name__,
    Passthrough.__name__,
    RawMessageFilter.__name__,
    UniversalFilter.__name__,
    cleanup_requeue_tags.__name__,
    decrement_requeue_tag.__name__,
    deserialize_header.__name__,
    filter_message.__name__,
    get_filters.__name__,
    msg_long_to_short.__name__,
    msg_short_to_long.__name__,
    normalize_filter_result.__name__,
    num_used_requeues.__name__,
    tags_for_changed_msg.__name__,
    tags_for_generated_msg.__name__,
    tags_for_requeuing_msg.__name__,
]
