#!/usr/bin/env python
'''
check_radius.py, v1.0.0

Performs a check on basic radius authentication. Returns perfdata on how long
it takes to return the authentication.

--- 

Basic RADIUS authentication. Minimum necessary to be able to authenticate a
user with or without challenge/response, yet remain RFC2865 compliant (I hope).

py-radius library: http://github.com/btimby/py-radius/
'''
# "py-radius" Library Code:
# Copyright (c) 1999, Stuart Bishop <zen@shangri-la.dropbear.id.au>
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
#     Redistributions of source code must retain the above copyright
#     notice, this list of conditions and the following disclaimer.
#
#     Redistributions in binary form must reproduce the above copyright
#     notice, this list of conditions and the following disclaimer in the
#     documentation and/or other materials provided with the
#     distribution.
#
#     The name of Stuart Bishop may not be used to endorse or promote
#     products derived from this software without specific prior written
#     permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import socket
import logging
import struct

# Not part of the py-radius library
import sys
import traceback
import optparse
from timeit import default_timer as timer
# -----

from select import select
from random import randint
from contextlib import closing, contextmanager

try:
    from collections import UserDict
except ImportError:
    from UserDict import UserDict

try:
    from hashlib import md5
except ImportError:
    from md5 import new as md5

# py-radius library version is 2.0.2
# plugin version is 1.0.0
__version__ = '1.0.0'

LOGGER = logging.getLogger(__name__)

# Networking constants.
# -------------------------------
PACKET_MAX = 4096
DEFAULT_PORT = 1812
DEFAULT_RETRIES = 3
DEFAULT_TIMEOUT = 5
# -------------------------------

# Protocol specific constants.
# -------------------------------
# Codes indicating packet type.
CODE_ACCESS_REQUEST = 1
CODE_ACCESS_ACCEPT = 2
CODE_ACCESS_REJECT = 3
CODE_ACCOUNTING_REQUEST = 4
CODE_ACCOUNTING_RESPONSE = 5
CODE_ACCESS_CHALLENGE = 11
CODE_STATUS_SERVER = 12
CODE_STATUS_CLIENT = 13
# CODE_RESERVED = 255

# Map from name to id.
CODES = {
    CODE_ACCESS_REQUEST: 'Access-Request',
    CODE_ACCESS_ACCEPT: 'Access-Accept',
    CODE_ACCESS_REJECT: 'Access-Reject',
    CODE_ACCOUNTING_REQUEST: 'Accounting-Request',
    CODE_ACCOUNTING_RESPONSE: 'Accounting-Response',
    CODE_ACCESS_CHALLENGE: 'Access-Challenge',
    CODE_STATUS_SERVER: 'Status-Server',
    CODE_STATUS_CLIENT: 'Status-Client',
}

CODE_NAMES = {}
for k, v in CODES.items():
    CODE_NAMES[v.lower()] = k

# Attributes that can be part of the RADIUS payload.
ATTR_USER_NAME = 1
ATTR_USER_PASSWORD = 2
ATTR_CHAP_PASSWORD = 4
ATTR_NAS_IP_ADDRESS = 4
ATTR_NAS_PORT = 5
ATTR_SERVICE_TYPE = 6
ATTR_FRAMED_PROTOCOL = 7
ATTR_FRAMED_IP_ADDRESS = 8
ATTR_FRAMED_IP_NETMASK = 9
ATTR_FRAMED_ROUTING = 10
ATTR_FILTER_ID = 11
ATTR_FRAMED_MTU = 12
ATTR_FRAMED_COMPRESSION = 13
ATTR_LOGIN_IP_HOST = 14
ATTR_LOGIN_SERVICE = 15
ATTR_LOGIN_TCP_PORT = 16
# ATTR_UNASSIGNED = 17
ATTR_REPLY_MESSAGE = 18
ATTR_CALLBACK_NUMBER = 19
ATTR_CALLBACK_ID = 20
# ATTR_UNASSIGNED = 21
ATTR_FRAMED_ROUTE = 22
ATTR_FRAMED_IPX_NETWORK = 23
ATTR_STATE = 24
ATTR_CLASS = 25
ATTR_VENDOR_SPECIFIC = 26
ATTR_SESSION_TIMEOUT = 27
ATTR_IDLE_TIMEOUT = 28
ATTR_TERMINATION_ACTION = 29
ATTR_CALLED_STATION_ID = 30
ATTR_CALLING_STATION_ID = 31
ATTR_NAS_IDENTIFIER = 32
ATTR_PROXY_STATE = 33
ATTR_LOGIN_LAT_SERVICE = 34
ATTR_LOGIN_LAT_NODE = 35
ATTR_LOGIN_LAT_GROUP = 36
ATTR_FRAMED_APPLETALK_LINK = 37
ATTR_FRAMED_APPLETALK_NETWORK = 38
ATTR_FRAMED_APPLETALK_ZONE = 39
# ATTR_RESERVED = 40-59
ATTR_CHAP_CHALLENGE = 60
ATTR_NAS_PORT_TYPE = 61
ATTR_PORT_LIMIT = 62
ATTR_LOGIN_LAT_PORT = 63

