Files
pyrad3/src/pyrad3/packet.py
2020-08-12 10:22:01 +02:00

346 lines
11 KiB
Python

# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Class for RADIUS Packet"""
from collections import OrderedDict
from secrets import token_bytes
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import hashlib
import hmac
from pyrad3.host import Host
from pyrad3.code import Code
from pyrad3.utils import (
PacketError,
Attribute,
parse_header,
parse_attributes,
calculate_authenticator,
validate_pap_password,
validate_chap_password,
)
HMAC = hmac.new
class Packet(OrderedDict):
"""Generic RADIUS Packet
Usually this should be used as a Super-Class for specific RADIUS Packets.
"""
def __init__(
self,
host: Host,
code: Code,
radius_id: int,
*,
request: "Packet" = None,
**attributes
):
super().__init__(**attributes)
self.code = code
self.identifier = radius_id
self.host = host
self.request = request
self.ordered_attributes: Sequence[Attribute] = []
self.raw_packet: Optional[bytearray] = None
self.authenticator: Optional[bytes] = None
@staticmethod
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
"""Parse the given bytearray to a RADIUS Packet"""
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
# Can we do better than Any with type hinting?
attrs: Dict[str, Any] = {}
for attr in ordered_attrs:
try:
attrs[attr.name].append(attr.value)
except KeyError:
attrs[attr.name] = [attr.value]
parsed_packet = Packet(host, code, radius_id, **attrs)
parsed_packet.authenticator = authenticator
parsed_packet.raw_packet = raw_packet
parsed_packet.ordered_attributes = ordered_attrs
return parsed_packet
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
"""Parse a bytearray """
self.verify_reply(raw_packet)
reply = Packet.from_raw(self.host, raw_packet)
reply.request = self
try:
if not reply.validate_message_authenticator():
raise PacketError("Packet has a wrong message authenticator")
except KeyError:
if "EAP-Message" in reply:
raise PacketError("Packet is missing a message authenticator")
return reply
def send(self):
"""Send the packet to the Client/Server.
"""
self.host._send_packet(self) # pylint: disable=protected-access
def verify_reply(self, raw_reply: bytes):
"""Verify the reply to this packet"""
if self.identifier != raw_reply[1]:
raise PacketError("Response has a wrong id")
# self.authenticator MUST be set, this packet got send so by definitation
# self.authenticator will not be non, but bytes
radius_hash = calculate_authenticator(
self.host.secret,
self.authenticator, # type: ignore
raw_reply,
)
if radius_hash != raw_reply[4:20]:
raise PacketError("Reply Packet has a wrong authenticator")
def validate_message_authenticator(self):
"""Validate the Message-Authenticator within the given RADIUS Packet"""
message_authenticator = self["Message-Authenticator"]
if isinstance(list, message_authenticator):
# There are multiple Message Authenticators, but a packet MUST NOT have
# more than one
return False
ma_attribute = self.find_first_attribute("Message-Authenticator")
generated = self._generate_message_authenticator(ma_attribute)
return message_authenticator == generated
def _generate_message_authenticator(self, ma_attr: Attribute):
"""Calculate the Message-Authenticator for the given Packet"""
assert self.authenticator is not None
assert self.request is not None
assert self.request.authenticator is not None
assert self.raw_packet is not None
# The message authenticator must be treated as 16 * \00
start_pos = ma_attr.pos + 2
end_pos = start_pos + 16
original_ma: bytes = ma_attr.value
self.raw_packet[start_pos:end_pos] = 16 * b"\00"
hmac_builder = HMAC(self.host.secret, digestmod=hashlib.md5)
hmac_builder.update(self.raw_packet)
if self.code in (Code.AccessRequest, Code.StatusServer):
hmac_builder.update(self.authenticator)
elif self.code in (
Code.AccessAccept,
Code.AccessChallenge,
Code.AccessReject,
):
hmac_builder.update(self.request.authenticator)
else:
hmac_builder.update(16 * b"\00")
hmac_builder.update(self.raw_packet[20:])
self.raw_packet[start_pos:end_pos] = original_ma
return hmac_builder.digest()
def add_message_authenticator(self):
"""Add a Message-Authenticator to the RADIUS packet"""
self._encode_packet()
self._generate_message_authenticator(self)
try:
# quick lookup before we iterate over the whole packet
_ = self["Message-Authenticator"]
attr = self.find_first_attribute("Message-Authenticator")
except KeyError:
self["Message-Authenticator"] = 16 * b"\00"
attr = self.ordered_attributes[-1]
generated = self._generate_message_authenticator(attr)
index = attr.pos + 2
self[index:] = generated
def refresh_message_authenticator(self):
"""Refresh an existing message-Authenticator
This method is equivalent to add_message_authenticator, but
the name can provide more context information.
"""
self.add_message_authenticator()
def find_first_attribute(
self, attr_type_name: Union[str, int, Tuple[int, ...]]
) -> Attribute:
"""Find the first attribute with the given name or code"""
for attr in self.ordered_attributes:
if attr.type == attr_type_name:
return attr.type
raise KeyError
def _encode_packet(self):
self.raw_packet = None
def serialize(self) -> bytes:
"""Serialize the Packet to the RADIUS Line Format"""
if self.host.message_authenticator:
self.add_message_authenticator()
# TODO: This is not an abstract method
raise NotImplementedError
class AuthPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS Authentication Packet"""
def __init__(
self,
host: Host,
radius_id: int,
auth_type,
*,
code: Code = Code.AccessRequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
self.auth_type = auth_type
if code == Code.AccessRequest:
self.authenticator = token_bytes(16)
def create_accept(self, **attributes):
"""Create an Access-Accept for a given Access-Request"""
return AuthPacket(
self.host,
self.identifier,
self.auth_type,
request=self,
code=Code.AccessAccept,
**attributes
)
def create_reject(self, **attributes):
"""Create an Access-Reject for a given Access-Request"""
return AuthPacket(
self.host,
self.identifier,
self.auth_type,
request=self,
code=Code.AccessReject,
**attributes
)
def create_challange(self, **attributes):
"""Create an Access-Challange for a given Access-Request"""
return AuthPacket(
self.host,
self.identifier,
self.auth_type,
request=self,
code=Code.AccessChallenge,
**attributes
)
def validate_password(self, password: bytes) -> bool:
"""Validate a password of an Access-Request"""
try:
return self.validate_pap(password)
except KeyError:
pass
# Will throw KeyError if no chap password exists
return self.validate_chap(password)
def validate_pap(self, password: bytes) -> bool:
"""Validate a PAP-Password of an Access-Request"""
packet_password = self["User-Password"]
return validate_pap_password(
self.host.secret,
self.authenticator, # type: ignore
packet_password,
password,
)
def validate_chap(self, password: bytes) -> bool:
"""Validate the CHAP-Password of an Access-Request"""
packet_password = self["Chap-Password"]
chap_id = packet_password[:1]
chap_password = packet_password[1:]
try:
challenge = self["Chap-Challenge"]
except KeyError:
challenge = self.authenticator
return validate_chap_password(
chap_id, challenge, chap_password, password, # type: ignore
)
class AcctPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS Accounting Packet"""
def __init__(
self,
host: Host,
radius_id: int,
*,
code: Code = Code.AccountingRequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_response(self, **attributes):
"""Create an Accounting-Response to a given Accounting-Request"""
return AcctPacket(
self.host,
self.identifier,
code=Code.AccountingResponse,
request=self,
**attributes
)
def increase_acct_delay_time(self, delay_time: float):
"""Increase the Accounting Delay Time
No check for the Accounting Codes are made, an
Accounting-Respones should not contain the
Acct-Delay-Time Attribute.
"""
try:
self["Acct-Delay-Time"] += int(delay_time)
except KeyError:
pass
class CoAPacket(Packet): # pylint: disable=abstract-method
"""Generic RADIUS CoA Packet"""
def __init__(
self,
host: Host,
radius_id: int,
*,
code: Code = Code.CoARequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_ack(self, **attributes):
"""Create a RADIUS Packet of type CoA-Ack"""
return CoAPacket(
self.host,
self.identifier,
code=Code.CoAACK,
request=self,
**attributes
)
def create_nack(self, **attributes):
"""Create a RADIUS Packet of type CoA-Nack"""
return CoAPacket(
self.host,
self.identifier,
code=Code.CoANAK,
request=self,
**attributes
)