#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

"""
Remap `mux/selected` from a bagfile to `selected` and `mux` to the name of a mux, and this node will
re-trigger the mux selection commands as they happened when the recorded system was running, so the
mux will be switching the same way it did during runtime.

This node is protected in case you create a loop by remapping the output of mux/selected of the
live mux to the `selected` input topic of this node. It should recognize such situations and react
only to messages that do not originate from the mux. This logic can, however, get broken in many
ways, so always test it prior to usage in a particular scenario. Third-party nodes will observe
all messages on the `mux/selected` topic doubled in case a loop exists.

Subscriptions:
- selected (std_msgs/String): Name of a selected output from a mux.

Service clients:
- mux/select (topic_tools/MuxSelect): Service called to switch the mux.
"""

import rospy
from std_msgs.msg import String
from topic_tools.srv import MuxSelect, MuxSelectRequest


def cb(msg):
    """
    Process the incoming "selected" message.
    :param String msg: The message containing name of the selected topic.
    """

    # noinspection PyProtectedMember
    callerid = msg._connection_header['callerid']

    # update the service server in case it changed
    global service_server
    if service_server != "" and service_server != callerid:
        service_server = get_service_server()

    # do not react to messages published by the mux to the mux/selected topic in case the user
    # creates a loop
    if callerid == service_server:
        rospy.logerr_once("Detected loop between mux %s and mux replay node %s! Please, use a separate "
                          "'selected' topic for the replay node." % (rospy.resolve_name("mux"), rospy.get_name()))
        return

    req = MuxSelectRequest()
    req.topic = msg.data
    rospy.logdebug("Selecting " + req.topic + " on mux " + rospy.resolve_name("mux"))
    select(req)


def get_service_server():
    """
    Tries to get the name of the node that is providing the mux/select service.
    :return: Name of the node, or empty string if this information cannot be extracted.
    :rtype: str
    """
    state = rospy.get_master().getSystemState()
    if len(state) >= 3 and len(state[2]) >= 3:
        services = dict(state[2][2])
        if service in services and len(services[service]) >= 1:
            return services[service][0]
    return ""


if __name__ == "__main__":
    rospy.init_node('mux_replay', anonymous=True)

    service = rospy.resolve_name("mux") + "/select"

    while not rospy.is_shutdown():
        try:
            rospy.wait_for_service(service, rospy.Duration(1))
            break
        except rospy.ROSException:
            rospy.logwarn("Waiting for service " + service + " to appear")

    rospy.loginfo("Service " + service + " found")

    service_server = get_service_server()

    select = rospy.ServiceProxy(service, MuxSelect)
    sub = rospy.Subscriber('selected', String, cb, queue_size=10)

    try:
        rospy.spin()
    except rospy.ROSInterruptException:
        pass