ATTRS = {
    ATTR_USER_NAME: 'User-Name',
    ATTR_USER_PASSWORD: 'User-Password',
    ATTR_CHAP_PASSWORD: 'CHAP-Password',
    ATTR_NAS_IP_ADDRESS: 'NAS-IP-Address',
    ATTR_NAS_PORT: 'NAS-Port',
    ATTR_SERVICE_TYPE: 'Service-Type',
    ATTR_FRAMED_PROTOCOL: 'Framed-Protocol',
    ATTR_FRAMED_IP_ADDRESS: 'Framed-IP-Address',
    ATTR_FRAMED_IP_NETMASK: 'Framed-IP-NetMask',
    ATTR_FRAMED_ROUTING: 'Framed-Routing',
    ATTR_FILTER_ID: 'Filter-Id',
    ATTR_FRAMED_MTU: 'Framed-MTU',
    ATTR_FRAMED_COMPRESSION: 'Framed-Compression',
    ATTR_LOGIN_IP_HOST: 'Login-IP-Host',
    ATTR_LOGIN_SERVICE: 'Login-Service',
    ATTR_LOGIN_TCP_PORT: 'Login-TCP-Port',
    ATTR_REPLY_MESSAGE: 'Reply-Message',
    ATTR_CALLBACK_NUMBER: 'Callback-Number',
    ATTR_CALLBACK_ID: 'Callback-Id',
    ATTR_FRAMED_ROUTE: 'Framed-Route',
    ATTR_FRAMED_IPX_NETWORK: 'Framed-IPX-Network',
    ATTR_STATE: 'State',
    ATTR_CLASS: 'Class',
    ATTR_VENDOR_SPECIFIC: 'Vendor-Specific',
    ATTR_SESSION_TIMEOUT: 'Session-Timeout',
    ATTR_IDLE_TIMEOUT: 'Idle-Timeout',
    ATTR_TERMINATION_ACTION: 'Termination-Action',
    ATTR_CALLED_STATION_ID: 'Called-Station-Id',
    ATTR_CALLING_STATION_ID: 'Calling-Station-Id',
    ATTR_NAS_IDENTIFIER: 'NAS-Identifier',
    ATTR_PROXY_STATE: 'Proxy-State',
    ATTR_LOGIN_LAT_SERVICE: 'Login-LAT-Service',
    ATTR_LOGIN_LAT_NODE: 'Login-LAT-Node',
    ATTR_LOGIN_LAT_GROUP: 'Login-LAT-Group',
    ATTR_FRAMED_APPLETALK_LINK: 'Framed-AppleTalk-Link',
    ATTR_FRAMED_APPLETALK_NETWORK: 'Framed-AppleTalk-Network',
    ATTR_FRAMED_APPLETALK_ZONE: 'Framed-AppleTalk-Zone',
    ATTR_CHAP_CHALLENGE: 'CHAP-Challenge',
    ATTR_NAS_PORT_TYPE: 'NAS-Port-Type',
    ATTR_PORT_LIMIT: 'Port-Limit',
    ATTR_LOGIN_LAT_PORT: 'Login-LAT-Port',
}

# Map from name to id.
ATTR_NAMES = {}
for k, v in ATTRS.items():
    ATTR_NAMES[v.lower()] = k
# -------------------------------


class Error(Exception):
    """
    Base Error class.
    """

    pass


class NoResponse(Error):
    """
    Indicates no valid response received.
    """

    pass


class ChallengeResponse(Error):
    """
    Raised when radius replies with a challenge.

    Provides the message(s) if any, as well as the state (if provided).

    There can be 0+ messages. State is either defined or not.
    """
    def __init__(self, msg=None, state=None):
        if msg is None:
            self.messages = []
        elif isinstance(msg, list):
            self.messages = msg
        else:
            self.messages = [msg]
        self.state = state


