Not so great value decoding code

This commit is contained in:
Istvan Ruzman
2020-09-03 13:32:34 +02:00
parent 9091f2eb2c
commit f508e1b614
9 changed files with 191 additions and 155 deletions

View File

@@ -3,34 +3,10 @@
"""Python RADIUS client code.
pyrad is an implementation of a RADIUS client as described in RFC2865.
pyrad3 is framework for RADIUS as described in RFC2865.
It takes care of all the details like building RADIUS packets, sending
them and decoding responses.
Here is an example of doing a authentication request::
import pyrad.packet
from pyrad.client import Client
from pyrad.dictionary import Dictionary
srv = Client(server="radius.my.domain", secret="s3cr3t",
dict = Dictionary("dicts/dictionary", "dictionary.acc"))
req = srv.CreatePacket(code=pyrad.packet.AccessRequest,
User_Name = "wichert", NAS_Identifier="localhost")
req["User-Password"] = req.PwCrypt("password")
reply = srv.SendPacket(req)
if reply.code = =pyrad.packet.AccessAccept:
print "access accepted"
else:
print "access denied"
print "Attributes returned by server:"
for key, value in reply.items():
print f'{key}: {value}')
This package contains four modules:
- client: RADIUS client code

View File

@@ -9,8 +9,8 @@ import time
from ipaddress import IPv4Address, IPv6Address
from typing import Optional, Union, cast
import pyrad3.host as H
import pyrad3.packet as P
from pyrad3 import host
from pyrad3.dictionary import Dictionary
SUPPORTED_SEND_TYPES = [
@@ -34,7 +34,7 @@ class UnsupportedPacketType(Exception):
"""Exception for received packets"""
class Client(host.Host):
class Client(H.Host):
"""A simple and extensible RADIUS Client."""
def __init__(
@@ -55,6 +55,8 @@ class Client(host.Host):
"""Bind the Address to some socket"""
self._socket_close()
self._socket_open()
# This should be always true, if there is no socket,
assert self._socket is not None
self._socket.bind(addr)
def _socket_open(self):
@@ -62,7 +64,7 @@ class Client(host.Host):
if self._socket is not None:
return
try:
family = socket.getaddrinfo(self.server, "www")[0][0]
family = socket.getaddrinfo(str(self.server), "www")[0][0]
except socket.gaierror:
family = socket.AF_INET
self._socket = socket.socket(family, socket.SOCK_DGRAM)
@@ -74,15 +76,14 @@ class Client(host.Host):
self._socket.setsockopt(
socket.SOL_SOCKET,
socket.SO_BINDTODEVICE,
self.interface,
len(self.interface),
self.interface.encode("utf-8"),
)
self._poll = select.poll()
self._poll.register(self._socket, select.POLLIN)
def _socket_close(self):
"""Close the Client socket"""
if self._socket is not None:
if self._socket is not None and self._poll is not None:
self._poll.unregister(self._socket)
self._socket.close()
self._socket = None

View File

@@ -122,10 +122,10 @@ def dict_parser(
if first_tok == "$INCLUDE":
try:
inner_filename = tokens[1]
except IndexError:
except IndexError as exc:
raise ParseError(
filename, "$INCLUDE is missing a filename", line_num,
)
) from exc
if not isabs(tokens[1]):
path = dirname(filename)
inner_filename = normpath(join(path, inner_filename))
@@ -269,12 +269,12 @@ class Dictionary:
continuation = True
except IndexError:
pass
except ValueError:
except ValueError as exc:
raise ParseError(
filename,
f"Syntax error in specification for vendor {vendor_name}",
line_num,
)
) from exc
except IndexError:
# no format definition
t_len, l_len = 1, 1
@@ -301,12 +301,12 @@ class Dictionary:
try:
vendor_id = self.vendor_lookup_id_by_name[tokens[1]]
self.cur_vendor = self.vendor[vendor_id]
except KeyError:
except KeyError as exc:
raise ParseError(
filename,
f"Unknown vendor {tokens[1]} in begin-vendor statement",
line_num,
)
) from exc
def _parse_end_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the END-VENDOR line of (Free)RADIUS dictionaries."""
@@ -357,12 +357,12 @@ class Dictionary:
if value == "0":
raise ValueError
encrypt = Encrypt(int(value)) # type: ignore
except (ValueError, TypeError):
except (ValueError, TypeError) as exc:
raise ParseError(
filename,
f"Illegal attribute encryption {value}",
line_num,
)
) from exc
else:
raise ParseError(
filename, "Unknown attribute flag {key}", line_num
@@ -377,10 +377,10 @@ class Dictionary:
for code in attr_code.split("."):
try:
code_num = _parse_number(code)
except ValueError:
except ValueError as exc:
raise ParseError(
filename, f'invalid attribute code {attr_code}""', line_num
)
) from exc
if 2 ** (8 * tlength) <= code_num:
raise ParseError(
filename,
@@ -430,8 +430,10 @@ class Dictionary:
base_datatype = datatype.split("[")[0].replace("-", "")
try:
attribute_type = Datatype[base_datatype]
except KeyError:
raise ParseError(filename, f"Illegal type: {datatype}", line_num)
except KeyError as exc:
raise ParseError(
filename, f"Illegal type: {datatype}", line_num
) from exc
attribute = Attribute(
name,
@@ -457,7 +459,9 @@ class Dictionary:
)
else:
LOG.info("Register Attribute %s", attribute.name)
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
attrcode: Union[int, Tuple[int, ...]] = codes[0] if len(
codes
) == 1 else tuple(codes)
self.attrindex[attrcode] = attribute
self.attrindex[name] = attribute
@@ -477,19 +481,19 @@ class Dictionary:
# quick and dirty way to make floats values an error
raise ValueError
value = _parse_number(vvalue)
except ValueError:
except ValueError as exc:
raise ParseError(
filename, f"Invalid number {vvalue} for VALUE {key}", line_num
)
) from exc
try:
attribute = self.attrindex[attr_name]
except KeyError:
except KeyError as exc:
raise ParseError(
filename,
f"ATTRIBUTE {attr_name} has not been defined yet",
line_num,
)
) from exc
try:
datatype = str(attribute.datatype).split(".")[1]
lmin, lmax = INTEGER_TYPES[datatype]
@@ -499,13 +503,13 @@ class Dictionary:
f"VALUE {key}({value}) is not in the limit of type {datatype}",
line_num,
)
except KeyError:
except KeyError as exc:
raise ParseError(
filename,
f"only attributes with integer typed datatypes can have"
f"value definitions {attribute.datatype}",
line_num,
)
) from exc
attribute.values[value] = key
attribute.values[key] = value

