#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""
Filter messages on a topic based on a Python expression.

CLI Args: filter [IN_TOPIC [OUT_TOPIC]] FILTER_EXPRESSION
- IN_TOPIC: If specified, it overrides the topic this node subscribes to.
- OUT_TOPIC: If specified, it overrides the topic this node publishes to.
- FILTER_EXPRESSION: String containing the filtering expression. The expression is evaluated using Python `eval()`.
                     Local variable `m` holds the received message. The expression should evaluate to a bool. If it is
                     True, the message is published to `~output` topic, otherwise it is discarded. If you need to
                     call functions of some modules, use `__import__('module').function()` in the expression, as normal
                     `import module` does not work there. Be sure to correctly enclose the expression in quotes when
                     passing from a command line. Do NOT use `return` in the statement.

Subscriptions:
- ~input (any message type): The input data.

Publications:
- ~output (same type as `~input`): The filtered data.

Parameters:
- ~pass_only_header (bool, default False): If True, only the header of the message is deserialized and passed to the
                                           filter expression as the `m` argument. This can save some computations when
                                           you expect to filter out a lot of messages.
                                           If the incoming message does not have a header, the message is discarded and
                                           a warning is printed.
- ~pass_raw_buffer (bool, default False): If True, the raw serialized buffer is passed to the filter expression as the
                                          `m` argument. This can save some computations when you do not need to look
                                           in the message content at all.
                                           Only one of `pass_raw_buffer` and `pass_only_header` can be true, otherwise
                                           the node ends with error right at startup.
- ~in_queue_size (unsigned int, default 10): Queue size of the input subscriber.
- ~out_queue_size (unsigned int, default 10): Queue size of the output publisher.
"""

from __future__ import print_function

import sys
import time

import genpy
import rospy
from roslib.message import get_message_class
from std_msgs.msg import Header

import cras

pub = None


def cb(anymsg):
    """
    Callback called when an input message is received. It evaluates the filter function and publishes to the output
    topic if the filter function allows.
    :param rospy.AnyMsg anymsg: The incoming message.
    """
    global pub
    global out_queue_size
    global type_check_ok

    if type_check_ok is not None and not type_check_ok:
        return

    # Create the publisher if it doesn't exist yet
    if pub is None:
        # noinspection PyProtectedMember
        connection_header = anymsg._connection_header
        cls = get_message_class(connection_header['type'])
        latch = connection_header.get("latching", "0") == "1"
        pub = rospy.Publisher(out_topic, cls, queue_size=out_queue_size, latch=latch)
        time.sleep(0.1 if latch else 0.01)  # Give the publisher time to register with subscribers
    else:
        cls = pub.data_class

    if type_check_ok is None:
        if pass_only_header and not has_header(cls):
            type_check_ok = False
            rospy.logerr("Requested to pass only header to filter function but message type '%s' does not have a "
                         "header. Discarding all messages." % (cls,))
            return
        else:
            type_check_ok = True

    # noinspection PyProtectedMember
    buff = anymsg._buff

    if pass_raw_buffer:
        publish = filter_fn(buff)
    elif pass_only_header:
        header = Header().deserialize(buff)
        publish = filter_fn(header)
    else:
        # noinspection PyCallingNonCallable
        msg = cls().deserialize(buff)
        publish = filter_fn(msg)

    if publish:
        pub.publish(anymsg)


def has_header(msg):
    if isinstance(msg, rospy.AnyMsg):
        return None
    elif isinstance(msg, type) or isinstance(msg, genpy.Message):
        return hasattr(msg, "header")
    else:
        return has_header(get_message_class(msg))


if __name__ == "__main__":
    rospy.init_node('filter', anonymous=True)

    argv = rospy.myargv()
    if not 2 <= len(argv) <= 4 :
        print("Usage: filter [IN_TOPIC [OUT_TOPIC]] 'FILTER_EXPRESSION'\n\n"
              "Topics:\n"
              "  $IN_TOPIC (default '~input'): input data\n"
              "  $OUT_TOPIC (default '~output'): output data\n\n"
              "FILTER_EXPRESSION should evaluate to a boolean (and not return it).\n"
              "Use __import__('module').function() if you need to import modules", file=sys.stderr)
        sys.exit(1)

    in_topic = "~input"
    out_topic = "~output"

    if len(argv) == 2:
        filter_fn_str = argv[1]
    elif len(argv) == 3:
        in_topic = argv[1]
        filter_fn_str = argv[2]
    else:
        in_topic = argv[1]
        out_topic = argv[2]
        filter_fn_str = argv[3]

    type_check_ok = None
    pass_raw_buffer = cras.get_param("~pass_raw_buffer", False)
    pass_only_header = cras.get_param("~pass_only_header", False)
    in_queue_size = cras.get_param("~in_queue_size", 10, "messages")
    out_queue_size = cras.get_param("~out_queue_size", 10, "messages")

    if pass_only_header and pass_raw_buffer:
        rospy.logfatal("Cannot set both pass_raw_buffer and pass_only_header.")
        sys.exit(1)

    def filter_fn(m):
        """
        The filtering function.
        :param rospy.AnyMsg m: The received message.
        :return: Whether the message should be sent further.
        :rtype: bool
        """
        try:
            return eval(filter_fn_str)
        except Exception as e:
            rospy.logerr_throttle(1.0, "Evaluation of the filter function failed due to the following error: '%s'. "
                                       "Skipping message." % (str(e),))

    sub = rospy.Subscriber(in_topic, rospy.AnyMsg, cb, queue_size=in_queue_size)

    try:
        rospy.spin()
    except rospy.ROSInterruptException:
        pass