class SocketError(NoResponse):
    """
    Indicates general network error.
    """

    pass


# These functions are used to act upon strings in Python2, but bytes in
# Python3. Their functions are not necessary in PY3, so we NOOP them.
if sys.version_info[0] > 2:
    def ord(s):
        return s
    def chr(s):
        return bytes([s])


def bytes_safe(s, e='utf-8'):
    try:
        return s.encode(e)
    except (AttributeError, UnicodeDecodeError):
        return s


def join(items):
    """
    Shortcut to join collection of strings.
    """
    return b''.join(items)


def authenticate(secret, username, password, host=None, port=None,
                 retries=DEFAULT_RETRIES, timeout=DEFAULT_TIMEOUT, **kwargs):
    """
    Authenticate the user against a radius server.

    Return True if the user successfully logged in and False if not.

    If the server replies with a challenge, a `ChallengeResponse` exception is
    raised with the challenge.

    Can raise either NoResponse or SocketError
    """
    # Pass host/port to the Radius instance. But ONLY if they are defined,
    # otherwise we allow Radius to use the defaults for the kwargs.
    rkwargs = {}
    if host:
        rkwargs['host'] = host
    if port:
        rkwargs['port'] = port
    # Additional kwargs (like attributes) are sent to Radius.authenticate().
    return Radius(secret, retries=retries, timeout=timeout, **rkwargs)\
           .authenticate(username, password, **kwargs)


def radcrypt(secret, authenticator, password):
    """Encrypt a password with the secret and authenticator."""
    # First, pad the password to multiple of 16 octets.
    password += b'\0' * (16 - (len(password) % 16))

    if len(password) > 128:
        raise ValueError('Password exceeds maximun of 128 bytes')

    result, last = b'', authenticator
    while password:
        # md5sum the shared secret with the authenticator,
        # after the first iteration, the authenticator is the previous
        # result of our encryption.
        hash = md5(secret + last).digest()
        for i in range(16):
            result += chr(ord(hash[i]) ^ ord(password[i]))
        # The next iteration will act upon the next 16 octets of the password
        # and the result of our xor operation above. We will set last to
        # the last 16 octets of our result (the xor we just completed). And
        # remove the first 16 octets from the password.
        last, password = result[-16:], password[16:]

    return result


class Attributes(UserDict):
    """
    Dictionary-style interface.

    Can retrieve or set values by name or by code. Internally stores items by
    their assigned code. A given attribute can be present more than once.
    """
    def __init__(self, initialdata={}):
        UserDict.__init__(self, {})
        # Set keys via update() to invoke validation.
        self.update(initialdata)

    def __getkeys(self, value):
        """Return tuple of code, name for given code or name."""
        if isinstance(value, int):
            return value, ATTRS[value]
        else:
            id = ATTR_NAMES[value.lower()]
            return id, ATTRS[id]

    def __contains__(self, key):
        """
        Override in operator.
        """
        code = self.__getkeys(key)[0]
        return UserDict.__contains__(self, code)

    def __getitem__(self, key):
        """
        Retrieve an item from attributes (by name or id).
        """
        for k in self.__getkeys(key):
            try:
                return UserDict.__getitem__(self, k)
            except KeyError:
                continue
        raise KeyError(key)

    def __setitem__(self, key, value):
        """
        Add an item to attributes (by name or id)
        """
        try:
            code, name = self.__getkeys(key)
        except KeyError:
            raise ValueError('Invalid radius attribute: %s' % key)
        values = self.get(code, [])
        values.append(value)
        UserDict.__setitem__(self, code, values)

    def update(self, data):
        """
        Sets keys via __setitem__() to invoke validation.
        """
        for k, v in data.items():
            self[k] = v

    def nameditems(self):
        """
        Yields name value pairs as names (instead of ids).
        """
        for k, v in self.items():
            yield self.__getkeys(k)[1], v

    def pack(self):
        """
        Packs Attributes instance into data buffer.
        """
        data = []
        for key, values in self.items():
            for value in values:
                data.append(struct.pack('BB%ds' % len(value), key,
                                        len(value) + 2, bytes_safe(value)))
        return join(data)

    @staticmethod
    def unpack(data):
        """
        Unpacks data into Attributes instance.
        """
        pos, attrs = 0, {}
        while pos < len(data):
            code, l = struct.unpack('BB', data[pos:pos + 2])
            attrs[code] = data[pos + 2:pos + l]
            pos += l
        return Attributes(attrs)