View File

@@ -48,3 +48,6 @@ class Host: # pylint: disable=too-many-arguments,too-many-instance-attributes
def create_coa_packet(self, **kwargs):
"""Create an CoA packet (requset per default)"""
return packet.CoAPacket(self, **kwargs)
def _send_packet(self, _packet: packet.Packet):
raise NotImplementedError

View File

@@ -77,9 +77,11 @@ class Packet(OrderedDict):
try:
if not reply.validate_message_authenticator():
raise PacketError("Packet has a wrong message authenticator")
except KeyError:
except KeyError as exc:
if "EAP-Message" in reply:
raise PacketError("Packet is missing a message authenticator")
raise PacketError(
"Packet is missing a message authenticator"
) from exc
return reply
def send(self):
@@ -147,8 +149,9 @@ class Packet(OrderedDict):
def add_message_authenticator(self):
"""Add a Message-Authenticator to the RADIUS packet"""
self["Message-Authenticator"] = b"\x00" * 16
self._encode_packet()
self._generate_message_authenticator(self)
self._generate_message_authenticator(self["Message-Authenticator"])
try:
# quick lookup before we iterate over the whole packet
_ = self["Message-Authenticator"]

View File

@@ -12,7 +12,13 @@ from ipaddress import (
ip_address,
ip_network,
)
from typing import Optional, Union
from typing import Any, Callable, Dict, Union
from pyrad3.dictionary import Datatype
def _from_bytes(value: bytes) -> int:
return int.from_bytes(value, byteorder="big")
def encode_string(string: str) -> bytes:
@@ -24,14 +30,10 @@ def encode_string(string: str) -> bytes:
return string
def encode_octets(string: bytes, explen: Optional[int]) -> bytes:
def encode_octets(string: bytes) -> bytes:
"""Encode a RADIUS value of type octet"""
if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters")
if explen is not None and len(string) != explen:
raise ValueError(
f"Expected a value length of {explen} got {len(string)}"
)
return string
@@ -171,12 +173,8 @@ def decode_string(string: bytes) -> Union[str, bytes]:
return string
def decode_octets(string: bytes, explen: Optional[int] = None) -> bytes:
"""Decode a RADIUS value of type octet"""
if explen is not None and len(string) != explen:
raise ValueError(
f"Expected a value length of {explen} got {len(string)}"
)
def decode_octets(string: bytes) -> bytes:
"""Decode a RADIUS value of type octets"""
return string
@@ -187,9 +185,9 @@ def decode_ipv4_address(addr: bytes) -> IPv4Address:
def decode_ipv4_prefix(addr: bytes) -> IPv4Network:
"""Decode a RADIUS value of type ipv6prefix"""
prefix = addr[:1]
prefix = _from_bytes(addr[:1])
addr = addr[1:]
return IPv4Network((prefix, addr))
return IPv4Network((addr, prefix))
def decode_ipv6_address(addr: bytes) -> IPv6Address:
@@ -201,9 +199,10 @@ def decode_ipv6_address(addr: bytes) -> IPv6Address:
def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
"""Decode a RADIUS value of type ipv6prefix"""
addr = addr + b"\x00" * (18 - len(addr))
prefix = addr[:2]
addr = addr[2:]
return IPv6Network((prefix, addr))
# ignoring the reserved field at addr[0]
prefix = _from_bytes(addr[1:2])
addr = addr[2:] + b"\x00" * (16 - len(addr[2:]))
return IPv6Network((addr, prefix))
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
@@ -226,55 +225,57 @@ def decode_date(num: bytes) -> int: # TODO: type
return (struct.unpack("!I", num))[0]
ENCODE_MAP = {
"string": encode_string,
"octets": encode_octets,
"ipaddr": encode_ipv4_address,
"ipv4prefix": encode_ipv4_prefix,
"ipv6addr": encode_ipv6_address,
"ipv6prefix": encode_ipv6_prefix,
"comboip": encode_combo_ip,
"ifid": lambda value: encode_octets(value, 8),
"abinary": encode_ascend_binary,
"byte": lambda value: encode_integer(value, "!B"),
"short": lambda value: encode_integer(value, "!H"),
"signed": lambda value: encode_integer(value, "!i"),
"integer": encode_integer,
"integer64": lambda value: encode_integer(value, "!Q"),
"date": encode_date,
ENCODE_MAP: Dict[Datatype, Callable[[Any], bytes]] = {
Datatype.string: encode_string,
Datatype.octets: encode_octets,
Datatype.ipaddr: encode_ipv4_address,
Datatype.ipv4prefix: encode_ipv4_prefix,
Datatype.ipv6addr: encode_ipv6_address,
Datatype.ipv6prefix: encode_ipv6_prefix,
Datatype.comboip: encode_combo_ip,
# TODO: length check (8)
Datatype.ifid: encode_octets,
Datatype.abinary: encode_ascend_binary,
Datatype.byte: lambda value: encode_integer(value, "!B"),
Datatype.short: lambda value: encode_integer(value, "!H"),
Datatype.signed: lambda value: encode_integer(value, "!i"),
Datatype.integer: encode_integer,
Datatype.integer64: lambda value: encode_integer(value, "!Q"),
Datatype.date: encode_date,
}
def encode_attr(datatype, value):
def encode_attr(datatype: Datatype, value: bytes) -> bytes:
"""Encode a RADIUS attribute"""
try:
return ENCODE_MAP[datatype](value)
except KeyError:
raise ValueError(f"Unknown attribute type {datatype}")
except KeyError as exc:
raise ValueError(f"Unknown attribute type {datatype}") from exc
DECODE_MAP = {
"string": decode_string,
"octets": decode_octets,
"ipaddr": decode_ipv4_address,
"ipv4prefix": decode_ipv4_prefix,
"ipv6addr": decode_ipv6_address,
"ipv6prefix": decode_ipv6_prefix,
"comboip": decode_combo_ip,
"ifid": lambda value: decode_octets(value, 8),
"abinary": decode_ascend_binary,
"byte": lambda value: decode_integer(value, "!B"),
"short": lambda value: decode_integer(value, "!H"),
"signed": lambda value: decode_integer(value, "!i"),
"integer": decode_integer,
"integer64": lambda value: decode_integer(value, "!Q"),
"date": decode_date,
DECODE_MAP: Dict[Datatype, Callable[[bytes], Any]] = {
Datatype.string: decode_string,
Datatype.octets: decode_octets,
Datatype.ipaddr: decode_ipv4_address,
Datatype.ipv4prefix: decode_ipv4_prefix,
Datatype.ipv6addr: decode_ipv6_address,
Datatype.ipv6prefix: decode_ipv6_prefix,
Datatype.comboip: decode_combo_ip,
# TODO: length check (8)
Datatype.ifid: decode_octets,
Datatype.abinary: decode_ascend_binary,
Datatype.byte: lambda value: decode_integer(value, "!B"),
Datatype.short: lambda value: decode_integer(value, "!H"),
Datatype.signed: lambda value: decode_integer(value, "!i"),
Datatype.integer: decode_integer,
Datatype.integer64: lambda value: decode_integer(value, "!Q"),
Datatype.date: decode_date,
}
def decode_attr(datatype, value):
def decode_attr(datatype: Datatype, value: bytes):
"""Decode a RADIUS attribute"""
try:
return DECODE_MAP[datatype](value)
except KeyError:
raise ValueError(f"Unknown attribute type {datatype}")
except KeyError as exc:
raise ValueError(f"Unknown attribute type {datatype}") from exc

View File

@@ -10,7 +10,8 @@ from collections import namedtuple
from typing import List, Optional, Tuple, Union
from pyrad3.code import Code
from pyrad3.dictionary import Dictionary
from pyrad3.dictionary import Attribute as DictAttr, Datatype, Dictionary
from pyrad3.tools import decode_attr
RANDOM_GENERATOR = secrets.SystemRandom()
MD5 = hashlib.md5
@@ -23,13 +24,15 @@ class PacketError(Exception):
Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"])
Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"])
PreParsedAttribute = List[Tuple[int, int, DictAttr]]
def parse_header(raw_packet: bytes) -> Header:
"""Parse the Header of a RADIUS Packet."""
try:
header = struct.unpack("!BBH16s", raw_packet)
except struct.error:
raise PacketError("Packet header is corrupt")
except struct.error as exc:
raise PacketError("Packet header is corrupt") from exc
length = header[3]
if len(raw_packet) != length:
@@ -41,8 +44,8 @@ def parse_header(raw_packet: bytes) -> Header:
try:
Code(header[1])
except ValueError:
PacketError(f"Unknown RADIUS Code {header[1]}")
except ValueError as exc:
raise PacketError(f"Unknown RADIUS Code {header[1]}") from exc
return Header(*header)
@@ -60,8 +63,8 @@ def parse_attributes(
while packet:
try:
(key, length) = struct.unpack("!BB", packet[0:2])
except struct.error:
raise PacketError("Attribute header is corrupt")
except struct.error as exc:
raise PacketError("Attribute header is corrupt") from exc
if length < 2:
raise PacketError(f"Attribute length ({length}) is too small")
@@ -77,14 +80,26 @@ def parse_attributes(
Attribute(
name="Unknown-Attribute",
pos=offset,
type="octets",
type=Datatype.octets,
length=int(packet[1]),
value=packet[2:],
)
)
else:
key = parse_key(rad_dict, key)
attributes.extend(parse_value(rad_dict, key, offset, value))
try:
attr_def = rad_dict.attrindex[key]
except KeyError as exc:
# TODO
raise exc
attributes.append(
Attribute(
name=attr_def.name,
pos=offset,
type=attr_def.datatype,
length=length - 2,
value=decode_attr(attr_def.datatype, packet[2:length],),
)
)
packet = packet[length:]
return attributes
@@ -97,33 +112,78 @@ def parse_vendor_attributes(
if len(vendor_value) < 4:
raise PacketError
vendor_id = int.from_bytes(vendor_value[:4], "big")
vendor_dict = rad_dict.vendor[vendor_id]
vendor_prefix = [26, vendor_id]
attributes = []
vendor_tlv = vendor_value[4:]
offset += 4
while vendor_tlv:
try:
(key, length) = struct.unpack("!BB", vendor_tlv[0:2])
except struct.error:
attribute = [
(key, attr_length) = struct.unpack("!BB", vendor_tlv[0:2])
print(attr_length)
keystack = vendor_prefix + [key]
attr_def = rad_dict.attrindex[tuple(keystack)]
except (struct.error, KeyError) as exc:
attributes.append(
Attribute(
name="Unknown-Attribute",
pos=offset - len(vendor_value),
type="octets",
type=Datatype.octets,
length=len(vendor_value) - 4,
value=vendor_value,
)
]
)
if exc is struct.error:
break
else:
offset = offset - len(vendor_tlv) + length
key = parse_key(rad_dict, tuple(vendor_prefix + [key]))
attribute = parse_value(vendor_dict, key, offset, vendor_tlv)
attributes.extend(attribute)
vendor_tlv = vendor_tlv[length:]
if attr_def.datatype == Datatype.tlv:
attr_pos = _parse_tlv(rad_dict, vendor_tlv, 2, keystack)
else:
attr_pos = [(2, attr_length, attr_def)]
parsed_attributes: List[Attribute] = []
for local_offset, length, attr_def in attr_pos:
length -= 2
parsed_attributes.append(
Attribute(
name=attr_def.name,
pos=offset + local_offset,
type=attr_def.datatype,
length=length,
value=decode_attr(
attr_def.datatype,
vendor_tlv[local_offset : local_offset + length],
),
)
)
offset = offset + attr_length
attributes.extend(parsed_attributes)
vendor_tlv = vendor_tlv[attr_length:]
return attributes
def _parse_tlv(
rad_dict: Dictionary, block: bytes, local_offset: int, key_stack: List[int]
) -> PreParsedAttribute:
# get a list of flattened radius attributes
ret = []
while block:
(key, length) = struct.unpack("!BB", block[0:2])
attr_def = rad_dict.attrindex[tuple(key_stack)]
if attr_def.datatype == Datatype.tlv:
key_stack.append(key)
block = block[:length]
ret.extend(
_parse_tlv(rad_dict, block[2:], local_offset + 2, key_stack)
)
key_stack.pop()
else:
ret.append((local_offset + 2, length, attr_def))
local_offset += length
return ret
def parse_key(
rad_dict: Dictionary, key_id: Union[int, Tuple[int, ...]]
) -> Union[str, int, Tuple[int, ...]]:
@@ -134,18 +194,6 @@ def parse_key(
return key_id
def parse_value(*_):
"""Parse the Value in the given Key/Dictionary Context"""
raise NotImplementedError
# def parse_value(
# rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
# ) -> List[Attribute]:
# """Parse the Value in the given Key/Dictionary Context"""
# return []
def calculate_authenticator(
secret: bytes, authenticator: bytes, raw_packet: bytes
) -> bytes: