#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

# Test helper for testing mux_replay when mux/selected from a running mux is plugged into the replay node controlling
# the same mux.

from __future__ import print_function

import rospy
import rostest
import unittest
import time
import sys

from std_msgs.msg import String
from topic_tools.srv import MuxSelect, MuxSelectRequest

NAME = "test_mux_replay_loop"


class MuxReplayLoop(unittest.TestCase):
    def __init__(self, *args):
        super(MuxReplayLoop, self).__init__(*args)
        rospy.init_node(NAME)

    def selectedCb(self, msg):
        self._last_selected_msg = msg
        self._num_selected_received += 1

    def outCb(self, msg):
        self._last_msg = msg
        self._num_received += 1

    def test_replay(self):
        select = rospy.ServiceProxy("mux/select", MuxSelect)
        in1Pub = rospy.Publisher("in1", String, queue_size=1)
        in2Pub = rospy.Publisher("in2", String, queue_size=1)

        self._last_msg = self._last_selected_msg = None
        self._num_received = self._num_selected_received = 0

        sub = rospy.Subscriber("mux/selected", String, self.selectedCb, queue_size=100)
        outSub = rospy.Subscriber("out", String, self.outCb, queue_size=100)

        msg = String()

        # first make sure the mux lets messages through (sometimes the first message is swallowed)

        self._last_msg = None
        self._num_received = 0
        msg.data = "in1"

        while not rospy.is_shutdown() and self._last_msg is None:
            in1Pub.publish(msg)
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.1)

        # use the mux in its initial state and publish to in1 which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        while not rospy.is_shutdown() and (self._last_msg is None or self._last_selected_msg is None):
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        time.sleep(0.1)
        self.assertEqual(self._num_received, 1)
        self.assertIsNotNone(self._last_msg)
        self.assertEqual(self._last_msg.data, "in1")
        self.assertEqual(self._num_selected_received, 1)
        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "/in1")

        # switch the mux to in2

        self._last_selected_msg = None
        self._num_selected_received = 0
        select(MuxSelectRequest(topic=rospy.resolve_name(in2Pub.name)))

        while not rospy.is_shutdown() and self._last_selected_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)
        time.sleep(0.1)

        self.assertEqual(self._num_selected_received, 1)
        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "/in2")

        # publish to in2 which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg = String()
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        while not rospy.is_shutdown() and self._last_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        time.sleep(0.1)
        self.assertEqual(self._num_received, 1)
        self.assertIsNotNone(self._last_msg)
        self.assertEqual(self._last_msg.data, "in2")

        # switch the mux to __none

        self._last_selected_msg = None
        self._num_selected_received = 0
        select(MuxSelectRequest(topic="__none"))

        while not rospy.is_shutdown() and self._last_selected_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)
        time.sleep(0.1)

        self.assertEqual(self._num_selected_received, 1)
        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "__none")

        # publish to in1 and in2, none of which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg = String()
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        i = 0
        while not rospy.is_shutdown() and self._last_msg is None and i < 100:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)
            i += 1

        time.sleep(0.1)
        self.assertEqual(self._num_received, 0)
        self.assertIsNone(self._last_msg)

        # switch the mux to in1

        self._last_selected_msg = None
        self._num_selected_received = 0
        msg.data = rospy.resolve_name(in1Pub.name)
        select(MuxSelectRequest(topic=rospy.resolve_name(in1Pub.name)))

        while not rospy.is_shutdown() and self._last_selected_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)
        time.sleep(0.1)

        self.assertEqual(self._num_selected_received, 1)
        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "/in1")

        # publish to in2 which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg = String()
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        while not rospy.is_shutdown() and self._last_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        time.sleep(0.1)
        self.assertEqual(self._num_received, 1)
        self.assertIsNotNone(self._last_msg)
        self.assertEqual(self._last_msg.data, "in1")


if __name__ == "__main__":
    # time.sleep(0.75)
    try:
        rostest.run('rostest', NAME, MuxReplayLoop, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")