class Message(object):
    """
    Represents a radius protocol packet.

    This class can be used for requests and replies. The RFC dictates the
    format.

     0                   1                   2                   3
     0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |     Code      |  Identifier   |            Length             |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |                                                               |
    |                     Response Authenticator                    |
    |                                                               |
    |                                                               |
    +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    |  Attributes ...
    +-+-+-+-+-+-+-+-+-+-+-+-+-

    Code - one octet, see CODES enum.
    Identifier - one octet, unique value that represents request/response pair.
                 Provided by client and echoed by server.
    Length - two octets, the length of the packet up to the max of 4096.
    """

    def __init__(self, secret, code, id=None, authenticator=None,
                 attributes=None):
        self.code = code
        self.secret = secret
        self.id = id if id else randint(0, 255)
        self.authenticator = authenticator if authenticator else os.urandom(16)
        if isinstance(attributes, dict):
            attributes = Attributes(attributes)
        self.attributes = attributes if attributes else Attributes()

    def pack(self):
        """Pack the packet into binary form for transport."""
        # First pack the attributes, since we need to know their length.
        attrs = self.attributes.pack()
        data = []
        # Now pack the code, id, total length, authenticator
        data.append(struct.pack('!BBH16s', self.code, self.id,
                    len(attrs) + 20, self.authenticator))
        # Attributes take up the remainder of the message.
        data.append(attrs)
        return join(data)

    @staticmethod
    def unpack(secret, data):
        """Unpack the data into it's fields."""
        code, id, l, authenticator = struct.unpack('!BBH16s', data[:20])
        if l != len(data):
            LOGGER.warning('Too much data!')
        attrs = Attributes.unpack(data[20:l])
        return Message(secret, code, id, authenticator, attrs)

    def verify(self, data):
        """
        Verify and unpack a response.

        Ensures that a message is a valid response to this message, then
        unpacks it.
        """
        id = ord(data[1])
        assert self.id == id, 'ID mismatch (%s != %s)' % (self.id, id)
        signature = md5(
            data[:4] + self.authenticator + data[20:] + self.secret).digest()
        assert signature == data[4:20], 'Invalid authenticator'
        return Message.unpack(self.secret, data)


class Radius(object):
    """
    Radius client implementation.
    """

    def __init__(self, secret, host='radius', port=DEFAULT_PORT,
                 retries=DEFAULT_RETRIES, timeout=DEFAULT_TIMEOUT):
        self._secret = bytes_safe(secret)
        self.retries = retries
        self.timeout = timeout
        self._host = host
        self._port = port

    @property
    def host(self):
        return self._host

    @property
    def port(self):
        return self._port

    @property
    def secret(self):
        return self._secret

    @contextmanager
    def connect(self):
        with closing(socket.socket(socket.AF_INET, socket.SOCK_DGRAM)) as c:
            c.connect((self.host, self.port))
            LOGGER.debug('Connected to %s:%s', self.host, self.port)
            yield c

    def send_message(self, message):
        send = message.pack()

        try:
            with self.connect() as c:
                for i in range(self.retries):
                    LOGGER.debug(
                        'Sending (as hex): %s',
                        ':'.join(format(ord(c), '02x') for c in send))

                    c.send(send)

                    r, w, x = select([c], [], [], self.timeout)
                    if c in r:
                        recv = c.recv(PACKET_MAX)
                    else:
                        # No data available on our socket. Try again.
                        LOGGER.warning('Timeout expired on try %s', i)
                        continue

                    LOGGER.debug(
                        'Received (as hex): %s',
                        ':'.join(format(ord(c), '02x') for c in recv))

                    try:
                        return message.verify(recv)
                    except AssertionError as e:
                        LOGGER.warning('Invalid response discarded %s', e)
                        # Silently discard invalid replies (as RFC states).
                        continue

        except socket.error as e:  # SocketError
            LOGGER.debug('Socket error', exc_info=True)
            raise SocketError(e)

        error_msg = 'request timed out after %s tries' % (i+1)
        raise NoResponse(error_msg)

    def access_request_message(self, username, password, **kwargs):
        username = bytes_safe(username)
        password = bytes_safe(password)

        message = Message(self.secret, CODE_ACCESS_REQUEST, **kwargs)
        message.attributes['User-Name'] = username
        message.attributes['User-Password'] = \
            radcrypt(self.secret, message.authenticator, password)

        return message

    def authenticate(self, username, password, **kwargs):
        """
        Attempt to authenticate with the given username and password.

           Returns False on failure
           Returns True on success
           Raises a NoResponse (or its subclass SocketError) exception if no
               responses or no valid responses are received
        """
        reply = self.send_message(
            self.access_request_message(username, password, **kwargs))

        if reply.code == CODE_ACCESS_ACCEPT:
            LOGGER.info('Access accepted')
            return True

        elif reply.code == CODE_ACCESS_CHALLENGE:
            LOGGER.info('Access challenged')
            messages = reply.attributes.get('Reply-Message', None)
            state = reply.attributes.get('State', [None])[0]
            raise ChallengeResponse(messages, state)

        LOGGER.info('Access rejected')
        return False

