Check fixed size octet

This commit is contained in:
Istvan Ruzman
2022-02-21 20:46:38 +01:00
parent 51b351ce3e
commit e11fe6c71f
2 changed files with 21 additions and 11 deletions

View File

@@ -4,6 +4,7 @@
"""Collections of functions to en- and decode RADIUS Attributes""" """Collections of functions to en- and decode RADIUS Attributes"""
import struct import struct
from functools import partial
from ipaddress import ( from ipaddress import (
IPv4Address, IPv4Address,
IPv4Network, IPv4Network,
@@ -12,9 +13,10 @@ from ipaddress import (
ip_address, ip_address,
ip_network, ip_network,
) )
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Optional, Union
from pyrad3.dictionary import Datatype from pyrad3.dictionary import Datatype
from pyrad3.types import DecodeError
def _from_bytes(value: bytes) -> int: def _from_bytes(value: bytes) -> int:
@@ -25,9 +27,7 @@ def encode_string(string: str) -> bytes:
"""Encode a RADIUS value of type string""" """Encode a RADIUS value of type string"""
if len(string) > 253: if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters") raise ValueError("Can only encode strings of <= 253 characters")
if isinstance(string, str): return string.encode("utf-8")
return string.encode("utf-8")
return string
def encode_octets(string: bytes) -> bytes: def encode_octets(string: bytes) -> bytes:
@@ -171,8 +171,10 @@ def decode_string(string: bytes) -> Union[str, bytes]:
return string return string
def decode_octets(string: bytes) -> bytes: def decode_octets(string: bytes, length: Optional[int] = None) -> bytes:
"""Decode a RADIUS value of type octets""" """Decode a RADIUS value of type octets"""
if length is not None and len(string) != length:
raise DecodeError
return string return string
@@ -234,11 +236,11 @@ ENCODE_MAP: Dict[Datatype, Callable[[Any], bytes]] = {
# TODO: length check (8) # TODO: length check (8)
Datatype.IFID: encode_octets, Datatype.IFID: encode_octets,
Datatype.ABINARY: encode_ascend_binary, Datatype.ABINARY: encode_ascend_binary,
Datatype.BYTE: lambda value: encode_integer(value, "!B"), Datatype.BYTE: partial(encode_integer, struct_format="!B"),
Datatype.SHORT: lambda value: encode_integer(value, "!H"), Datatype.SHORT: partial(encode_integer, struct_format="!H"),
Datatype.SIGNED: lambda value: encode_integer(value, "!i"), Datatype.SIGNED: partial(encode_integer, struct_format="!i"),
Datatype.INTEGER: encode_integer, Datatype.INTEGER: encode_integer,
Datatype.INTEGER64: lambda value: encode_integer(value, "!Q"), Datatype.INTEGER64: partial(encode_integer, struct_format="!Q"),
Datatype.DATE: encode_date, Datatype.DATE: encode_date,
} }
@@ -260,11 +262,11 @@ DECODE_MAP: Dict[Datatype, Callable[[bytes], Any]] = {
Datatype.IPV6PREFIX: decode_ipv6_prefix, Datatype.IPV6PREFIX: decode_ipv6_prefix,
Datatype.COMBOIP: decode_combo_ip, Datatype.COMBOIP: decode_combo_ip,
# TODO: length check (8) # TODO: length check (8)
Datatype.IFID: decode_octets, Datatype.IFID: partial(decode_octets, length=8),
Datatype.ABINARY: decode_ascend_binary, Datatype.ABINARY: decode_ascend_binary,
Datatype.BYTE: decode_integer, Datatype.BYTE: decode_integer,
Datatype.SHORT: decode_integer, Datatype.SHORT: decode_integer,
Datatype.SIGNED: lambda num: decode_integer(num, True), Datatype.SIGNED: partial(decode_integer, signed=True),
Datatype.INTEGER: decode_integer, Datatype.INTEGER: decode_integer,
Datatype.INTEGER64: decode_integer, Datatype.INTEGER64: decode_integer,
Datatype.DATE: decode_date, Datatype.DATE: decode_date,

View File

@@ -12,6 +12,14 @@ from enum import Enum, IntEnum, auto
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
class PyradError(Exception):
pass
class DecodeError(PyradError):
pass
class Code(IntEnum): class Code(IntEnum):
"""Valid RADIUS codes (registered in IANA)""" """Valid RADIUS codes (registered in IANA)"""