#!/usr/bin/env python

import numpy as np

import cv_bridge
from jsk_topic_tools import ConnectionBasedTransport
import rospy
from sensor_msgs.msg import Image

from jsk_recognition_msgs.msg import LabelArray
from jsk_recognition_msgs.srv import SetLabels
from jsk_recognition_msgs.srv import SetLabelsResponse


def get_param_uint_list(name, default=None):
    param = rospy.get_param(name, default)
    if not (isinstance(param, list) and
            all(isinstance(el, int) and el >= 0 for el in param)):
        raise ValueError("Elements of '%s' must be integer and >=0." % param)
    return param


class ApplyContextToLabelProbability(ConnectionBasedTransport):

    def __init__(self):
        super(self.__class__, self).__init__()
        # list of label values
        self.candidates = get_param_uint_list('~candidates', [])

        # candidates which are never updated
        self.candidates_fixed = get_param_uint_list('~candidates_fixed', [])

        rospy.Service('~update_candidates', SetLabels, self._update_candidates)
        self.pub_proba = self.advertise('~output', Image, queue_size=1)
        self.pub_label = self.advertise('~output/label', Image, queue_size=1)

    def subscribe(self):
        self._sub_img = rospy.Subscriber('~input', Image, self._apply)
        self._sub_candidates = rospy.Subscriber(
            '~input/candidates', LabelArray, self._update_candidates_with_topic)

    def unsubscribe(self):
        self._sub_img.unregister()
        self._sub_candidates.unregister()

    def _update_candidates(self, req):
        self.candidates = list(req.labels)
        rospy.set_param('~candidates', self.candidates)
        return SetLabelsResponse(success=True)

    def _update_candidates_with_topic(self, candidates_msg):
        candidates = candidates_msg.labels
        self.candidates = [x.id for x in candidates]
        rospy.set_param('~candidates', self.candidates)

    def _apply(self, imgmsg):
        bridge = cv_bridge.CvBridge()

        proba_img = bridge.imgmsg_to_cv2(imgmsg).copy()  # copy to change it
        if proba_img.ndim != 3:
            rospy.logerr('Image shape must be (height, width, channels).')
            return

        if self.candidates:
            # current candidates
            candidates = self.candidates + self.candidates_fixed
            candidates = set(candidates)

            n_labels = proba_img.shape[2]
            if max(candidates) >= n_labels:
                rospy.logwarn_throttle(
                    10, "The max label value in '~candidates' exceeds "
                        "the number of input labels.")

            for lbl in range(n_labels):
                if lbl not in candidates:
                    proba_img[:, :, lbl] = 0.

        # do dynamic scaling for the probability image
        proba_img /= np.atleast_3d(proba_img.sum(axis=-1))

        out_proba_msg = bridge.cv2_to_imgmsg(proba_img)
        out_proba_msg.header = imgmsg.header
        self.pub_proba.publish(out_proba_msg)

        label_img = proba_img.argmax(axis=-1).astype(np.int32)
        out_label_msg = bridge.cv2_to_imgmsg(label_img, encoding='32SC1')
        out_label_msg.header = imgmsg.header
        self.pub_label.publish(out_label_msg)


if __name__ == '__main__':
    rospy.init_node('apply_context_to_label_probablity')
    ApplyContextToLabelProbability()
    rospy.spin()
