#!/usr/bin/env python

# The script was originally written by Brian Gerkey under
# the works of the Virtual Robotics Challenge
#
# Copyright Open Source Robotics Foundation
#

from __future__ import print_function
import unittest
import rostest
import subprocess
import sys
import time
import re
import rospy

class Tester(unittest.TestCase):

    def _test_extra_topics(self, topics):
        cmd = ['rostopic', 'list']
        po = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        out, err = po.communicate()
        self.assertEqual(po.returncode, 0, 'rostopic failed (%s). stdout: %s stderr: %s'%(cmd, out, err))
        topics_actual = set(out.split('\n')) - set([''])
        topics_expected = set([x['topic'] for x in topics])
        topics_extra = topics_actual - topics_expected
        self.assertEqual(topics_extra, set([]))

    def _test_extra_services(self, services):
        cmd = ['rosservice', 'list']
        po = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        out, err = po.communicate()
        self.assertEqual(po.returncode, 0, 'rosservice failed (%s). stdout: %s stderr: %s'%(cmd, out, err))
        services_actual = set(out.split('\n')) - set([''])
        services_expected = set([x['service'] for x in services])
        services_extra = services_actual - services_expected
        self.assertEqual(services_extra, set([]))

    def _test_topic(self, t):
        self.assertIn('topic', t)
        self.assertIn('type', t)
        self.assertIn('num_publishers', t)
        self.assertIn('num_subscribers', t)

        cmd = ['rostopic', 'info', t['topic']]
        po = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        out, err = po.communicate()
        self.assertEqual(po.returncode, 0, 'rostopic info failed (%s). stdout: %s stderr: %s'%(cmd, out, err))
        self._parse_rostopic(t, out)

    def _parse_rostopic(self, t, out):
        # Should probably do this through a library API instead...

        # Step 0: make sure we have enough output
        outsplit = out.split('\n')
        self.assertTrue(len(outsplit) >= 5)

        type_re = re.compile('\w*Type: (.*)')
        pub_start_re = re.compile('Publishers:.*')
        sub_start_re = re.compile('Subscribers:.*')
        pub_sub_re = re.compile(' *\* *([^ ]*).*')

        # Step 1: check type
        m = type_re.match(outsplit[0])
        self.assertEqual(len(m.groups()), 1)
        self.assertEqual(m.groups()[0], t['type'])

        # Step 2: check num_publishers and num_subscribers
        state = None
        pubs = 0
        subs = 0
        for l in outsplit:
            if pub_start_re.match(l):
                state = 'in_pubs'
            elif sub_start_re.match(l):
                state = 'in_subs'
            else:
                m = pub_sub_re.match(l)
                if m and len(m.groups()) == 1:
                    if state == 'in_pubs':
                        pubs += 1
                    elif state == 'in_subs':
                        subs += 1
        if t['num_publishers'] >= 0:
            self.assertEqual(pubs, t['num_publishers'])
        if t['num_subscribers'] >= 0:
            self.assertEqual(subs, t['num_subscribers'])

    def _test_service(self, s):
        self.assertIn('service', s)
        self.assertIn('type', s)

        cmd = ['rosservice', 'info', s['service']]
        po = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        out, err = po.communicate()
        self.assertEqual(po.returncode, 0, 'rosservice info failed (%s). stdout: %s stderr: %s'%(cmd, out, err))
        self._parse_rosservice(s, out)

    def _parse_rosservice(self, s, out):
        # Should probably do this through a library API instead...

        # Step 0: make sure we have enough output
        outsplit = out.split('\n')
        self.assertTrue(len(outsplit) >= 4)

        type_re = re.compile('Type: (.*)')

        # Step 1: check type
        m = type_re.match(outsplit[2])
        self.assertEqual(len(m.groups()), 1)
        self.assertEqual(m.groups()[0], s['type'])

def load_config(files):
    import yaml
    topics = []
    services = []
    strict = False
    for f in files:
        # Ignore args passed in by rostest
        if f[:2] == '--' or f[:2] == '__':
            continue
        # Let parsing exceptions leak out; we'll catch them by noticing the
        # absence of a test result file.
        y = yaml.load(open(f))
        for t in y['topics']:
            topics.append(t)
        if 'services' in y:
            for s in y['services']:
                services.append(s)

        # TODO: This logic will enforce strictness if any of the provided files
        # sets strict to true, which isn't necessarily the right thing.
        if 'strict' in y and y['strict']:
            strict = True
    return topics, services, strict

def generate_topic_test(t):
    def test_func(self):
        self._test_topic(t)
    return test_func

def generate_service_test(t):
    def test_func(self):
        self._test_service(t)
    return test_func

def generate_test_extra_topics(topics):
    def test_func(self):
        self._test_extra_topics(topics)
    return test_func

def generate_test_extra_services(services):
    def test_func(self):
        self._test_extra_services(services)
    return test_func

def add_tests(topics, services, strict):
    for t in topics:
        test_func = generate_topic_test(t)
        test_name = "test_topic_%s"%(t['topic'].replace('/','_'))
        setattr(Tester, test_name, test_func)
    for s in services:
        test_func = generate_service_test(s)
        test_name = "test_service_%s"%(s['service'].replace('/','_'))
        setattr(Tester, test_name, test_func)
    if strict:
        test_func = generate_test_extra_topics(topics)
        test_name = "test_extra_topics"
        setattr(Tester, test_name, test_func)
        test_func = generate_test_extra_services(services)
        test_name = "test_extra_services"
        setattr(Tester, test_name, test_func)

if __name__ == '__main__':
    rospy.init_node('rosapi_checker', anonymous=True)

    # Dynamically generate test cases and stuff them into the Tester class
    topics, services, strict = load_config(sys.argv[1:])
    # The rostest node itself will advertise a couple of services
    services.append({'service': '%s/get_loggers'%(rospy.get_name()),
                     'type': 'roscpp/GetLoggers'})
    services.append({'service': '%s/set_logger_level'%(rospy.get_name()),
                     'type': 'roscpp/SetLoggerLevel'})
    add_tests(topics, services, strict)

    # Wait until /clock is being published; this can take an unpredictable
    # amount of time when we're downloading models.
    while rospy.Time.now().to_sec() == 0.0:
        print('Waiting for Gazebo to start...')
        time.sleep(1.0)
    # Take an extra nap, to allow plugins to be loaded
    time.sleep(5.0)
    print('OK, starting test.')

    rostest.run('gazebo_ros', 'api_check', Tester, sys.argv)