#
# ---------------------------------------------
# End of py-radius python module
#


# Parse the command line options for the required values
def parse_args():
    version = 'check_radius.py v%s' % __version__
    parser = optparse.OptionParser()

    # Options required by nagios-plugins
    parser.add_option("-V", "--version", action="store_true",
                      help="Print the version number of the plugin")
    parser.add_option("-v", "--verbose", default=False, action="store_true",
                      help="Print out verbose output")

    # Create parsing options
    parser.add_option("-H", "--hostname", help="The hostname of the RADIUS "
                      "server to connect to")
    parser.add_option("-P", "--port", default=DEFAULT_PORT,
                      help="The port of the RADIUS server")
    parser.add_option("-u", "--username", help="The username to authenticate")
    parser.add_option("-p", "--password", help="The password of the auth user")
    parser.add_option("-s", "--secret",
                      help="The shared secret for the RADIUS server")
    parser.add_option("-t", "--timeout", default=DEFAULT_TIMEOUT,
                      help="The amount of time to try to authenticate before "
                      "timing out")
    parser.add_option("-r", "--retries", default=DEFAULT_RETRIES,
                      help="The number of authentication retries")

    # Attribute options
    parser.add_option("-c", "--chresponse", help="Response to challenge message")
    parser.add_option("-a", "--attributes", help="Location of attributes file")

    # Do actual argument parsing and check validity
    options, _ = parser.parse_args()

    if options.version:
        print(version)
        sys.exit(0)

    # Check to make sure we have the proper options

    if not options.secret:
        print("You must pass the shared secret wuth -s|--secret")
        sys.exit(3)

    if not options.hostname or not options.username or not options.password:
        print("You must specify a -h|--hostname, -u|--username, -p|--password")
        sys.exit(3)

    return options


# Run the actual check
def main():
    options = parse_args()

    host = options.hostname
    port = int(options.port)
    secret = options.secret
    username = options.username
    password = options.password
    retries = int(options.retries)
    timeout = int(options.timeout)
    verbose = options.verbose

    # Add attributes
    # - We get one attribute per line as "attr=value"
    if options.attributes:
        attrs = Attributes()
        with open(options.attributes) as f:
            for line in f:
                key,val = line.split("=")
                if key and val:
                    attrs[key] = val

    def _do_output(status):
        end = timer()
        time = round(((end - start) * 1000), 3) # Get time in milliseconds
        perfdata = " | 'auth_time'=" + str(time) + "ms"
        if status:
            print("OK: Authentication succeeded" + perfdata)
            sys.exit(0)
        else:
            print("CRITICAL: Authentication failed" + perfdata)
            sys.exit(2)

    # Try to connect
    start = timer()
    try:
        if options.attributes:
            _do_output(authenticate(secret, username, password, host=host,
                                  port=port, timeout=timeout,
                                  retries=retries, attributes=attrs))
        else:
            _do_output(authenticate(secret, username, password, host=host,
                                  port=port, timeout=timeout, retries=retries))
    except ChallengeResponse as e:
        # Challenge response portion (e.messages is a list of challenges)
        response = options.chresponse
        a = Attributes()
        if e.state:
            a['State'] = e.state
        try:
            _do_output(authenticate(secret, username, response, host=host,
                                  port=port, attributes=a, timeout=timeout,
                                  retries=retries))
        except Exception as e:
            if verbose:
                traceback.print_exc
            _do_output(False)
    except Exception as e:
        if verbose:
            traceback.print_exc
        print('CRITICAL: Authentication failed (' + str(e) + ')')
        sys.exit(2)


if __name__ == '__main__':

    LOGGER.addHandler(logging.StreamHandler())
    LOGGER.setLevel(logging.ERROR)

    main()
