diff --git a/src/pyrad3/tools.py b/src/pyrad3/tools.py index 9af8465..6a444f6 100644 --- a/src/pyrad3/tools.py +++ b/src/pyrad3/tools.py @@ -4,6 +4,7 @@ """Collections of functions to en- and decode RADIUS Attributes""" import struct +from functools import partial from ipaddress import ( IPv4Address, IPv4Network, @@ -12,9 +13,10 @@ from ipaddress import ( ip_address, ip_network, ) -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Optional, Union from pyrad3.dictionary import Datatype +from pyrad3.types import DecodeError def _from_bytes(value: bytes) -> int: @@ -25,9 +27,7 @@ def encode_string(string: str) -> bytes: """Encode a RADIUS value of type string""" if len(string) > 253: raise ValueError("Can only encode strings of <= 253 characters") - if isinstance(string, str): - return string.encode("utf-8") - return string + return string.encode("utf-8") def encode_octets(string: bytes) -> bytes: @@ -171,8 +171,10 @@ def decode_string(string: bytes) -> Union[str, bytes]: return string -def decode_octets(string: bytes) -> bytes: +def decode_octets(string: bytes, length: Optional[int] = None) -> bytes: """Decode a RADIUS value of type octets""" + if length is not None and len(string) != length: + raise DecodeError return string @@ -234,11 +236,11 @@ ENCODE_MAP: Dict[Datatype, Callable[[Any], bytes]] = { # TODO: length check (8) Datatype.IFID: encode_octets, Datatype.ABINARY: encode_ascend_binary, - Datatype.BYTE: lambda value: encode_integer(value, "!B"), - Datatype.SHORT: lambda value: encode_integer(value, "!H"), - Datatype.SIGNED: lambda value: encode_integer(value, "!i"), + Datatype.BYTE: partial(encode_integer, struct_format="!B"), + Datatype.SHORT: partial(encode_integer, struct_format="!H"), + Datatype.SIGNED: partial(encode_integer, struct_format="!i"), Datatype.INTEGER: encode_integer, - Datatype.INTEGER64: lambda value: encode_integer(value, "!Q"), + Datatype.INTEGER64: partial(encode_integer, struct_format="!Q"), Datatype.DATE: encode_date, } @@ -260,11 +262,11 @@ DECODE_MAP: Dict[Datatype, Callable[[bytes], Any]] = { Datatype.IPV6PREFIX: decode_ipv6_prefix, Datatype.COMBOIP: decode_combo_ip, # TODO: length check (8) - Datatype.IFID: decode_octets, + Datatype.IFID: partial(decode_octets, length=8), Datatype.ABINARY: decode_ascend_binary, Datatype.BYTE: decode_integer, Datatype.SHORT: decode_integer, - Datatype.SIGNED: lambda num: decode_integer(num, True), + Datatype.SIGNED: partial(decode_integer, signed=True), Datatype.INTEGER: decode_integer, Datatype.INTEGER64: decode_integer, Datatype.DATE: decode_date, diff --git a/src/pyrad3/types.py b/src/pyrad3/types.py index d3a47ee..d00de5c 100644 --- a/src/pyrad3/types.py +++ b/src/pyrad3/types.py @@ -12,6 +12,14 @@ from enum import Enum, IntEnum, auto from typing import Dict, List, Tuple, Union +class PyradError(Exception): + pass + + +class DecodeError(PyradError): + pass + + class Code(IntEnum): """Valid RADIUS codes (registered in IANA)"""