#!/usr/bin/env python

# SPDX-License-Identifier: BSD-3-Clause
# SPDX-FileCopyrightText: Czech Technical University in Prague

# Test helper for testing mux_replay.

from __future__ import print_function

import rospy
import rostest
import unittest
import time
import sys

from std_msgs.msg import String

NAME = "test_mux_replay"


class MuxReplay(unittest.TestCase):
    def __init__(self, *args):
        super(MuxReplay, self).__init__(*args)
        rospy.init_node(NAME)

    def selectedCb(self, msg):
        self._last_selected_msg = msg

    def outCb(self, msg):
        self._last_msg = msg
        self._num_received += 1

    def test_replay(self):
        pub = rospy.Publisher("selected", String, queue_size=1)
        in1Pub = rospy.Publisher("in1", String, queue_size=1)
        in2Pub = rospy.Publisher("in2", String, queue_size=1)

        self._last_msg = self._last_selected_msg = None
        self._num_received = 0

        sub = rospy.Subscriber("/mux/selected", String, self.selectedCb, queue_size=100)
        outSub = rospy.Subscriber("/out", String, self.outCb, queue_size=100)

        msg = String()

        # first make sure the mux lets messages through (sometimes the first message is swallowed)

        self._last_msg = None
        self._num_received = 0
        msg.data = "in1"

        while not rospy.is_shutdown() and self._last_msg is None:
            in1Pub.publish(msg)
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.1)

        # use the mux in its initial state and publish to in1 which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        while not rospy.is_shutdown() and (self._last_msg is None or self._last_selected_msg is None):
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        time.sleep(0.1)
        self.assertEqual(self._num_received, 1)
        self.assertIsNotNone(self._last_msg)
        self.assertEqual(self._last_msg.data, "in1")
        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "/in1")

        # switch the mux to in2

        self._last_selected_msg = None
        msg.data = rospy.resolve_name(in2Pub.name)
        pub.publish(msg)

        while not rospy.is_shutdown() and self._last_selected_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "/in2")

        # publish to in2 which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg = String()
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        while not rospy.is_shutdown() and self._last_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        time.sleep(0.1)
        self.assertEqual(self._num_received, 1)
        self.assertIsNotNone(self._last_msg)
        self.assertEqual(self._last_msg.data, "in2")

        # switch the mux to __none

        self._last_selected_msg = None
        msg.data = "__none"
        pub.publish(msg)

        while not rospy.is_shutdown() and self._last_selected_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "__none")

        # publish to in1 and in2, none of which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg = String()
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        i = 0
        while not rospy.is_shutdown() and self._last_msg is None and i < 100:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)
            i += 1

        time.sleep(0.1)
        self.assertEqual(self._num_received, 0)
        self.assertIsNone(self._last_msg)

        # switch the mux to in1

        self._last_selected_msg = None
        msg.data = rospy.resolve_name(in1Pub.name)
        pub.publish(msg)

        while not rospy.is_shutdown() and self._last_selected_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        self.assertIsNotNone(self._last_selected_msg)
        self.assertEqual(self._last_selected_msg.data, "/in1")

        # publish to in2 which should get propagated to out

        self._last_msg = None
        self._num_received = 0
        msg = String()
        msg.data = "in1"
        in1Pub.publish(msg)
        msg.data = "in2"
        in2Pub.publish(msg)

        while not rospy.is_shutdown() and self._last_msg is None:
            rospy.logerr_throttle(1.0, "Waiting")
            time.sleep(0.01)

        time.sleep(0.1)
        self.assertEqual(self._num_received, 1)
        self.assertIsNotNone(self._last_msg)
        self.assertEqual(self._last_msg.data, "in1")


if __name__ == "__main__":
    time.sleep(0.75)
    try:
        rostest.run('rostest', NAME, MuxReplay, sys.argv)
    except KeyboardInterrupt:
        pass
    print("exiting")
