Save progress

This commit is contained in:
Istvan Ruzman
2020-08-13 09:17:20 +02:00
parent 93193ee18d
commit 6023ec948a
5 changed files with 147 additions and 62 deletions

View File

@@ -1,2 +1,2 @@
[flake8] [flake8]
max-line-length = 100 max-line-length = 160

View File

@@ -2,4 +2,4 @@
ignore=tests,examples ignore=tests,examples
[MESSAGES_CONTROL] [MESSAGES_CONTROL]
disable=bad-continuation disable=bad-continuation,fixme

View File

@@ -3,7 +3,7 @@
"""Collections of functions to en- and decode RADIUS Attributes""" """Collections of functions to en- and decode RADIUS Attributes"""
from typing import Union from typing import Optional, Union
from ipaddress import ( from ipaddress import (
IPv4Address, IPv4Address,
IPv4Network, IPv4Network,
@@ -25,19 +25,23 @@ def encode_string(string: str) -> bytes:
return string return string
def encode_octets(string: bytes) -> bytes: def encode_octets(string: bytes, explen: Optional[int]) -> bytes:
"""Encode a RADIUS value of type octet""" """Encode a RADIUS value of type octet"""
if len(string) > 253: if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters") raise ValueError("Can only encode strings of <= 253 characters")
if explen is not None and len(string) != explen:
raise ValueError(
f"Expected a value length of {explen} got {len(string)}"
)
return string return string
def encode_address(addr: Union[str, IPv4Address]) -> bytes: def encode_ipv4_address(addr: Union[str, IPv4Address]) -> bytes:
"""Encode a RADIUS value of type ipaddr""" """Encode a RADIUS value of type ipaddr"""
return IPv4Address(addr).packed return IPv4Address(addr).packed
def encode_network(network: Union[str, IPv4Network]) -> bytes: def encode_ipv4_prefix(network: Union[str, IPv4Network]) -> bytes:
"""Encode a RADIUS value of type ipv4prefix""" """Encode a RADIUS value of type ipv4prefix"""
address = IPv4Network(network) address = IPv4Network(network)
return ( return (
@@ -45,6 +49,11 @@ def encode_network(network: Union[str, IPv4Network]) -> bytes:
) )
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type ipv6addr"""
return IPv6Address(addr).packed
def encode_ipv6_prefix(network: Union[str, IPv6Network]) -> bytes: def encode_ipv6_prefix(network: Union[str, IPv6Network]) -> bytes:
"""Encode a RADIUS value of type ipv6prefix""" """Encode a RADIUS value of type ipv6prefix"""
address = IPv6Network(network) address = IPv6Network(network)
@@ -53,11 +62,6 @@ def encode_ipv6_prefix(network: Union[str, IPv6Network]) -> bytes:
) + address.network_address.packed.rstrip(b"\0") ) + address.network_address.packed.rstrip(b"\0")
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type ipv6addr"""
return IPv6Address(addr).packed
def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes: def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type combo-ip""" """Encode a RADIUS value of type combo-ip"""
return ip_address(addr).packed return ip_address(addr).packed
@@ -168,16 +172,33 @@ def decode_string(string: bytes) -> Union[str, bytes]:
return string return string
def decode_octets(string: bytes) -> bytes: def decode_octets(string: bytes, explen: Optional[int] = None) -> bytes:
"""Decode a RADIUS value of type octet""" """Decode a RADIUS value of type octet"""
if explen is not None and len(string) != explen:
raise ValueError(
f"Expected a value length of {explen} got {len(string)}"
)
return string return string
def decode_address(addr: bytes) -> IPv4Address: def decode_ipv4_address(addr: bytes) -> IPv4Address:
"""Decode a RADIUS value of type ipaddr""" """Decode a RADIUS value of type ipaddr"""
return IPv4Address(addr) return IPv4Address(addr)
def decode_ipv4_prefix(addr: bytes) -> IPv4Network:
"""Decode a RADIUS value of type ipv6prefix"""
prefix = addr[:1]
addr = addr[1:]
return IPv4Network((prefix, addr))
def decode_ipv6_address(addr: bytes) -> IPv6Address:
"""Decode a RADIUS value of type ipv6addr"""
addr = addr + b"\x00" * (16 - len(addr))
return IPv6Address(addr)
def decode_ipv6_prefix(addr: bytes) -> IPv6Network: def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
"""Decode a RADIUS value of type ipv6prefix""" """Decode a RADIUS value of type ipv6prefix"""
addr = addr + b"\x00" * (18 - len(addr)) addr = addr + b"\x00" * (18 - len(addr))
@@ -186,18 +207,12 @@ def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
return IPv6Network((prefix, addr)) return IPv6Network((prefix, addr))
def decode_ipv6_address(addr: bytes) -> IPv6Address:
"""Decode a RADIUS value of type ipv6addr"""
addr = addr + b"\x00" * (16 - len(addr))
return IPv6Address(addr)
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]: def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
"""Decode a RADIUS value of type combo-ip""" """Decode a RADIUS value of type combo-ip"""
return ip_address(addr).packed return ip_address(addr).packed
def decode_ascend_binary(string): def decode_ascend_binary(string: bytes):
"""Decode a RADIUS value of type abinary""" """Decode a RADIUS value of type abinary"""
raise NotImplementedError raise NotImplementedError
@@ -207,7 +222,7 @@ def decode_integer(num: bytes, struct_format="!I") -> int:
return (struct.unpack(struct_format, num))[0] return (struct.unpack(struct_format, num))[0]
def decode_date(num): def decode_date(num: bytes) -> int: # TODO: type
"""Decode a RADIUS value of type date""" """Decode a RADIUS value of type date"""
return (struct.unpack("!I", num))[0] return (struct.unpack("!I", num))[0]
@@ -215,14 +230,17 @@ def decode_date(num):
ENCODE_MAP = { ENCODE_MAP = {
"string": encode_string, "string": encode_string,
"octets": encode_octets, "octets": encode_octets,
"integer": encode_integer, "ipaddr": encode_ipv4_address,
"ipaddr": encode_address, "ipv4prefix": encode_ipv4_prefix,
"ipv6prefix": encode_ipv6_prefix,
"ipv6addr": encode_ipv6_address, "ipv6addr": encode_ipv6_address,
"ipv6prefix": encode_ipv6_prefix,
"comboip": encode_combo_ip,
"ifid": lambda value: encode_octets(value, 8),
"abinary": encode_ascend_binary, "abinary": encode_ascend_binary,
"signed": lambda value: encode_integer(value, "!i"),
"short": lambda value: encode_integer(value, "!H"),
"byte": lambda value: encode_integer(value, "!B"), "byte": lambda value: encode_integer(value, "!B"),
"short": lambda value: encode_integer(value, "!H"),
"signed": lambda value: encode_integer(value, "!i"),
"integer": encode_integer,
"integer64": lambda value: encode_integer(value, "!Q"), "integer64": lambda value: encode_integer(value, "!Q"),
"date": encode_date, "date": encode_date,
} }
@@ -239,14 +257,17 @@ def encode_attr(datatype, value):
DECODE_MAP = { DECODE_MAP = {
"string": decode_string, "string": decode_string,
"octets": decode_octets, "octets": decode_octets,
"integer": decode_integer, "ipaddr": decode_ipv4_address,
"ipaddr": decode_address, "ipv4prefix": decode_ipv4_prefix,
"ipv6prefix": decode_ipv6_prefix,
"ipv6addr": decode_ipv6_address, "ipv6addr": decode_ipv6_address,
"ipv6prefix": decode_ipv6_prefix,
"comboip": decode_combo_ip,
"ifid": lambda value: decode_octets(value, 8),
"abinary": decode_ascend_binary, "abinary": decode_ascend_binary,
"signed": lambda value: decode_integer(value, "!i"),
"short": lambda value: decode_integer(value, "!H"),
"byte": lambda value: decode_integer(value, "!B"), "byte": lambda value: decode_integer(value, "!B"),
"short": lambda value: decode_integer(value, "!H"),
"signed": lambda value: decode_integer(value, "!i"),
"integer": decode_integer,
"integer64": lambda value: decode_integer(value, "!Q"), "integer64": lambda value: decode_integer(value, "!Q"),
"date": decode_date, "date": decode_date,
} }

View File

@@ -4,7 +4,7 @@
"""Collection of functions to deal with RADIUS packet en- and decoding.""" """Collection of functions to deal with RADIUS packet en- and decoding."""
from collections import namedtuple from collections import namedtuple
from typing import List, Union from typing import List, Tuple, Union
import hashlib import hashlib
import secrets import secrets
@@ -99,6 +99,7 @@ def parse_vendor_attributes(
raise PacketError raise PacketError
vendor_id = int.from_bytes(vendor_value[:4], "big") vendor_id = int.from_bytes(vendor_value[:4], "big")
vendor_dict = rad_dict.vendor[vendor_id] vendor_dict = rad_dict.vendor[vendor_id]
vendor_prefix = [26, vendor_id]
vendor_name = vendor_dict.name vendor_name = vendor_dict.name
attributes = [] attributes = []
@@ -118,14 +119,16 @@ def parse_vendor_attributes(
] ]
else: else:
offset = offset - len(vendor_tlv) + length offset = offset - len(vendor_tlv) + length
key = parse_key(vendor_dict, key) key = parse_key(rad_dict, tuple(vendor_prefix + [key]))
attribute = parse_value(vendor_dict, key, offset, vendor_tlv) attribute = parse_value(vendor_dict, key, offset, vendor_tlv)
attributes.extend(attribute) attributes.extend(attribute)
vendor_tlv = vendor_tlv[length:] vendor_tlv = vendor_tlv[length:]
return attributes return attributes
def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]: def parse_key(
rad_dict: Dictionary, key_id: Union[int, Tuple[int, ...]]
) -> Union[str, int, Tuple[int, ...]]:
"""Parse the key in the Dictionary Context""" """Parse the key in the Dictionary Context"""
try: try:
return rad_dict.attrindex[key_id].name return rad_dict.attrindex[key_id].name
@@ -171,8 +174,17 @@ def password_encode(
secret: bytes, authenticator: bytes, password: bytes secret: bytes, authenticator: bytes, password: bytes
) -> bytes: ) -> bytes:
"""Obfuscate the plaintext Password for RADIUS""" """Obfuscate the plaintext Password for RADIUS"""
password += b"\x00" * (16 - (len(password) % 16)) buf = password + b"\x00" * (16 - (len(password) % 16))
return obfuscation_algorithm(secret, authenticator, password) last = authenticator
results = []
while buf:
cur_hash = MD5(secret + last).digest()
tmp = [cbuf ^ chash for cbuf, chash in zip(buf, cur_hash)]
results += tmp
(last, buf) = (bytes(tmp), buf[16:])
return bytes(results)
def password_decode( def password_decode(
@@ -185,33 +197,33 @@ def password_decode(
If the original password had some trailing \\x00 it will get lost. Therefore it is If the original password had some trailing \\x00 it will get lost. Therefore it is
not recommended to use (trailing) \\x00 in passwords. not recommended to use (trailing) \\x00 in passwords.
""" """
deobfuscated = obfuscation_algorithm( buf = obfuscated_password
secret, authenticator, obfuscated_password
)
return deobfuscated.rstrip(b"\x00").decode("utf-8")
def obfuscation_algorithm(
secret: bytes, authenticator: bytes, password: bytes
) -> bytes:
"""Obfuscate the plaintext password.
This function does not deal with the padding (which the
RADIUS Protocol requires.)
The User has to pad the password themself, or better use
the `password_encode` or `password_decode` function.
"""
result = b""
buf = password
last = authenticator last = authenticator
results = []
while buf: while buf:
cur_hash = MD5(secret + last).digest() cur_hash = MD5(secret + last).digest()
for cbuf, chash in zip(buf, cur_hash): results += [cbuf ^ chash for cbuf, chash in zip(buf, cur_hash)]
result += bytes([cbuf ^ chash])
(last, buf) = (buf[:16], buf[16:]) (last, buf) = (buf[:16], buf[16:])
return result return bytes(results).rstrip(b"\x00").decode("utf-8")
def create_chap_password(
chap_id: bytes, challenge: bytes, plaintext_password: bytes,
) -> bytes:
"""Create the CHAP Password with the chap_id and challenge.
The resulting CHAP Password includes the chap-id, as specified
in RFC 2865.
This function should be used to create the value for the
CHAP-PASSWORD Attribute in an Access-Request.
The CHAP-Password should be validated on the Server side
by using the validate_chap_password function.
"""
return chap_id + MD5(chap_id + plaintext_password + challenge).digest()
def validate_chap_password( def validate_chap_password(
@@ -221,8 +233,8 @@ def validate_chap_password(
plaintext_password: bytes, plaintext_password: bytes,
) -> bool: ) -> bool:
"""Validate the CHAP password against the given plaintext password""" """Validate the CHAP password against the given plaintext password"""
return ( return chap_password == create_chap_password(
chap_password == MD5(chap_id + plaintext_password + challenge).digest() chap_id, challenge, plaintext_password
) )
@@ -234,12 +246,12 @@ def salt_encrypt(secret: bytes, authenticator: bytes, value: bytes) -> bytes:
salted_auth = authenticator + salt salted_auth = authenticator + salt
return obfuscation_algorithm(secret, salted_auth, value) return password_encode(secret, salted_auth, value)
def salt_decrypt( def salt_decrypt(
secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes
) -> bytes: ) -> str:
"""Decrypt the given value""" """Decrypt the given value"""
salted_auth = authenticator + salt salted_auth = authenticator + salt
return obfuscation_algorithm(secret, salted_auth, encrypted_value) return password_decode(secret, salted_auth, encrypted_value)

View File

@@ -11,6 +11,8 @@ import pytest
# def test_valid_header(header): # def test_valid_header(header):
# utils.parse_header(header) # utils.parse_header(header)
SECRET = b"secret"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"header", "header",
@@ -24,3 +26,53 @@ import pytest
def test_invalid_header(header): def test_invalid_header(header):
with pytest.raises(utils.PacketError): with pytest.raises(utils.PacketError):
utils.parse_header(header) utils.parse_header(header)
@pytest.mark.parametrize(
"plaintext, obfuscated, authenticator",
[
(
b"short_password",
"ed9b49281f9de8edefae1b09b04beb86",
"7b19486d8372b8c136ccf2444d0a5b2c",
),
(
b"superlongpassword_exeeding_16_bytes",
"1f123e277869997fdfb93f6df037024463918d29064c9fcd5831c57dccd9308ac6b835e6d8f70995d1498a6c5a2a5b71",
"12441ce350ce269c04f650f7923058e1",
),
],
)
def test_password(plaintext, obfuscated, authenticator):
obfuscated = bytes.fromhex(obfuscated)
authenticator = bytes.fromhex(authenticator)
encoded = utils.password_encode(SECRET, authenticator, plaintext)
assert len(encoded) == len(obfuscated)
assert encoded == obfuscated
decoded = utils.password_decode(SECRET, authenticator, encoded)
plaintext_str = plaintext.decode("utf-8")
assert len(decoded) == len(plaintext_str)
assert decoded == plaintext_str
assert utils.validate_pap_password(
SECRET, authenticator, encoded, plaintext
)
@pytest.mark.parametrize(
"plaintext, chap, challenge",
[
(
b"short_password",
bytes.fromhex("2302a92821f675a52df8e5a3b10e49b0ab"),
b"1234567890ABCDEF",
),
],
)
def test_chap_password(plaintext, chap, challenge):
chapid = chap[:1]
encoded = utils.create_chap_password(chapid, challenge, plaintext)
assert len(encoded) == len(chap)
assert encoded == chap
assert utils.validate_chap_password(chapid, challenge, chap, plaintext)