safe progress

This commit is contained in:
Istvan Ruzman
2020-08-06 18:04:24 +02:00
parent 3254bc66e0
commit fd16436c3c
53 changed files with 2167 additions and 4589 deletions

View File

@@ -1,3 +1,2 @@
[flake8]
max-complexity = 10
max-line-length = 100

View File

@@ -2,10 +2,13 @@
python.pkgs.buildPythonPackage rec {
pname = "pyrad";
version = "3.0-alpha";
version = "1.0-alpha";
buildInputs = with python.pkgs; [ netaddr six ];
buildInputs = with python.pkgs; [ ];
checkInputs = with python.pkgs; [
black
pytest
# pytest-cov
];
}

View File

@@ -15,41 +15,47 @@ def send_accounting_packet(srv, req):
try:
srv.SendPacket(req)
except pyrad.client.Timeout:
print('RADIUS server does not reply')
print("RADIUS server does not reply")
sys.exit(1)
except socket.error as error:
print('Network error: ' + error[1])
print("Network error: " + error[1])
sys.exit(1)
def main(path_to_dictionary):
srv = Client(server='127.0.0.1',
secret=b'Kah3choteereethiejeimaeziecumi',
dict=Dictionary(path_to_dictionary))
srv = Client(
server="127.0.0.1",
secret=b"Kah3choteereethiejeimaeziecumi",
dict=Dictionary(path_to_dictionary),
)
req = srv.CreateAcctPacket(**{
'User-Name': 'wichert',
'NAS-IP-Address': '192.168.1.10',
'NAS-Port': 0,
'NAS-Identifier': 'trillian',
'Called-Station-Id': '00-04-5F-00-0F-D1',
'Calling-Station-Id': '00-01-24-80-B3-9C',
'Framed-IP-Address': '10.0.0.100',
})
req = srv.CreateAcctPacket(
**{
"User-Name": "wichert",
"NAS-IP-Address": "192.168.1.10",
"NAS-Port": 0,
"NAS-Identifier": "trillian",
"Called-Station-Id": "00-04-5F-00-0F-D1",
"Calling-Station-Id": "00-01-24-80-B3-9C",
"Framed-IP-Address": "10.0.0.100",
}
)
print('Sending accounting start packet')
req['Acct-Status-Type'] = 'Start'
print("Sending accounting start packet")
req["Acct-Status-Type"] = "Start"
send_accounting_packet(srv, req)
print('Sending accounting stop packet')
req['Acct-Status-Type'] = 'Stop'
req['Acct-Input-Octets'] = random.randrange(2**10, 2**30)
req['Acct-Output-Octets'] = random.randrange(2**10, 2**30)
req['Acct-Session-Time'] = random.randrange(120, 3600)
req['Acct-Terminate-Cause'] = random.choice(['User-Request', 'Idle-Timeout'])
print("Sending accounting stop packet")
req["Acct-Status-Type"] = "Stop"
req["Acct-Input-Octets"] = random.randrange(2 ** 10, 2 ** 30)
req["Acct-Output-Octets"] = random.randrange(2 ** 10, 2 ** 30)
req["Acct-Session-Time"] = random.randrange(120, 3600)
req["Acct-Terminate-Cause"] = random.choice(
["User-Request", "Idle-Timeout"]
)
send_accounting_packet(srv, req)
if __name__ == '__main__':
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
if __name__ == "__main__":
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary)

View File

@@ -11,43 +11,46 @@ from pyrad.dictionary import Dictionary
def main(path_to_dictionary):
srv = Client(server='127.0.0.1',
secret=b'Kah3choteereethiejeimaeziecumi',
dict=Dictionary(path_to_dictionary))
srv = Client(
server="127.0.0.1",
secret=b"Kah3choteereethiejeimaeziecumi",
dict=Dictionary(path_to_dictionary),
)
req = srv.CreateAuthPacket(
code=pyrad.packet.AccessRequest,
**{
'User-Name': 'wichert',
'NAS-IP-Address': '192.168.1.10',
'NAS-Port': 0,
'Service-Type': 'Login-User',
'NAS-Identifier': 'trillian',
'Called-Station-Id': '00-04-5F-00-0F-D1',
'Calling-Station-Id': '00-01-24-80-B3-9C',
'Framed-IP-Address': '10.0.0.100',
})
"User-Name": "wichert",
"NAS-IP-Address": "192.168.1.10",
"NAS-Port": 0,
"Service-Type": "Login-User",
"NAS-Identifier": "trillian",
"Called-Station-Id": "00-04-5F-00-0F-D1",
"Calling-Station-Id": "00-01-24-80-B3-9C",
"Framed-IP-Address": "10.0.0.100",
},
)
try:
print('Sending authentication request')
print("Sending authentication request")
reply = srv.SendPacket(req)
except pyrad.client.Timeout:
print('RADIUS server does not reply')
print("RADIUS server does not reply")
sys.exit(1)
except socket.error as error:
print('Network error: ' + error[1])
print("Network error: " + error[1])
sys.exit(1)
if reply.code == pyrad.packet.AccessAccept:
print('Access accepted')
print("Access accepted")
else:
print('Access denied')
print("Access denied")
print('Attributes returned by server:')
print("Attributes returned by server:")
for key, value in reply.items():
print(f'{key} {value}')
print(f"{key} {value}")
if __name__ == '__main__':
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
if __name__ == "__main__":
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary)

View File

@@ -10,49 +10,58 @@ from pyrad.client_async import ClientAsync
from pyrad.dictionary import Dictionary
from pyrad.packet import AccessAccept
logging.basicConfig(level='DEBUG',
format='%(asctime)s [%(levelname)-8s] %(message)s')
logging.basicConfig(
level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s"
)
def create_request(client, user):
return client.CreateAuthPacket(**{
'User-Name': user,
'NAS-IP-Address': '192.168.1.10',
'NAS-Port': 0,
'Service-Type': 'Login-User',
'NAS-Identifier': 'trillian',
'Called-Station-Id': '00-04-5F-00-0F-D1',
'Calling-Station-Id': '00-01-24-80-B3-9C',
'Framed-IP-Address': '10.0.0.100',
})
return client.CreateAuthPacket(
**{
"User-Name": user,
"NAS-IP-Address": "192.168.1.10",
"NAS-Port": 0,
"Service-Type": "Login-User",
"NAS-Identifier": "trillian",
"Called-Station-Id": "00-04-5F-00-0F-D1",
"Calling-Station-Id": "00-01-24-80-B3-9C",
"Framed-IP-Address": "10.0.0.100",
}
)
def print_reply(reply):
if reply.code == AccessAccept:
print('Access accepted')
print("Access accepted")
else:
print('Access denied')
print("Access denied")
print('Attributes returned by server:')
print("Attributes returned by server:")
for key, value in reply.items():
print(f'{key}: {value}')
print(f"{key}: {value}")
def initialize_transport(loop, client):
loop.run_until_complete(
asyncio.ensure_future(
client.initialize_transports(enable_auth=True,
local_addr='127.0.0.1',
local_auth_port=8000,
enable_acct=True,
enable_coa=True)))
client.initialize_transports(
enable_auth=True,
local_addr="127.0.0.1",
local_auth_port=8000,
enable_acct=True,
enable_coa=True,
)
)
)
def main(path_to_dictionary):
client = ClientAsync(server='localhost',
secret=b'Kah3choteereethiejeimaeziecumi',
timeout=4,
dict=Dictionary(path_to_dictionary))
client = ClientAsync(
server="localhost",
secret=b"Kah3choteereethiejeimaeziecumi",
timeout=4,
dict=Dictionary(path_to_dictionary),
)
loop = asyncio.get_event_loop()
@@ -62,41 +71,41 @@ def main(path_to_dictionary):
requests = []
for i in range(255):
req = create_request(client, f'user{i}')
req = create_request(client, f"user{i}")
future = client.SendPacket(req)
requests.append(future)
# Send auth requests asynchronously to the server
loop.run_until_complete(asyncio.ensure_future(
asyncio.gather(
*requests,
return_exceptions=True
loop.run_until_complete(
asyncio.ensure_future(
asyncio.gather(*requests, return_exceptions=True)
)
))
)
for future in requests:
if future.exception():
print('EXCEPTION ', future.exception())
print("EXCEPTION ", future.exception())
else:
reply = future.result()
print_reply(reply)
# Close transports
loop.run_until_complete(asyncio.ensure_future(
client.deinitialize_transports()))
print('END')
loop.run_until_complete(
asyncio.ensure_future(client.deinitialize_transports())
)
print("END")
except Exception as exc:
print('Error: ', exc)
print("Error: ", exc)
traceback.print_exc()
# Close transports
loop.run_until_complete(asyncio.ensure_future(
client.deinitialize_transports()))
loop.run_until_complete(
asyncio.ensure_future(client.deinitialize_transports())
)
loop.close()
if __name__ == '__main__':
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
if __name__ == "__main__":
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary)

View File

@@ -14,21 +14,21 @@ from pyrad.server import Server, RemoteHost
def print_attributes(packet):
print('Attributes')
print("Attributes")
for key, value in packet.items():
print(f'{key}: {value}')
print(f"{key}: {value}")
class FakeCoA(Server):
def HandleCoaPacket(self, packet):
'''Accounting packet handler.
"""Accounting packet handler.
Function that is called when a valid
accounting packet has been received.
:param packet: packet to process
:type packet: Packet class instance
'''
print('Received a coa request %d' % packet.code)
"""
print("Received a coa request %d" % packet.code)
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
@@ -38,7 +38,7 @@ class FakeCoA(Server):
self.SendReplyPacket(packet.fd, reply)
def HandleDisconnectPacket(self, packet):
print('Received a disconnect request %d' % packet.code)
print("Received a disconnect request %d" % packet.code)
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
@@ -52,27 +52,27 @@ def main(path_to_dictionary, coa_port):
# create server/coa only and read dictionary
# bind and listen only on 127.0.0.1:argv[1]
coa = FakeCoA(
addresses=['127.0.0.1'],
addresses=["127.0.0.1"],
dict=Dictionary(path_to_dictionary),
coaport=coa_port,
auth_enabled=False,
acct_enabled=False,
coa_enabled=True)
coa_enabled=True,
)
# add peers (address, secret, name)
coa.hosts['127.0.0.1'] = RemoteHost(
'127.0.0.1',
b'Kah3choteereethiejeimaeziecumi',
'localhost')
coa.hosts["127.0.0.1"] = RemoteHost(
"127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost"
)
# start
coa.Run()
if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) != 2:
print('usage: client-coa.py {portnumber}')
print("usage: client-coa.py {portnumber}")
sys.exit(1)
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary, int(sys.argv[1]))

View File

@@ -11,27 +11,29 @@ from pyrad.dictionary import Dictionary
def main(path_to_dictionary, coa_type, nas_identifier):
# create coa client
client = Client(server='127.0.0.1',
secret=b'Kah3choteereethiejeimaeziecumi',
dict=Dictionary(path_to_dictionary))
client = Client(
server="127.0.0.1",
secret=b"Kah3choteereethiejeimaeziecumi",
dict=Dictionary(path_to_dictionary),
)
# set coa timeout
client.timeout = 30
# create coa request packet
attributes = {
'Acct-Session-Id': '1337',
'NAS-Identifier': nas_identifier,
"Acct-Session-Id": "1337",
"NAS-Identifier": nas_identifier,
}
if coa_type == 'coa':
if coa_type == "coa":
# create coa request
request = client.CreateCoAPacket(**attributes)
elif coa_type == 'dis':
elif coa_type == "dis":
# create disconnect request
request = client.CreateCoAPacket(
code=pyrad.packet.DisconnectRequest,
**attributes)
code=pyrad.packet.DisconnectRequest, **attributes
)
else:
sys.exit(1)
@@ -41,11 +43,11 @@ def main(path_to_dictionary, coa_type, nas_identifier):
print(result.code)
if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) != 3:
print('usage: coa.py {coa|dis} daemon-1234')
print("usage: coa.py {coa|dis} daemon-1234")
sys.exit(1)
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary, sys.argv[1], sys.argv[2])

View File

@@ -8,46 +8,52 @@ import pyrad.packet
from pyrad import server
from pyrad.dictionary import Dictionary
logging.basicConfig(filename='pyrad.log', level='DEBUG',
format='%(asctime)s [%(levelname)-8s] %(message)s')
logging.basicConfig(
filename="pyrad.log",
level="DEBUG",
format="%(asctime)s [%(levelname)-8s] %(message)s",
)
def print_attributes(packet):
print('Attributes')
print("Attributes")
for key, value in packet.items():
print(f'{key}: {value}')
print(f"{key}: {value}")
class FakeServer(server.Server):
def HandleAuthPacket(self, packet):
print('Received an authentication request')
print("Received an authentication request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet, **{
'Service-Type': 'Framed-User',
'Framed-IP-Address': '192.168.0.1',
'Framed-IPv6-Prefix': 'fc66::/64'
})
reply = self.CreateReplyPacket(
packet,
**{
"Service-Type": "Framed-User",
"Framed-IP-Address": "192.168.0.1",
"Framed-IPv6-Prefix": "fc66::/64",
},
)
reply.code = pyrad.packet.AccessAccept
self.SendReplyPacket(packet.fd, reply)
def HandleAcctPacket(self, packet):
print('Received an accounting request')
print("Received an accounting request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
self.SendReplyPacket(packet.fd, reply)
def HandleCoaPacket(self, packet):
print('Received an coa request')
print("Received an coa request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
self.SendReplyPacket(packet.fd, reply)
def HandleDisconnectPacket(self, packet):
print('Received an disconnect request')
print("Received an disconnect request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
@@ -58,20 +64,18 @@ class FakeServer(server.Server):
def main(path_to_dictionary):
# create server and read dictionary
srv = FakeServer(dict=Dictionary(path_to_dictionary),
coa_enabled=True)
srv = FakeServer(dict=Dictionary(path_to_dictionary), coa_enabled=True)
# add clients (address, secret, name)
srv.hosts['127.0.0.1'] = server.RemoteHost(
'127.0.0.1',
b'Kah3choteereethiejeimaeziecumi',
'localhost')
srv.BindToAddress('0.0.0.0')
srv.hosts["127.0.0.1"] = server.RemoteHost(
"127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost"
)
srv.BindToAddress("0.0.0.0")
# start server
srv.Run()
if __name__ == '__main__':
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
if __name__ == "__main__":
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary)

View File

@@ -12,57 +12,67 @@ from pyrad.server import RemoteHost
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except:
pass
logging.basicConfig(level='DEBUG',
format='%(asctime)s [%(levelname)-8s] %(message)s')
logging.basicConfig(
level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s"
)
def print_attributes(packet):
print('Attributes returned by server:')
print("Attributes returned by server:")
for key, value in packet.items():
print(f'{key}: {value}')
print(f"{key}: {value}")
class FakeServer(ServerAsync):
def __init__(self, loop, dictionary):
ServerAsync.__init__(self, loop=loop, dictionary=dictionary,
enable_pkt_verify=True, debug=True)
ServerAsync.__init__(
self,
loop=loop,
dictionary=dictionary,
enable_pkt_verify=True,
debug=True,
)
def handle_auth_packet(self, protocol, packet, addr):
print('Received an authentication request with id ', packet.id)
print('Authenticator ', packet.authenticator.hex())
print('Secret ', packet.secret)
print("Received an authentication request with id ", packet.id)
print("Authenticator ", packet.authenticator.hex())
print("Secret ", packet.secret)
print_attributes(packet)
reply = self.CreateReplyPacket(packet, **{
'Service-Type': 'Framed-User',
'Framed-IP-Address': '192.168.0.1',
'Framed-IPv6-Prefix': 'fc66::/64'
})
reply = self.CreateReplyPacket(
packet,
**{
"Service-Type": "Framed-User",
"Framed-IP-Address": "192.168.0.1",
"Framed-IPv6-Prefix": "fc66::/64",
},
)
reply.code = AccessAccept
protocol.send_response(reply, addr)
def handle_acct_packet(self, protocol, packet, addr):
print('Received an accounting request')
print("Received an accounting request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
protocol.send_response(reply, addr)
def handle_coa_packet(self, protocol, packet, addr):
print('Received an coa request')
print("Received an coa request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
protocol.send_response(reply, addr)
def handle_disconnect_packet(self, protocol, packet, addr):
print('Received an disconnect request')
print("Received an disconnect request")
print_attributes(packet)
reply = self.CreateReplyPacket(packet)
@@ -77,17 +87,19 @@ def main(path_to_dictionary):
server = FakeServer(loop=loop, dictionary=Dictionary(path_to_dictionary))
# add clients (address, secret, name)
server.hosts['127.0.0.1'] = RemoteHost('127.0.0.1',
b'Kah3choteereethiejeimaeziecumi',
'localhost')
server.hosts["127.0.0.1"] = RemoteHost(
"127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost"
)
try:
# Initialize transports
loop.run_until_complete(
asyncio.ensure_future(
server.initialize_transports(enable_auth=True,
enable_acct=True,
enable_coa=True)))
server.initialize_transports(
enable_auth=True, enable_acct=True, enable_coa=True
)
)
)
try:
# start server
@@ -96,20 +108,22 @@ def main(path_to_dictionary):
pass
# Close transports
loop.run_until_complete(asyncio.ensure_future(
server.deinitialize_transports()))
loop.run_until_complete(
asyncio.ensure_future(server.deinitialize_transports())
)
except Exception as exc:
print('Error: ', exc)
print("Error: ", exc)
traceback.print_exc()
# Close transports
loop.run_until_complete(asyncio.ensure_future(
server.deinitialize_transports()))
loop.run_until_complete(
asyncio.ensure_future(server.deinitialize_transports())
)
loop.close()
if __name__ == '__main__':
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
if __name__ == "__main__":
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary)

View File

@@ -11,32 +11,33 @@ from pyrad.dictionary import Dictionary
def main(path_to_dictionary):
srv = Client(server='localhost',
authport=18121,
secret=b'test',
dict=Dictionary(path_to_dictionary))
srv = Client(
server="localhost",
authport=18121,
secret=b"test",
dict=Dictionary(path_to_dictionary),
)
req = srv.CreateAuthPacket(
code=pyrad.packet.StatusServer,
FreeRADIUS_Statistics_Type='All',
code=pyrad.packet.StatusServer, FreeRADIUS_Statistics_Type="All",
)
req.add_message_authenticator()
try:
print('Sending FreeRADIUS status request')
print("Sending FreeRADIUS status request")
reply = srv.SendPacket(req)
except pyrad.client.Timeout:
print('RADIUS server does not reply')
print("RADIUS server does not reply")
sys.exit(1)
except socket.error as error:
print('Network error: ' + error[1])
print("Network error: " + error[1])
sys.exit(1)
print('Attributes returned by server:')
print("Attributes returned by server:")
for key, value in reply.items():
print(f'{key}: {value}')
print(f"{key}: {value}")
if __name__ == '__main__':
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
if __name__ == "__main__":
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
main(dictionary)

View File

@@ -21,13 +21,11 @@ classifiers = [
"Programming Language :: Python :: 3.8",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: System :: Systems Administration :: Authentication/Directory"
]
packages = [
{ include = "pyrad3"},
{ include = "src/pyrad3"},
]
include = [
"CHANGELOG.md",
"LICENSE-APACHE",
"LICENSE-MIT",
"README.md",
@@ -40,9 +38,16 @@ repository = "https://github.com/pyradius/pyrad3"
[tool.poetry.dev-dependencies]
pytest = "^5.4"
pytest-cov = "^2.5"
pytest-black = "^0.30"
pytest-cov = "^2.10"
pytest-mypy = "^0.6"
pytest-pylint = "^0.17"
[tool.black]
line-length = 100
line-length = 80
include = '\.py'
[tool.pylint.messages_control]
disable = "bad-continuation"

View File

@@ -1,221 +0,0 @@
# client.py
#
# Copyright 2002-2007 Wichert Akkerman <wichert@wiggy.net>
__docformat__ = "epytext en"
import hashlib
import select
import socket
import time
import struct
from pyrad import host
from pyrad import packet
EAP_CODE_REQUEST = 1
EAP_CODE_RESPONSE = 2
EAP_TYPE_IDENTITY = 1
class Timeout(Exception):
"""Simple exception class which is raised when a timeout occurs
while waiting for a RADIUS server to respond."""
class Client(host.Host):
"""Basic RADIUS client.
This class implements a basic RADIUS client. It can send requests
to a RADIUS server, taking care of timeouts and retries, and
validate its replies.
:ivar retries: number of times to retry sending a RADIUS request
:type retries: integer
:ivar timeout: number of seconds to wait for an answer
:type timeout: float
"""
def __init__(self, server, authport=1812, acctport=1813,
coaport=3799, secret=b'', dict=None, retries=3,
timeout=5):
"""Constructor.
:param server: hostname or IP address of RADIUS server
:type server: string
:param authport: port to use for authentication packets
:type authport: integer
:param acctport: port to use for accounting packets
:type acctport: integer
:param coaport: port to use for CoA packets
:type coaport: integer
:param secret: RADIUS secret
:type secret: string
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary
"""
host.Host.__init__(self, authport, acctport, coaport, dict)
self.server = server
self.secret = secret
self._socket = None
self.retries = retries
self.timeout = timeout
self._poll = select.poll()
def bind(self, addr):
"""Bind socket to an address.
Binding the socket used for communicating to an address can be
usefull when working on a machine with multiple addresses.
:param addr: network address (hostname or IP) and port to bind to
:type addr: host,port tuple
"""
self._close_socket()
self._socket_open()
self._socket.bind(addr)
def _socket_open(self):
try:
family = socket.getaddrinfo(self.server, 'www')[0][0]
except:
family = socket.AF_INET
if not self._socket:
self._socket = socket.socket(family,
socket.SOCK_DGRAM)
self._socket.setsockopt(socket.SOL_SOCKET,
socket.SO_REUSEADDR, 1)
self._poll.register(self._socket, select.POLLIN)
def _close_socket(self):
if self._socket:
self._poll.unregister(self._socket)
self._socket.close()
self._socket = None
def CreateAuthPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.AuthPacket
"""
return host.Host.CreateAuthPacket(self, secret=self.secret, **args)
def CreateAcctPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
return host.Host.CreateAcctPacket(self, secret=self.secret, **args)
def CreateCoAPacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS packet which can
be used to communicate with the RADIUS server this client
talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
return host.Host.CreateCoAPacket(self, secret=self.secret, **args)
def _send_packet(self, pkt, port):
"""Send a packet to a RADIUS server.
:param pkt: the packet to send
:type pkt: pyrad.packet.Packet
:param port: UDP port to send packet to
:type port: integer
:return: the reply packet received
:rtype: pyrad.packet.Packet
:raise Timeout: RADIUS server does not reply
"""
self._socket_open()
for attempt in range(self.retries):
if attempt and pkt.code == packet.AccountingRequest:
if "Acct-Delay-Time" in pkt:
pkt["Acct-Delay-Time"] = \
pkt["Acct-Delay-Time"][0] + self.timeout
else:
pkt["Acct-Delay-Time"] = self.timeout
now = time.time()
waitto = now + self.timeout
self._socket.sendto(pkt.RequestPacket(), (self.server, port))
while now < waitto:
ready = self._poll.poll((waitto - now) * 1000)
if ready:
rawreply = self._socket.recv(4096)
else:
now = time.time()
continue
try:
return pkt.VerifyReply(rawreply)
except packet.PacketError:
# TODO: report or error out maybe?
pass
now = time.time()
raise Timeout
def SendPacket(self, pkt):
"""Send a packet to a RADIUS server.
:param pkt: the packet to send
:type pkt: pyrad.packet.Packet
:return: the reply packet received
:rtype: pyrad.packet.Packet
:raise Timeout: RADIUS server does not reply
"""
if isinstance(pkt, packet.AuthPacket):
if pkt.auth_type == 'eap-md5':
# Creating EAP-Identity
password = pkt[2][0] if 2 in pkt else pkt[1][0]
pkt[79] = [struct.pack('!BBHB%ds' % len(password),
EAP_CODE_RESPONSE,
packet.CurrentID,
len(password) + 5,
EAP_TYPE_IDENTITY,
password)]
reply = self._send_packet(pkt, self.authport)
if (reply
and reply.code == packet.AccessChallenge
and pkt.auth_type == 'eap-md5'
):
# Got an Access-Challenge
_eap_code, eap_id, _eap_size, _eap_type, eap_md5 = struct.unpack(
'!BBHB%ds' % (len(reply[79][0]) - 5), reply[79][0]
)
# Sending back an EAP-Type-MD5-Challenge
# Thank god for http://www.secdev.org/python/eapy.py
client_pw = pkt[2][0] if 2 in pkt else pkt[1][0]
md5_challenge = hashlib.md5(
struct.pack('!B', eap_id) + client_pw + eap_md5[1:]
).digest()
pkt[79] = [
struct.pack('!BBHBB', 2, eap_id, len(md5_challenge) + 6,
4, len(md5_challenge)) + md5_challenge
]
# Copy over Challenge-State
pkt[24] = reply[24]
reply = self._send_packet(pkt, self.authport)
return reply
if isinstance(pkt, packet.CoAPacket):
return self._send_packet(pkt, self.coaport)
return self._send_packet(pkt, self.acctport)

View File

@@ -1,81 +0,0 @@
# curved.py
#
# Copyright 2002 Wichert Akkerman <wichert@wiggy.net>
"""Twisted integration code
"""
__docformat__ = 'epytext en'
from twisted.internet import protocol
from twisted.internet import reactor
from twisted.python import log
import sys
from pyrad import dictionary
from pyrad import host
from pyrad import packet
class PacketError(Exception):
"""Exception class for bogus packets
PacketError exceptions are only used inside the Server class to
abort processing of a packet.
"""
class RADIUS(host.Host, protocol.DatagramProtocol):
def __init__(self, hosts={}, dict=dictionary.Dictionary()):
host.Host.__init__(self, dict=dict)
self.hosts = hosts
def processPacket(self, pkt):
pass
def createPacket(self, **kwargs):
raise NotImplementedError('Attempted to use a pure base class')
def datagramReceived(self, datagram, source):
host, port = source
try:
pkt = self.CreatePacket(packet=datagram)
except packet.PacketError as err:
log.msg('Dropping invalid packet: ' + str(err))
return
if host not in self.hosts:
log.msg('Dropping packet from unknown host ' + host)
return
pkt.source = (host, port)
try:
self.processPacket(pkt)
except PacketError as err:
log.msg('Dropping packet from %s: %s' % (host, str(err)))
class RADIUSAccess(RADIUS):
def createPacket(self, **kwargs):
self.CreateAuthPacket(**kwargs)
def processPacket(self, pkt):
if pkt.code != packet.AccessRequest:
raise PacketError(
'non-AccessRequest packet on authentication socket')
class RADIUSAccounting(RADIUS):
def createPacket(self, **kwargs):
self.CreateAcctPacket(**kwargs)
def processPacket(self, pkt):
if pkt.code != packet.AccountingRequest:
raise PacketError(
'non-AccountingRequest packet on authentication socket')
if __name__ == '__main__':
log.startLogging(sys.stdout, 0)
reactor.listenUDP(1812, RADIUSAccess())
reactor.listenUDP(1813, RADIUSAccounting())
reactor.run()

View File

@@ -1,116 +0,0 @@
# dictfile.py
#
# Copyright 2009 Kristoffer Gronlund <kristoffer.gronlund@purplescout.se>
""" Dictionary File
Implements an iterable file format that handles the
RADIUS $INCLUDE directives behind the scene.
"""
import os
class _Node():
"""Dictionary file node
A single dictionary file.
"""
__slots__ = ('name', 'lines', 'current', 'length', 'dir')
def __init__(self, fd, name, parentdir):
self.lines = fd.readlines()
self.length = len(self.lines)
self.current = 0
self.name = os.path.basename(name)
path = os.path.dirname(name)
if os.path.isabs(path):
self.dir = path
else:
self.dir = os.path.join(parentdir, path)
def next(self):
if self.current >= self.length:
return None
self.current += 1
return self.lines[self.current - 1]
class DictFile():
"""Dictionary file class
An iterable file type that handles $INCLUDE
directives internally.
"""
__slots__ = ['stack']
def __init__(self, fil):
"""
@param fil: a dictionary file to parse
@type fil: string or file
"""
self.stack = []
self.__read_node(fil)
def __read_node(self, fil):
node = None
parentdir = self.__cur_dir()
if isinstance(fil, str):
fname = None
if os.path.isabs(fil):
fname = fil
else:
fname = os.path.join(parentdir, fil)
fd = open(fname, "rt")
node = _Node(fd, fil, parentdir)
fd.close()
else:
node = _Node(fil, '', parentdir)
self.stack.append(node)
def __cur_dir(self):
if self.stack:
return self.stack[-1].dir
return os.path.realpath(os.curdir)
@staticmethod
def __get_include(line):
line = line.split("#", 1)[0].strip()
tokens = line.split()
if tokens and tokens[0].upper() == '$INCLUDE':
return " ".join(tokens[1:])
return None
def line(self):
"""Returns line number of current file
"""
try:
return self.stack[-1].current
except (AttributeError, IndexError):
return -1
def file(self):
"""Returns name of current file
"""
try:
return self.stack[-1].name
except (AttributeError, IndexError):
return ''
def __iter__(self):
return self
def __next__(self):
while self.stack:
line = self.stack[-1].next()
if line is None:
self.stack.pop()
else:
inc = DictFile.__get_include(line)
if inc:
self.__read_node(inc)
else:
return line
raise StopIteration
next = __next__ # BBB for python <3

View File

@@ -1,393 +0,0 @@
# dictionary.py
#
# Copyright 2002,2005,2007,2016 Wichert Akkerman <wichert@wiggy.net>
"""
RADIUS uses dictionaries to define the attributes that can
be used in packets. The Dictionary class stores the attribute
definitions from one or more dictionary files.
Dictionary files are textfiles with one command per line.
Comments are specified by starting with a # character, and empty
lines are ignored.
The commands supported are::
ATTRIBUTE <attribute> <code> <type> [<vendor>]
specify an attribute and its type
VALUE <attribute> <valuename> <value>
specify a value attribute
VENDOR <name> <id>
specify a vendor ID
BEGIN-VENDOR <vendorname>
begin definition of vendor attributes
END-VENDOR <vendorname>
end definition of vendor attributes
The datatypes currently supported are:
+---------------+----------------------------------------------+
| type | description |
+===============+==============================================+
| string | ASCII string |
+---------------+----------------------------------------------+
| ipaddr | IPv4 address |
+---------------+----------------------------------------------+
| date | 32 bits UNIX |
+---------------+----------------------------------------------+
| octets | arbitrary binary data |
+---------------+----------------------------------------------+
| abinary | ascend binary data |
+---------------+----------------------------------------------+
| ipv6addr | 16 octets in network byte order |
+---------------+----------------------------------------------+
| ipv6prefix | 18 octets in network byte order |
+---------------+----------------------------------------------+
| integer | 32 bits unsigned number |
+---------------+----------------------------------------------+
| signed | 32 bits signed number |
+---------------+----------------------------------------------+
| short | 16 bits unsigned number |
+---------------+----------------------------------------------+
| byte | 8 bits unsigned number |
+---------------+----------------------------------------------+
| tlv | Nested tag-length-value |
+---------------+----------------------------------------------+
| integer64 | 64 bits unsigned number |
+---------------+----------------------------------------------+
These datatypes are parsed but not supported:
+---------------+----------------------------------------------+
| type | description |
+===============+==============================================+
| ifid | 8 octets in network byte order |
+---------------+----------------------------------------------+
| ether | 6 octets of hh:hh:hh:hh:hh:hh |
| | where 'h' is hex digits, upper or lowercase. |
+---------------+----------------------------------------------+
"""
from copy import copy
from pyrad import bidict
from pyrad import tools
from pyrad import dictfile
__docformat__ = 'epytext en'
DATATYPES = frozenset(['string', 'ipaddr', 'integer', 'date', 'octets',
'abinary', 'ipv6addr', 'ipv6prefix', 'short', 'byte',
'signed', 'ifid', 'ether', 'tlv', 'integer64'])
class ParseError(Exception):
"""Dictionary parser exceptions.
:ivar msg: Error message
:type msg: string
:ivar linenumber: Line number on which the error occurred
:type linenumber: integer
"""
def __init__(self, msg=None, **data):
super().__init__()
self.msg = msg
self.file = data.get('file', '')
self.line = data.get('line', -1)
def __str__(self):
line = f'({self.line})' if self.line > -1 else ''
return f'{self.file}{line}: ParseError: {self.msg}'
class Attribute():
def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', values=None,
encrypt=0, has_tag=False):
if datatype not in DATATYPES:
raise ValueError('Invalid data type')
self.name = name
self.code = code
self.type = datatype
self.vendor = vendor
self.encrypt = encrypt
self.has_tag = has_tag
self.values = bidict.BiDict()
self.sub_attributes = {}
self.parent = None
self.is_sub_attribute = is_sub_attribute
if values:
for (key, value) in values.items():
self.values.add(key, value)
class Dictionary():
"""RADIUS dictionary class.
This class stores all information about vendors, attributes and their
values as defined in RADIUS dictionary files.
:ivar vendors: bidict mapping vendor name to vendor code
:type vendors: bidict
:ivar attrindex: bidict mapping
:type attrindex: bidict
:ivar attributes: bidict mapping attribute name to attribute class
:type attributes: bidict
"""
def __init__(self, dict=None, *dicts):
"""
:param dict: path of dictionary file or file-like object to read
:type dict: string or file
:param dicts: list of dictionaries
:type dicts: sequence of strings or files
"""
self.vendors = bidict.BiDict()
self.vendors.add('', 0)
self.attrindex = bidict.BiDict()
self.attributes = {}
self.defer_parse = []
if dict:
self.read_dictionary(dict)
for i in dicts:
self.read_dictionary(i)
def __len__(self):
return len(self.attributes)
def __getitem__(self, key):
return self.attributes[key]
def __contains__(self, key):
return key in self.attributes
has_key = __contains__
def __parse_attribute(self, state, tokens):
if not len(tokens) in [4, 5]:
raise ParseError(
'Incorrect number of tokens for attribute definition',
name=state['file'],
line=state['line'])
vendor = state['vendor']
has_tag = False
encrypt = 0
if len(tokens) >= 5:
def keyval(o):
kv = o.split('=')
if len(kv) == 2:
return (kv[0], kv[1])
else:
return (kv[0], None)
options = [keyval(o) for o in tokens[4].split(',')]
for (key, val) in options:
if key == 'has_tag':
has_tag = True
elif key == 'encrypt':
if val not in ['1', '2', '3']:
raise ParseError(
f'Illegal attribute encryption: {val}',
file=state['file'],
line=state['line'])
encrypt = int(val)
if (not has_tag) and encrypt == 0:
vendor = tokens[4]
if not self.vendors.has_forward(vendor):
if vendor == "concat":
# ignore attributes with concat (freeradius compat.)
return None
else:
raise ParseError('Unknown vendor ' + vendor,
file=state['file'],
line=state['line'])
(attribute, code, datatype) = tokens[1:4]
codes = code.split('.')
# Codes can be sent as hex, or octal or decimal string representations.
tmp = []
for c in codes:
if c.startswith('0x'):
tmp.append(int(c, 16))
elif c.startswith('0o'):
tmp.append(int(c, 8))
else:
tmp.append(int(c, 10))
codes = tmp
is_sub_attribute = (len(codes) > 1)
if len(codes) == 2:
code = int(codes[1])
parent_code = int(codes[0])
elif len(codes) == 1:
code = int(codes[0])
parent_code = None
else:
raise ParseError('nested tlvs are not supported')
datatype = datatype.split("[")[0]
if datatype not in DATATYPES:
raise ParseError('Illegal type: ' + datatype,
file=state['file'],
line=state['line'])
if vendor:
if is_sub_attribute:
key = (self.vendors.get_forward(vendor), parent_code, code)
else:
key = (self.vendors.get_forward(vendor), code)
else:
if is_sub_attribute:
key = (parent_code, code)
else:
key = code
self.attrindex.add(attribute, key)
self.attributes[attribute] = Attribute(attribute, code, datatype, is_sub_attribute,
vendor, encrypt=encrypt, has_tag=has_tag)
if datatype == 'tlv':
# save attribute in tlvs
state['tlvs'][code] = self.attributes[attribute]
if is_sub_attribute:
# save sub attribute in parent tlv and update their parent field
state['tlvs'][parent_code].sub_attributes[code] = attribute
self.attributes[attribute].parent = state['tlvs'][parent_code]
def __parse_value(self, state, tokens, defer):
if len(tokens) != 4:
raise ParseError('Incorrect number of tokens for value definition',
file=state['file'],
line=state['line'])
(attr, key, value) = tokens[1:]
try:
adef = self.attributes[attr]
except KeyError:
if defer:
self.defer_parse.append((copy(state), copy(tokens)))
return
raise ParseError('Value defined for unknown attribute ' + attr,
file=state['file'],
line=state['line'])
if adef.type in ['integer', 'signed', 'short', 'byte', 'integer64']:
value = int(value, 0)
value = tools.EncodeAttr(adef.type, value)
self.attributes[attr].values.add(key, value)
def __parse_vendor(self, state, tokens):
if len(tokens) not in [3, 4]:
raise ParseError(
'Incorrect number of tokens for vendor definition',
file=state['file'],
line=state['line'])
# Parse format specification, but do
# nothing about it for now
if len(tokens) == 4:
fmt = tokens[3].split('=')
if fmt[0] != 'format':
raise ParseError(
f"Unknown option '{fmt[0]}' for vendor definition",
file=state['file'],
line=state['line'])
try:
(t, l) = tuple(int(a) for a in fmt[1].split(','))
if t not in [1, 2, 4] or l not in [0, 1, 2]:
raise ParseError(
f'Unknown vendor format specification {fmt[1]}',
file=state['file'],
line=state['line'])
except ValueError:
raise ParseError(
'Syntax error in vendor specification',
file=state['file'],
line=state['line'])
(vendorname, vendor) = tokens[1:3]
self.vendors.add(vendorname, int(vendor, 0))
def __parse_begin_vendor(self, state, tokens):
if len(tokens) != 2:
raise ParseError(
'Incorrect number of tokens for begin-vendor statement',
file=state['file'],
line=state['line'])
vendor = tokens[1]
if not self.vendors.has_forward(vendor):
raise ParseError(
f'Unknown vendor {vendor} in begin-vendor statement',
file=state['file'],
line=state['line'])
state['vendor'] = vendor
def __parse_end_vendor(self, state, tokens):
if len(tokens) != 2:
raise ParseError(
'Incorrect number of tokens for end-vendor statement',
file=state['file'],
line=state['line'])
vendor = tokens[1]
if state['vendor'] != vendor:
raise ParseError(
'Ending non-open vendor' + vendor,
file=state['file'],
line=state['line'])
state['vendor'] = ''
def read_dictionary(self, file):
"""Parse a dictionary file.
Reads a RADIUS dictionary file and merges its contents into the
class instance.
:param file: Name of dictionary file to parse or a file-like object
:type file: string or file-like object
"""
fil = dictfile.DictFile(file)
state = {}
state['vendor'] = ''
state['tlvs'] = {}
self.defer_parse = []
for line in fil:
state['file'] = fil.file()
state['line'] = fil.line()
line = line.split('#', 1)[0].strip()
tokens = line.split()
if not tokens:
continue
key = tokens[0].upper()
if key == 'ATTRIBUTE':
self.__parse_attribute(state, tokens)
elif key == 'VALUE':
self.__parse_value(state, tokens, True)
elif key == 'VENDOR':
self.__parse_vendor(state, tokens)
elif key == 'BEGIN-VENDOR':
self.__parse_begin_vendor(state, tokens)
elif key == 'END-VENDOR':
self.__parse_end_vendor(state, tokens)
for state, tokens in self.defer_parse:
key = tokens[0].upper()
if key == 'VALUE':
self.__parse_value(state, tokens, False)
self.defer_parse = []

View File

@@ -1,100 +0,0 @@
# host.py
#
# Copyright 2003,2007 Wichert Akkerman <wichert@wiggy.net>
from pyrad import packet
class Host(object):
"""Generic RADIUS capable host.
:ivar dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary
:ivar authport: port to listen on for authentication packets
:type authport: integer
:ivar acctport: port to listen on for accounting packets
:type acctport: integer
"""
def __init__(self, authport=1812, acctport=1813, coaport=3799, dict=None):
"""Constructor
:param authport: port to listen on for authentication packets
:type authport: integer
:param acctport: port to listen on for accounting packets
:type acctport: integer
:param coaport: port to listen on for CoA packets
:type coaport: integer
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary
"""
self.dict = dict
self.authport = authport
self.acctport = acctport
self.coaport = coaport
def CreatePacket(self, **args):
"""Create a new RADIUS packet.
This utility function creates a new RADIUS authentication
packet which can be used to communicate with the RADIUS server
this client talks to. This is initializing the new packet with
the dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.Packet
"""
return packet.Packet(dict=self.dict, **args)
def CreateAuthPacket(self, **args):
"""Create a new authentication RADIUS packet.
This utility function creates a new RADIUS authentication
packet which can be used to communicate with the RADIUS server
this client talks to. This is initializing the new packet with
the dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.AuthPacket
"""
return packet.AuthPacket(dict=self.dict, **args)
def CreateAcctPacket(self, **args):
"""Create a new accounting RADIUS packet.
This utility function creates a new accouting RADIUS packet
which can be used to communicate with the RADIUS server this
client talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.AcctPacket
"""
return packet.AcctPacket(dict=self.dict, **args)
def CreateCoAPacket(self, **args):
"""Create a new CoA RADIUS packet.
This utility function creates a new CoA RADIUS packet
which can be used to communicate with the RADIUS server this
client talks to. This is initializing the new packet with the
dictionary and secret used for the client.
:return: a new empty packet instance
:rtype: pyrad.packet.CoAPacket
"""
return packet.CoAPacket(dict=self.dict, **args)
def SendPacket(self, fd, pkt):
"""Send a packet.
:param fd: socket to send packet with
:type fd: socket class instance
:param pkt: packet to send
:type pkt: Packet class instance
"""
fd.sendto(pkt.Packet(), pkt.source)
def SendReplyPacket(self, fd, pkt):
"""Send a packet.
:param fd: socket to send packet with
:type fd: socket class instance
:param pkt: packet to send
:type pkt: Packet class instance
"""
fd.sendto(pkt.ReplyPacket(), pkt.source)

View File

@@ -1,29 +0,0 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
from pyrad3 import new_packet as packet
class Host:
def __init__(self, secret, radius_dict,
authport=1812, acctport=1813, coaport=3799,
timeout=30, retries=3):
self.secret = secret
self.dictionary = radius_dict
self.authport = authport
self.acctport = acctport
self.coaport = coaport
self.timeout = timeout
self.retries = retries
def create_packet(self, **kwargs):
return packet.Packet(self, **kwargs)
def create_auth_packet(self, **kwargs):
return packet.AuthPacket(self, **kwargs)
def create_acct_packet(self, **kwargs):
return packet.AcctPacket(self, **kwargs)
def create_coa_packet(self, **kwargs):
return packet.CoAPacket(self, **kwargs)

View File

@@ -1,143 +0,0 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
from collections import OrderedDict
import hashlib
import secrets
import struct
# Packet codes
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45
class PacketError(Exception):
pass
class AuthError(Exception):
pass
RANDOM_GENERATOR = secrets.SystemRandom()
class Packet(OrderedDict):
def __init__(self, host, code, radius_id, *, request=None, **attributes):
super().__init__(**attributes)
self.code = code
self.id = radius_id
self.host = host
self.request = request
def send(self):
self.host._send_packet(self)
def verify_reply(self, raw_reply):
if self.id != raw_reply[1]:
raise PacketError("Response has a wrong id")
radius_hash = self.calculate_radius_hash(raw_reply)
if radius_hash != raw_reply[4:20]:
raise PacketError("Reply Packet has a wrong authenticator")
return self.parse_raw_packet(raw_reply)
def calculate_radius_hash(self, data, authenticator=None):
if authenticator is None:
authenticator = self.authenticator
return hashlib.md5(data[0:4] + authenticator +
data[20:] + self.secret).digest()
def parse_raw_reply(self, raw_packet):
(code, radius_id, length, authenticator) = self._parse_header(raw_packet)
attrs = self.parse_attributes(raw_packet)
return Packet(self.host, code, radius_id, **attrs, request=self)
def _parse_header(self, raw_packet):
try:
header = struct.unpack('!BBH16s', raw_packet)
except struct.error:
raise PacketError('Packet header is corrupt')
length = header[3]
if len(raw_packet) != length:
raise PacketError('Packet has invalid length')
if length > 4096:
raise PacketError(f'Packet length is too big ({length})')
return header
class AuthPacket(Packet):
def __init__(self, host, radius_id, auth_type, *,
code=AccessRequest, request=None,
**attributes):
super().__init__(host, code, radius_id,
request=request, **attributes)
self.auth_type = auth_type
if code == AccessRequest:
self.authenticator = secrets.token_bytes(16)
def create_accept(self, **attributes):
return AuthPacket(self.host, self.id, self.auth_type,
request=self,
code=AccessAccept
**attributes)
def create_reject(self, **attributes):
return AuthPacket(self.host, self.id, self.auth_type,
request=self,
code=AccessReject,
**attributes)
def create_challange(self, **attributes):
return AuthPacket(self.host, self.id, self.auth_type,
request=self,
code=AccessChallenge,
**attributes)
class AcctPacket(Packet):
def __init__(self, host, radius_id, *,
code=AccountingRequest, request=None,
**attributes):
super().__init__(host, code, radius_id,
request=request, **attributes)
def create_response(self, **attributes):
return AcctPacket(self.host, self.id,
code=AccountingResponse,
request=self,
**attributes)
class CoAPacket(Packet):
def __init__(self, host, radius_id, *,
code=CoARequest, request=None,
**attributes):
super().__init__(host, code, radius_id,
request=request, **attributes)
def create_ack(self, **attributes):
return CoAPacket(self.host, self.id,
code=CoAACK,
request=self,
**attributes)
def create_nack(self, **attributes):
return CoAPacket(self.host, self.id,
code=CoANACK,
request=self,
**attributes)

View File

@@ -1,869 +0,0 @@
# packet.py
#
# Copyright 2002-2005,2007 Wichert Akkerman <wichert@wiggy.net>
#
# A RADIUS packet as defined in RFC 2138
from collections import OrderedDict
import hashlib
import hmac
import secrets
import struct
from pyrad import tools
md5_constructor = hashlib.md5
# Packet codes
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45
# Use cryptographic-safe random generator as provided by the OS.
random_generator = secrets.SystemRandom()
# Current ID
CurrentID = random_generator.randrange(1, 255)
class PacketError(Exception):
pass
class Packet(OrderedDict):
"""Packet acts like a standard python map to provide simple access
to the RADIUS attributes. Since RADIUS allows for repeated
attributes the value will always be a sequence. pyrad makes sure
to preserve the ordering when encoding and decoding packets.
There are two ways to use the map intereface: if attribute
names are used pyrad take care of en-/decoding data. If
the attribute type number (or a vendor ID/attribute type
tuple for vendor attributes) is used you work with the
raw data.
Normally you will not use this class directly, but one of the
:obj:`AuthPacket` or :obj:`AcctPacket` classes.
"""
def __init__(self, code=0, id=None, secret=b'', authenticator=None,
**attributes):
"""Constructor
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary class
:param secret: secret needed to communicate with a RADIUS server
:type secret: string
:param id: packet identification number
:type id: integer (8 bits)
:param code: packet type code
:type code: integer (8bits)
:param packet: raw packet to decode
:type packet: string
"""
super().__init__()
self.code = code
if id is not None:
self.id = id
else:
self.id = CreateID()
if not isinstance(secret, bytes):
raise TypeError('secret must be a binary string')
self.secret = secret
if authenticator is not None and \
not isinstance(authenticator, bytes):
raise TypeError('authenticator must be a binary string')
self.authenticator = authenticator
self.message_authenticator = None
if 'dict' in attributes:
self.dict = attributes['dict']
if 'packet' in attributes:
self.raw_packet = attributes['packet']
self.DecodePacket(attributes['packet'])
if 'message_authenticator' in attributes:
self.message_authenticator = attributes['message_authenticator']
for (key, value) in attributes.items():
if key in [
'dict', 'fd', 'packet',
'message_authenticator',
]:
continue
key = key.replace('_', '-')
self.AddAttribute(key, value)
def add_message_authenticator(self):
self.message_authenticator = True
# Maintain a zero octets content for md5 and hmac calculation.
self['Message-Authenticator'] = 16 * b'\00'
if self.id is None:
self.id = self.CreateID()
if self.authenticator is None and self.code == AccessRequest:
self.authenticator = self.CreateAuthenticator()
self._refresh_message_authenticator()
def get_message_authenticator(self):
self._refresh_message_authenticator()
return self.message_authenticator
def _refresh_message_authenticator(self):
hmac_constructor = hmac.new(self.secret, digestmod=hashlib.md5)
# Maintain a zero octets content for md5 and hmac calculation.
self['Message-Authenticator'] = 16 * b'\00'
attr = self._pkt_encode_attributes()
header = struct.pack('!BBH', self.code, self.id,
(20 + len(attr)))
hmac_constructor.update(header[0:4])
if self.code in (AccountingRequest, DisconnectRequest,
CoARequest, AccountingResponse):
hmac_constructor.update(16 * b'\00')
else:
# NOTE: self.authenticator on reply packet is initialized
# with request authenticator by design.
# For AccessAccept, AccessReject and AccessChallenge
# it is needed use original Authenticator.
# For AccessAccept, AccessReject and AccessChallenge
# it is needed use original Authenticator.
if self.authenticator is None:
raise Exception('No authenticator found')
hmac_constructor.update(self.authenticator)
hmac_constructor.update(attr)
self['Message-Authenticator'] = hmac_constructor.digest()
def verify_message_authenticator(self,
original_authenticator=None,
original_code=None):
"""Verify packet Message-Authenticator.
:return: False if verification failed else True
:rtype: boolean
"""
if self.message_authenticator is None:
raise Exception('No Message-Authenticator AVP present')
prev_ma = self['Message-Authenticator']
self['Message-Authenticator'] = 16 * b'\00'
attr = self._pkt_encode_attributes()
header = struct.pack('!BBH', self.code, self.id,
(20 + len(attr)))
hmac_constructor = hmac.new(self.secret, digestmod=hashlib.md5)
hmac_constructor.update(header)
if self.code in (AccountingRequest, DisconnectRequest,
CoARequest, AccountingResponse):
if original_code is None or original_code != StatusServer:
# TODO: Handle Status-Server response correctly.
hmac_constructor.update(16 * b'\00')
elif self.code in (AccessAccept, AccessChallenge,
AccessReject):
if original_authenticator is None:
if self.authenticator:
# NOTE: self.authenticator on reply packet is initialized
# with request authenticator by design.
original_authenticator = self.authenticator
else:
raise Exception('Missing original authenticator')
hmac_constructor.update(original_authenticator)
else:
# On Access-Request and Status-Server use dynamic authenticator
hmac_constructor.update(self.authenticator)
hmac_constructor.update(attr)
self['Message-Authenticator'] = prev_ma[0]
return prev_ma[0] == hmac_constructor.digest()
def CreateReply(self, **attributes):
"""Create a new packet as a reply to this one. This method
makes sure the authenticator and secret are copied over
to the new instance.
"""
return Packet(id=self.id, secret=self.secret,
authenticator=self.authenticator, dict=self.dict,
**attributes)
def _decode_value(self, attr, value):
try:
return attr.values.get_backward(value)
except KeyError:
return tools.DecodeAttr(attr.type, value)
def _encode_value(self, attr, value):
try:
result = attr.values.get_forward(value)
except KeyError:
result = tools.EncodeAttr(attr.type, value)
if attr.encrypt == 2:
# salt encrypt attribute
result = self.SaltCrypt(result)
return result
def _encode_key_values(self, key, values):
if not isinstance(key, str):
return (key, values)
if not isinstance(values, (list, tuple)):
values = [values]
key, _, tag = key.partition(":")
attr = self.dict.attributes[key]
key = self._encode_key(key)
if attr.has_tag:
tag = '0' if tag == '' else tag
tag = struct.pack('B', int(tag))
if attr.type == "integer":
# When a tagged value has the type int only 3 bytes are used
# the first byte is the tag itself, so we need to shorten our int
return (key, [tag + self._encode_value(attr, v)[1:] for v in values])
else:
return (key, [tag + self._encode_value(attr, v) for v in values])
else:
return (key, [self._encode_value(attr, v) for v in values])
def _encode_key(self, key):
if not isinstance(key, str):
return key
attr = self.dict.attributes[key]
# sub attribute keys don't need vendor
if attr.vendor and not attr.is_sub_attribute:
return (self.dict.vendors.get_forward(attr.vendor), attr.code)
else:
return attr.code
def _decode_key(self, key):
"""Turn a key into a string if possible"""
try:
return self.dict.attrindex.get_backward(key)
except KeyError:
return key
def AddAttribute(self, key, value):
"""Add an attribute to the packet.
:param key: attribute name or identification
:type key: string, attribute code or (vendor code, attribute code)
tuple
:param value: value
:type value: depends on type of attribute
"""
attr = self.dict.attributes[key.partition(':')[0]]
(key, value) = self._encode_key_values(key, value)
if attr.is_sub_attribute:
tlv = self.setdefault(self._encode_key(attr.parent.name), {})
encoded = tlv.setdefault(key, [])
else:
encoded = self.setdefault(key, [])
encoded.extend(value)
def get(self, key, failobj=None):
try:
res = self.__getitem__(key)
except KeyError:
res = failobj
return res
def __getitem__(self, key):
if not isinstance(key, str):
return super().__getitem__(key)
values = super().__getitem__(self._encode_key(key))
attr = self.dict.attributes[key]
if attr.type == 'tlv': # return map from sub attribute code to its values
res = {}
for (sub_attr_key, sub_attr_val) in values.items():
sub_attr_name = attr.sub_attributes[sub_attr_key]
sub_attr = self.dict.attributes[sub_attr_name]
for v in sub_attr_val:
res.setdefault(sub_attr_name, []).append(self._decode_value(sub_attr, v))
return res
else:
res = []
for v in values:
res.append(self._decode_value(attr, v))
return res
def __contains__(self, key):
try:
return super().__contains__(self._encode_key(key))
except KeyError:
return False
has_key = __contains__
def __delitem__(self, key):
super().__delitem__(self._encode_key(key))
def __setitem__(self, key, item):
if isinstance(key, str):
(key, item) = self._encode_key_values(key, item)
super().__setitem__(key, item)
def keys(self):
return [self._decode_key(key) for key in OrderedDict.keys(self)]
@staticmethod
def CreateAuthenticator():
"""Create a packet authenticator. All RADIUS packets contain a sixteen
byte authenticator which is used to authenticate replies from the
RADIUS server and in the password hiding algorithm. This function
returns a suitable random string that can be used as an authenticator.
:return: valid packet authenticator
:rtype: binary string
"""
return secrets.token_bytes(16)
def CreateID(self):
"""Create a packet ID. All RADIUS requests have a ID which is used to
identify a request. This is used to detect retries and replay attacks.
This function returns a suitable random number that can be used as ID.
:return: ID number
:rtype: integer
"""
return int.from_bytes(secrets.token_bytes(1), 'little')
def ReplyPacket(self):
"""Create a ready-to-transmit authentication reply packet.
Returns a RADIUS packet which can be directly transmitted
to a RADIUS server. This differs with Packet() in how
the authenticator is calculated.
:return: raw packet
:rtype: string
"""
assert self.authenticator
assert self.secret is not None
if self.message_authenticator:
self._refresh_message_authenticator()
attr = self._pkt_encode_attributes()
header = struct.pack('!BBH', self.code, self.id, (20 + len(attr)))
authenticator = md5_constructor(header[0:4] + self.authenticator
+ attr + self.secret).digest()
return header + authenticator + attr
def VerifyReply(self, rawreply):
if int(rawreply[1]) != self.id:
raise PacketError("Reply Packet has wrong id")
# The Authenticator field in an Accounting-Response packet is called
# the Response Authenticator, and contains a one-way MD5 hash
# calculated over a stream of octets consisting of the Accounting
# Response Code, Identifier, Length, the Request Authenticator field
# from the Accounting-Request packet being replied to, and the
# response attributes if any, followed by the shared secret. The
# resulting 16 octet MD5 hash value is stored in the Authenticator
# field of the Accounting-Response packet.
hash = md5_constructor(rawreply[0:4] + self.authenticator +
rawreply[20:] + self.secret).digest()
if hash != rawreply[4:20]:
raise PacketError("Reply Packet has a wrong authenticator")
return self.CreateReply(packet=rawreply)
def _pkt_encode_attribute(self, key, value):
if isinstance(key, tuple):
value = struct.pack('!L', key[0]) + \
self._pkt_encode_attribute(key[1], value)
key = 26
return struct.pack('!BB', key, (len(value) + 2)) + value
def _pkt_encode_tlv(self, tlv_key, tlv_value):
tlv_attr = self.dict.attributes[self._decode_key(tlv_key)]
curr_avp = b''
avps = []
max_sub_attribute_len = max(map(lambda item: len(item[1]), tlv_value.items()))
for i in range(max_sub_attribute_len):
sub_attr_encoding = b''
for (code, datalst) in tlv_value.items():
if i < len(datalst):
sub_attr_encoding += self._pkt_encode_attribute(code, datalst[i])
# split above 255. assuming len of one instance of all sub tlvs is lower than 255
if (len(sub_attr_encoding) + len(curr_avp)) < 245:
curr_avp += sub_attr_encoding
else:
avps.append(curr_avp)
curr_avp = sub_attr_encoding
avps.append(curr_avp)
tlv_avps = []
for avp in avps:
value = struct.pack('!BB', tlv_attr.code, (len(avp) + 2)) + avp
tlv_avps.append(value)
if tlv_attr.vendor:
vendor_avps = b''
for avp in tlv_avps:
vendor_avps += struct.pack(
'!BBL', 26, (len(avp) + 6),
self.dict.vendors.get_forward(tlv_attr.vendor)
) + avp
return vendor_avps
else:
return b''.join(tlv_avps)
def _pkt_encode_attributes(self):
result = b''
for (code, datalst) in self.items():
attribute = self.dict.attributes.get(self._decode_key(code))
if attribute and attribute.type == 'tlv':
result += self._pkt_encode_tlv(code, datalst)
else:
for data in datalst:
result += self._pkt_encode_attribute(code, data)
return result
def _pkt_decode_vendor_attribute(self, data):
# Check if this packet is long enough to be in the
# RFC2865 recommended form
if len(data) < 6:
return [(26, data)]
(vendor, atype, length) = struct.unpack('!LBB', data[:6])[0:3]
attribute = self.dict.attributes.get(self._decode_key((vendor, atype)))
try:
if attribute and attribute.type == 'tlv':
self._pkt_decode_tlv_attribute((vendor, atype), data[6:length + 4])
tlvs = [] # tlv is added to the packet inside _pkt_decode_tlv_attribute
else:
tlvs = [((vendor, atype), data[6:length + 4])]
except:
return [(26, data)]
sumlength = 4 + length
while len(data) > sumlength:
try:
atype, length = struct.unpack('!BB', data[sumlength:sumlength+2])[0:2]
except:
return [(26, data)]
tlvs.append(((vendor, atype), data[sumlength+2:sumlength+length]))
sumlength += length
return tlvs
def _pkt_decode_tlv_attribute(self, code, data):
sub_attributes = self.setdefault(code, {})
loc = 0
while loc < len(data):
atype, length = struct.unpack('!BB', data[loc:loc+2])[0:2]
sub_attributes.setdefault(atype, []).append(data[loc+2:loc+length])
loc += length
def DecodePacket(self, packet):
"""Initialize the object from raw packet data. Decode a packet as
received from the network and decode it.
:param packet: raw packet
:type packet: string"""
try:
(self.code, self.id, length, self.authenticator) = \
struct.unpack('!BBH16s', packet[0:20])
except struct.error:
raise PacketError('Packet header is corrupt')
if len(packet) != length:
raise PacketError('Packet has invalid length')
if length > 4096:
raise PacketError(f'Packet length is too long ({length})')
self.clear()
packet = packet[20:]
while packet:
try:
(key, attrlen) = struct.unpack('!BB', packet[0:2])
except struct.error:
raise PacketError('Attribute header is corrupt')
if attrlen < 2:
raise PacketError(f'Attribute length is too small (attrlen)')
value = packet[2:attrlen]
attribute = self.dict.attributes.get(self._decode_key(key))
if key == 26:
for (key, value) in self._pkt_decode_vendor_attribute(value):
self.setdefault(key, []).append(value)
elif key == 80:
# POST: Message Authenticator AVP is present.
self.message_authenticator = True
self.setdefault(key, []).append(value)
elif attribute and attribute.type == 'tlv':
self._pkt_decode_tlv_attribute(key, value)
else:
self.setdefault(key, []).append(value)
packet = packet[attrlen:]
def SaltCrypt(self, value):
"""Salt Encryption
:param value: plaintext value
:type password: unicode string
:return: obfuscated version of the value
:rtype: binary string
"""
if isinstance(value, str):
value = value.encode('utf-8')
if self.authenticator is None:
# self.authenticator = self.CreateAuthenticator()
self.authenticator = 16 * b'\x00'
random_value = 32768 + random_generator.randrange(0, 32767)
result = struct.pack('!H', random_value)
length = struct.pack("B", len(value))
buf = length + value
if len(buf) % 16 != 0:
buf += b'\x00' * (16 - (len(buf) % 16))
last = self.authenticator + result
while buf:
cur_hash = md5_constructor(self.secret + last).digest()
for b, h in zip(buf, cur_hash):
result += bytes([b ^ h])
last = result[-16:]
buf = buf[16:]
return result
class AuthPacket(Packet):
def __init__(self, code=AccessRequest, id=None, secret=b'',
authenticator=None, auth_type='pap', **attributes):
"""Constructor
:param code: packet type code
:type code: integer (8bits)
:param id: packet identification number
:type id: integer (8 bits)
:param secret: secret needed to communicate with a RADIUS server
:type secret: string
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary class
:param packet: raw packet to decode
:type packet: string
"""
Packet.__init__(self, code, id, secret, authenticator, **attributes)
self.auth_type = auth_type
def CreateReply(self, **attributes):
"""Create a new packet as a reply to this one. This method
makes sure the authenticator and secret are copied over
to the new instance.
"""
return AuthPacket(AccessAccept, self.id,
self.secret, self.authenticator, dict=self.dict,
auth_type=self.auth_type, **attributes)
def RequestPacket(self):
"""Create a ready-to-transmit authentication request packet.
Return a RADIUS packet which can be directly transmitted
to a RADIUS server.
:return: raw packet
:rtype: string
"""
if self.authenticator is None:
self.authenticator = self.CreateAuthenticator()
if self.id is None:
self.id = self.CreateID()
if self.message_authenticator:
self._refresh_message_authenticator()
attr = self._pkt_encode_attributes()
if self.auth_type == 'eap-md5':
header = struct.pack(
'!BBH16s', self.code, self.id, (20 + 18 + len(attr)), self.authenticator
)
msg = header \
+ attr \
+ struct.pack('!BB', 80, struct.calcsize('!BB16s')),
digest = hmac.new(self.secret, msg, digestmod=hashlib.md5).digest()
return msg + digest
header = struct.pack('!BBH16s', self.code, self.id,
(20 + len(attr)), self.authenticator)
return header + attr
def PwDecrypt(self, password):
"""Obfuscate a RADIUS password. RADIUS hides passwords in packets by
using an algorithm based on the MD5 hash of the packet authenticator
and RADIUS secret. This function reverses the obfuscation process.
:param password: obfuscated form of password
:type password: binary string
:return: plaintext password
:rtype: unicode string
"""
pw = self.radius_password_pseudo_hash(password).rstrip(b'\x00')
return pw.decode('utf-8')
def PwCrypt(self, password):
"""Obfuscate password.
RADIUS hides passwords in packets by using an algorithm
based on the MD5 hash of the packet authenticator and RADIUS
secret. If no authenticator has been set before calling PwCrypt
one is created automatically. Changing the authenticator after
setting a password that has been encrypted using this function
will not work.
:param password: plaintext password
:type password: unicode string
:return: obfuscated version of the password
:rtype: binary string
"""
if self.authenticator is None:
self.authenticator = self.CreateAuthenticator()
if isinstance(password, str):
password = password.encode('utf-8')
buf = password
if len(password) % 16 != 0:
buf += b'\x00' * (16 - (len(password) % 16))
return self.radius_password_pseudo_hash(buf)
def radius_password_pseudo_hash(self, password):
result = b''
buf = password
last = self.authenticator
while buf:
cur_hash = md5_constructor(self.secret + last).digest()
for b, h in zip(buf, cur_hash):
result += bytes([b ^ h])
(last, buf) = (buf[:16], buf[16:])
return result
def VerifyChapPasswd(self, userpwd):
""" Verify RADIUS ChapPasswd
:param userpwd: plaintext password
:type userpwd: str
:return: is verify ok
:rtype: bool
"""
if not self.authenticator:
self.authenticator = self.CreateAuthenticator()
if isinstance(userpwd, str):
userpwd = userpwd.strip().encode('utf-8')
chap_password = tools.DecodeOctets(self.get(3)[0])
if len(chap_password) != 17:
return False
chapid = chr(chap_password[0]).encode('utf-8')
password = chap_password[1:]
challenge = self.authenticator
if 'CHAP-Challenge' in self:
challenge = self['CHAP-Challenge'][0]
return password == md5_constructor(chapid + userpwd + challenge).digest()
def VerifyAuthRequest(self):
"""Verify request authenticator.
:return: True if verification failed else False
:rtype: boolean
"""
assert self.raw_packet
hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' +
self.raw_packet[20:] + self.secret).digest()
return hash == self.authenticator
class AcctPacket(Packet):
"""RADIUS accounting packets. This class is a specialization
of the generic :obj:`Packet` class for accounting packets.
"""
def __init__(self, code=AccountingRequest, id=None, secret=b'',
authenticator=None, **attributes):
"""Constructor
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary class
:param secret: secret needed to communicate with a RADIUS server
:type secret: string
:param id: packet identification number
:type id: integer (8 bits)
:param code: packet type code
:type code: integer (8bits)
:param packet: raw packet to decode
:type packet: string
"""
Packet.__init__(self, code, id, secret, authenticator, **attributes)
def CreateReply(self, **attributes):
"""Create a new packet as a reply to this one. This method
makes sure the authenticator and secret are copied over
to the new instance.
"""
return AcctPacket(AccountingResponse, self.id,
self.secret, self.authenticator, dict=self.dict,
**attributes)
def VerifyAcctRequest(self):
"""Verify request authenticator.
:return: False if verification failed else True
:rtype: boolean
"""
assert self.raw_packet
hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' +
self.raw_packet[20:] + self.secret).digest()
return hash == self.authenticator
def RequestPacket(self):
"""Create a ready-to-transmit authentication request packet.
Return a RADIUS packet which can be directly transmitted
to a RADIUS server.
:return: raw packet
:rtype: string
"""
if self.id is None:
self.id = self.CreateID()
if self.message_authenticator:
self._refresh_message_authenticator()
attr = self._pkt_encode_attributes()
header = struct.pack('!BBH', self.code, self.id, (20 + len(attr)))
self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' +
attr + self.secret).digest()
ans = header + self.authenticator + attr
return ans
class CoAPacket(Packet):
"""RADIUS CoA packets. This class is a specialization
of the generic :obj:`Packet` class for CoA packets.
"""
def __init__(self, code=CoARequest, id=None, secret=b'',
authenticator=None, **attributes):
"""Constructor
:param dict: RADIUS dictionary
:type dict: pyrad.dictionary.Dictionary class
:param secret: secret needed to communicate with a RADIUS server
:type secret: string
:param id: packet identification number
:type id: integer (8 bits)
:param code: packet type code
:type code: integer (8bits)
:param packet: raw packet to decode
:type packet: string
"""
Packet.__init__(self, code, id, secret, authenticator, **attributes)
def CreateReply(self, **attributes):
"""Create a new packet as a reply to this one. This method
makes sure the authenticator and secret are copied over
to the new instance.
"""
return CoAPacket(CoAACK, self.id,
self.secret, self.authenticator, dict=self.dict,
**attributes)
def VerifyCoARequest(self):
"""Verify request authenticator.
:return: False if verification failed else True
:rtype: boolean
"""
assert self.raw_packet
hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' +
self.raw_packet[20:] + self.secret).digest()
return hash == self.authenticator
def RequestPacket(self):
"""Create a ready-to-transmit CoA request packet.
Return a RADIUS packet which can be directly transmitted
to a RADIUS server.
:return: raw packet
:rtype: string
"""
attr = self._pkt_encode_attributes()
if self.id is None:
self.id = self.CreateID()
header = struct.pack('!BBH', self.code, self.id, (20 + len(attr)))
self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' +
attr + self.secret).digest()
if self.message_authenticator:
self._refresh_message_authenticator()
attr = self._pkt_encode_attributes()
self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' +
attr + self.secret).digest()
return header + self.authenticator + attr
def CreateID():
"""Generate a packet ID.
:return: packet ID
:rtype: 8 bit integer
"""
global CurrentID
CurrentID = (CurrentID + 1) % 256
return CurrentID

View File

@@ -1,5 +0,0 @@
import pyrad
import sys
pyrad # keep pyflakes happy
home = sys.modules["pyrad"].__path__[0]

View File

@@ -1,6 +0,0 @@
# A simple dictionary
ATTRIBUTE User-Name 1 string
ATTRIBUTE User-Password 2 string encrypt=1
ATTRIBUTE CHAP-Password 3 octets
ATTRIBUTE CHAP-Challenge 60 octets

View File

@@ -1,5 +0,0 @@
# A failing dictionary
VALUE Not-Defined Undefined-Value 1

View File

@@ -1,34 +0,0 @@
# A simple dictionary
ATTRIBUTE Test-String 1 string
ATTRIBUTE Test-Octets 2 octets
ATTRIBUTE Test-Integer 3 integer
VALUE Test-Integer Zero 0
VALUE Test-Integer One 1
VALUE Test-Integer Two 2
VALUE Test-Integer Three 3
VALUE Test-Integer Four 4
ATTRIBUTE Test-Tlv 4 tlv
ATTRIBUTE Test-Tlv-Str 4.1 string
ATTRIBUTE Test-Tlv-Int 4.2 integer
VENDOR Simplon 16
BEGIN-VENDOR Simplon
ATTRIBUTE Simplon-Number 1 integer
ATTRIBUTE Simplon-String 2 string
VALUE Simplon-Number Zero 0
VALUE Simplon-Number One 1
VALUE Simplon-Number Two 2
VALUE Simplon-Number Three 3
VALUE Simplon-Number Four 4
ATTRIBUTE Simplon-Tlv 3 tlv
ATTRIBUTE Simplon-Tlv-Str 3.1 string
ATTRIBUTE Simplon-Tlv-Int 3.2 integer
END-VENDOR Simplon

View File

@@ -1,16 +0,0 @@
# A simple dictionary
ATTRIBUTE Test-String 1 string
ATTRIBUTE Test-Octets 2 octets
ATTRIBUTE Test-Integer 3 integer
ATTRIBUTE Test-Ip-Address 4 ipaddr
ATTRIBUTE Test-Ipv6-Address 5 ipv6addr
ATTRIBUTE Test-If-Id 6 ifid
ATTRIBUTE Test-Date 7 date
ATTRIBUTE Test-Abinary 8 abinary
ATTRIBUTE Test-Tlv 9 tlv
ATTRIBUTE Test-Tlv-Str 9.1 string
ATTRIBUTE Test-Tlv-Int 9.2 integer
ATTRIBUTE Test-Integer64 10 integer64
ATTRIBUTE Test-Integer64-Hex 0x0a integer64
ATTRIBUTE Test-Integer64-Oct 0o12 integer64

View File

@@ -1,3 +0,0 @@
# A simple dictionary
ATTRIBUTE Tunnel-Password 2 string encrypt=2

View File

@@ -1,141 +0,0 @@
import fcntl
import os
from pyrad.packet import PacketError
class MockPacket:
reply = object()
def __init__(self, code, verify=False, error=False):
self.code = code
self.data = {}
self.verify = verify
self.error = error
def CreateReply(self, packet=None):
if self.error:
raise PacketError
return self.reply
def VerifyReply(self, reply, rawreply):
return self.verify
def RequestPacket(self):
return "request packet"
def __contains__(self, key):
return key in self.data
has_key = __contains__
def __setitem__(self, key, value):
self.data[key] = [value]
def __getitem__(self, key):
return self.data[key]
class MockSocket:
def __init__(self, domain, type, data=None):
self.domain = domain
self.type = type
self.closed = False
self.options = []
self.address = None
self.output = []
if data is not None:
(self.read_end, self.write_end) = os.pipe()
fcntl.fcntl(self.write_end, fcntl.F_SETFL, os.O_NONBLOCK)
os.write(self.write_end, data)
self.data = data
else:
self.read_end = 1
self.write_end = None
def fileno(self):
return self.read_end
def bind(self, address):
self.address = address
def recv(self, buffer):
return self.data[:buffer]
def sendto(self, data, target):
self.output.append((data, target))
def setsockopt(self, level, opt, value):
self.options.append((level, opt, value))
def close(self):
self.closed = True
class MockFinished(Exception):
pass
class MockPoll:
results = []
def __init__(self):
self.registry = {}
def register(self, fd, options):
self.registry[fd] = options
def unregister(self, fd):
try:
del self.registry[fd]
except KeyError:
pass
def poll(self, timeout=None):
for result in self.results:
yield result
raise MockFinished
def origkey(klass):
return "_originals_" + klass.__name__
def MockClassMethod(klass, name, myfunc=None):
def func(self, *args, **kwargs):
if not hasattr(self, "called"):
self.called = []
self.called.append((name, args, kwargs))
key = origkey(klass)
if not hasattr(klass, key):
setattr(klass, key, {})
getattr(klass, key)[name] = getattr(klass, name)
if myfunc is None:
setattr(klass, name, func)
else:
setattr(klass, name, myfunc)
def UnmockClassMethods(klass):
key = origkey(klass)
if not hasattr(klass, key):
return
for (name, func) in getattr(klass, key).items():
setattr(klass, name, func)
delattr(klass, key)
class MockFd:
data = object()
source = object()
def __init__(self, fd=0):
self.fd = fd
def fileno(self):
return self.fd
def recvfrom(self, size):
self.size = size
return (self.data, self.source)

View File

@@ -1,56 +0,0 @@
import operator
import unittest
from pyrad.bidict import BiDict
class BiDictTests(unittest.TestCase):
def setUp(self):
self.bidict = BiDict()
def testStartEmpty(self):
self.assertEqual(len(self.bidict), 0)
self.assertEqual(len(self.bidict.forward), 0)
self.assertEqual(len(self.bidict.backward), 0)
def testLength(self):
self.assertEqual(len(self.bidict), 0)
self.bidict.add("from", "to")
self.assertEqual(len(self.bidict), 1)
del self.bidict["from"]
self.assertEqual(len(self.bidict), 0)
def testDeletion(self):
self.assertRaises(KeyError, operator.delitem, self.bidict, "missing")
self.bidict.add("missing", "present")
del self.bidict["missing"]
def testBackwardDeletion(self):
self.assertRaises(KeyError, operator.delitem, self.bidict, "missing")
self.bidict.add("missing", "present")
del self.bidict["present"]
self.assertEqual(self.bidict.has_forward("missing"), False)
def testForwardAccess(self):
self.bidict.add("shake", "vanilla")
self.bidict.add("pie", "custard")
self.assertEqual(self.bidict.has_forward("shake"), True)
self.assertEqual(self.bidict.get_forward("shake"), "vanilla")
self.assertEqual(self.bidict.has_forward("pie"), True)
self.assertEqual(self.bidict.get_forward("pie"), "custard")
self.assertEqual(self.bidict.has_forward("missing"), False)
self.assertRaises(KeyError, self.bidict.get_forward, "missing")
def testBackwardAccess(self):
self.bidict.add("shake", "vanilla")
self.bidict.add("pie", "custard")
self.assertEqual(self.bidict.has_backward("vanilla"), True)
self.assertEqual(self.bidict.get_backward("vanilla"), "shake")
self.assertEqual(self.bidict.has_backward("missing"), False)
self.assertRaises(KeyError, self.bidict.get_backward, "missing")
def testItemAccessor(self):
self.bidict.add("shake", "vanilla")
self.bidict.add("pie", "custard")
self.assertRaises(KeyError, operator.getitem, self.bidict, "missing")
self.assertEquals(self.bidict["shake"], "vanilla")
self.assertEquals(self.bidict["pie"], "custard")

View File

@@ -1,184 +0,0 @@
import select
import socket
import unittest
from pyrad.client import Client
from pyrad.client import Timeout
from pyrad.packet import AuthPacket
from pyrad.packet import AcctPacket
from pyrad.packet import AccessRequest
from pyrad.packet import AccountingRequest
from pyrad.tests.mock import MockPacket
from pyrad.tests.mock import MockPoll
from pyrad.tests.mock import MockSocket
BIND_IP = "127.0.0.1"
BIND_PORT = 53535
class ConstructionTests(unittest.TestCase):
def setUp(self):
self.server = object()
def testSimpleConstruction(self):
client = Client(self.server)
self.failUnless(client.server is self.server)
self.assertEqual(client.authport, 1812)
self.assertEqual(client.acctport, 1813)
self.assertEqual(client.secret, b'')
self.assertEqual(client.retries, 3)
self.assertEqual(client.timeout, 5)
self.failUnless(client.dict is None)
def testParameterOrder(self):
marker = object()
client = Client(self.server, 123, 456, 789, "secret", marker)
self.failUnless(client.server is self.server)
self.assertEqual(client.authport, 123)
self.assertEqual(client.acctport, 456)
self.assertEqual(client.coaport, 789)
self.assertEqual(client.secret, "secret")
self.failUnless(client.dict is marker)
def testNamedParameters(self):
marker = object()
client = Client(server=self.server, authport=123, acctport=456,
secret="secret", dict=marker)
self.failUnless(client.server is self.server)
self.assertEqual(client.authport, 123)
self.assertEqual(client.acctport, 456)
self.assertEqual(client.secret, "secret")
self.failUnless(client.dict is marker)
class SocketTests(unittest.TestCase):
def setUp(self):
self.server = object()
self.client = Client(self.server)
self.orgsocket = socket.socket
socket.socket = MockSocket
def tearDown(self):
socket.socket = self.orgsocket
def testReopen(self):
self.client._socket_open()
sock = self.client._socket
self.client._socket_open()
self.failUnless(sock is self.client._socket)
def testBind(self):
self.client.bind((BIND_IP, BIND_PORT))
self.assertEqual(self.client._socket.address, (BIND_IP, BIND_PORT))
self.assertEqual(self.client._socket.options,
[(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)])
def testBindClosesSocket(self):
s = MockSocket(socket.AF_INET, socket.SOCK_DGRAM)
self.client._socket = s
self.client._poll = MockPoll()
self.client.bind((BIND_IP, BIND_PORT))
self.assertEqual(s.closed, True)
def testSendPacket(self):
def MockSend(self, pkt, port):
self._mock_pkt = pkt
self._mock_port = port
_send_packet = Client._send_packet
Client._send_packet = MockSend
self.client.SendPacket(AuthPacket())
self.assertEqual(self.client._mock_port, self.client.authport)
self.client.SendPacket(AcctPacket())
self.assertEqual(self.client._mock_port, self.client.acctport)
Client._send_packet = _send_packet
def testNoRetries(self):
self.client.retries = 0
self.assertRaises(Timeout, self.client._send_packet, None, None)
def testSingleRetry(self):
self.client.retries = 1
self.client.timeout = 0
packet = MockPacket(AccessRequest)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
self.assertEqual(self.client._socket.output,
[("request packet", (self.server, 432))])
def testDoubleRetry(self):
self.client.retries = 2
self.client.timeout = 0
packet = MockPacket(AccessRequest)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
self.assertEqual(
self.client._socket.output,
[("request packet", (self.server, 432)),
("request packet", (self.server, 432))])
def testAuthDelay(self):
self.client.retries = 2
self.client.timeout = 1
packet = MockPacket(AccessRequest)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
self.failIf("Acct-Delay-Time" in packet)
def testSingleAccountDelay(self):
self.client.retries = 2
self.client.timeout = 1
packet = MockPacket(AccountingRequest)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
self.assertEqual(packet["Acct-Delay-Time"], [1])
def testDoubleAccountDelay(self):
self.client.retries = 3
self.client.timeout = 1
packet = MockPacket(AccountingRequest)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
self.assertEqual(packet["Acct-Delay-Time"], [2])
def testIgnorePacketError(self):
self.client.retries = 1
self.client.timeout = 1
self.client._socket = MockSocket(1, 2, b"valid reply")
packet = MockPacket(AccountingRequest, verify=True, error=True)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
def testValidReply(self):
self.client.retries = 1
self.client.timeout = 1
self.client._socket = MockSocket(1, 2, b"valid reply")
self.client._poll = MockPoll()
MockPoll.results = [(1, select.POLLIN)]
packet = MockPacket(AccountingRequest, verify=True)
reply = self.client._send_packet(packet, 432)
self.failUnless(reply is packet.reply)
def testInvalidReply(self):
self.client.retries = 1
self.client.timeout = 1
self.client._socket = MockSocket(1, 2, b"invalid reply")
MockPoll.results = [(1, select.POLLIN)]
packet = MockPacket(AccountingRequest, verify=False)
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
class OtherTests(unittest.TestCase):
def setUp(self):
self.server = object()
self.client = Client(self.server, secret=b'zeer geheim')
def testCreateAuthPacket(self):
packet = self.client.CreateAuthPacket(id=15)
self.failUnless(isinstance(packet, AuthPacket))
self.failUnless(packet.dict is self.client.dict)
self.assertEqual(packet.id, 15)
self.assertEqual(packet.secret, b'zeer geheim')
def testCreateAcctPacket(self):
packet = self.client.CreateAcctPacket(id=15)
self.failUnless(isinstance(packet, AcctPacket))
self.failUnless(packet.dict is self.client.dict)
self.assertEqual(packet.id, 15)
self.assertEqual(packet.secret, b'zeer geheim')

View File

@@ -1,332 +0,0 @@
import unittest
import operator
import os
from io import StringIO
from pyrad.tests import home
from pyrad.dictionary import Attribute
from pyrad.dictionary import Dictionary
from pyrad.dictionary import ParseError
from pyrad.tools import DecodeAttr
from pyrad.dictfile import DictFile
class AttributeTests(unittest.TestCase):
def testInvalidDataType(self):
self.assertRaises(ValueError, Attribute, 'name', 'code', 'datatype')
def testConstructionParameters(self):
attr = Attribute('name', 'code', 'integer', False, 'vendor')
self.assertEqual(attr.name, 'name')
self.assertEqual(attr.code, 'code')
self.assertEqual(attr.type, 'integer')
self.assertEqual(attr.is_sub_attribute, False)
self.assertEqual(attr.vendor, 'vendor')
self.assertEqual(len(attr.values), 0)
self.assertEqual(len(attr.sub_attributes), 0)
def testNamedConstructionParameters(self):
attr = Attribute(name='name', code='code', datatype='integer',
vendor='vendor')
self.assertEqual(attr.name, 'name')
self.assertEqual(attr.code, 'code')
self.assertEqual(attr.type, 'integer')
self.assertEqual(attr.vendor, 'vendor')
self.assertEqual(len(attr.values), 0)
def testValues(self):
attr = Attribute('name', 'code', 'integer', False, 'vendor',
dict(pie='custard', shake='vanilla'))
self.assertEqual(len(attr.values), 2)
self.assertEqual(attr.values['shake'], 'vanilla')
class DictionaryInterfaceTests(unittest.TestCase):
def testEmptyDictionary(self):
dict = Dictionary()
self.assertEqual(len(dict), 0)
def testContainment(self):
dict = Dictionary()
self.assertEqual('test' in dict, False)
dict.attributes['test'] = 'dummy'
self.assertEqual('test' in dict, True)
def testReadonlyContainer(self):
dict = Dictionary()
self.assertRaises(TypeError,
operator.setitem, dict, 'test', 'dummy')
self.assertRaises(AttributeError,
operator.attrgetter('clear'), dict)
self.assertRaises(AttributeError,
operator.attrgetter('update'), dict)
class DictionaryParsingTests(unittest.TestCase):
simple_dict_values = [
('Test-String', 1, 'string'),
('Test-Octets', 2, 'octets'),
('Test-Integer', 0x03, 'integer'),
('Test-Ip-Address', 4, 'ipaddr'),
('Test-Ipv6-Address', 5, 'ipv6addr'),
('Test-If-Id', 6, 'ifid'),
('Test-Date', 7, 'date'),
('Test-Abinary', 8, 'abinary'),
('Test-Tlv', 9, 'tlv'),
('Test-Tlv-Str', 1, 'string'),
('Test-Tlv-Int', 2, 'integer'),
('Test-Integer64', 10, 'integer64'),
('Test-Integer64-Hex', 10, 'integer64'),
('Test-Integer64-Oct', 10, 'integer64'),
]
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'simple'))
def testParseEmptyDictionary(self):
dict = Dictionary(StringIO(''))
self.assertEqual(len(dict), 0)
def testParseMultipleDictionaries(self):
dict = Dictionary(StringIO(''))
self.assertEqual(len(dict), 0)
one = StringIO('ATTRIBUTE Test-First 1 string')
two = StringIO('ATTRIBUTE Test-Second 2 string')
dict = Dictionary(StringIO(''), one, two)
self.assertEqual(len(dict), 2)
def testParseSimpleDictionary(self):
self.assertEqual(len(self.dict), len(self.simple_dict_values))
for (attr, code, type) in self.simple_dict_values:
attr = self.dict[attr]
self.assertEqual(attr.code, code)
self.assertEqual(attr.type, type)
def testAttributeTooFewColumnsError(self):
try:
self.dict.read_dictionary(
StringIO('ATTRIBUTE Oops-Too-Few-Columns'))
except ParseError as e:
self.assertEqual('attribute' in str(e), True)
else:
self.fail()
def testAttributeUnknownTypeError(self):
try:
self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 dummy'))
except ParseError as e:
self.assertEqual('dummy' in str(e), True)
else:
self.fail()
def testAttributeUnknownVendorError(self):
try:
self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 Simplon'))
except ParseError as e:
self.assertEqual('Simplon' in str(e), True)
else:
self.fail()
def testAttributeOptions(self):
self.dict.read_dictionary(StringIO(
'ATTRIBUTE Option-Type 1 string has_tag,encrypt=1'))
self.assertEqual(self.dict['Option-Type'].has_tag, True)
self.assertEqual(self.dict['Option-Type'].encrypt, 1)
def testAttributeEncryptionError(self):
try:
self.dict.read_dictionary(StringIO(
'ATTRIBUTE Test-Type 1 string encrypt=4'))
except ParseError as e:
self.assertEqual('encrypt' in str(e), True)
else:
self.fail()
def testValueTooFewColumnsError(self):
try:
self.dict.read_dictionary(StringIO('VALUE Oops-Too-Few-Columns'))
except ParseError as e:
self.assertEqual('value' in str(e), True)
else:
self.fail()
def testValueForUnknownAttributeError(self):
try:
self.dict.read_dictionary(StringIO(
'VALUE Test-Attribute Test-Text 1'))
except ParseError as e:
self.assertEqual('unknown attribute' in str(e), True)
else:
self.fail()
def testIntegerValueParsing(self):
self.assertEqual(len(self.dict['Test-Integer'].values), 0)
self.dict.read_dictionary(StringIO('VALUE Test-Integer Value-Six 5'))
self.assertEqual(len(self.dict['Test-Integer'].values), 1)
self.assertEqual(
DecodeAttr('integer', self.dict['Test-Integer'].values['Value-Six']),
5)
def testInteger64ValueParsing(self):
self.assertEqual(len(self.dict['Test-Integer64'].values), 0)
self.dict.read_dictionary(StringIO('VALUE Test-Integer64 Value-Six 5'))
self.assertEqual(len(self.dict['Test-Integer64'].values), 1)
self.assertEqual(
DecodeAttr('integer64', self.dict['Test-Integer64'].values['Value-Six']),
5)
def testStringValueParsing(self):
self.assertEqual(len(self.dict['Test-String'].values), 0)
self.dict.read_dictionary(StringIO('VALUE Test-String Value-Custard custardpie'))
self.assertEqual(len(self.dict['Test-String'].values), 1)
self.assertEqual(
DecodeAttr('string', self.dict['Test-String'].values['Value-Custard']),
'custardpie')
def testTlvParsing(self):
self.assertEqual(len(self.dict['Test-Tlv'].sub_attributes), 2)
self.assertEqual(self.dict['Test-Tlv'].sub_attributes,
{1: 'Test-Tlv-Str', 2: 'Test-Tlv-Int'})
def testSubTlvParsing(self):
for (attr, _, _) in self.simple_dict_values:
if attr.startswith('Test-Tlv-'):
self.assertEqual(self.dict[attr].is_sub_attribute, True)
self.assertEqual(self.dict[attr].parent, self.dict['Test-Tlv'])
else:
self.assertEqual(self.dict[attr].is_sub_attribute, False)
self.assertEqual(self.dict[attr].parent, None)
# tlv with vendor
full_dict = Dictionary(os.path.join(self.path, 'full'))
self.assertEqual(full_dict['Simplon-Tlv-Str'].is_sub_attribute, True)
self.assertEqual(full_dict['Simplon-Tlv-Str'].parent, full_dict['Simplon-Tlv'])
self.assertEqual(full_dict['Simplon-Tlv-Int'].is_sub_attribute, True)
self.assertEqual(full_dict['Simplon-Tlv-Int'].parent, full_dict['Simplon-Tlv'])
def testVenderTooFewColumnsError(self):
try:
self.dict.read_dictionary(StringIO('VENDOR Simplon'))
except ParseError as e:
self.assertEqual('vendor' in str(e), True)
else:
self.fail()
def testVendorParsing(self):
self.assertRaises(ParseError, self.dict.read_dictionary,
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
self.dict.read_dictionary(StringIO('VENDOR Simplon 42'))
self.assertEqual(self.dict.vendors['Simplon'], 42)
self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1))
def testVendorOptionError(self):
self.assertRaises(ParseError, self.dict.read_dictionary,
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
try:
self.dict.read_dictionary(StringIO('VENDOR Simplon 42 badoption'))
except ParseError as e:
self.assertEqual('option' in str(e), True)
else:
self.fail()
def testVendorFormatError(self):
self.assertRaises(ParseError, self.dict.read_dictionary,
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
try:
self.dict.read_dictionary(StringIO('VENDOR Simplon 42 format=5,4'))
except ParseError as e:
self.assertEqual('format' in str(e), True)
else:
self.fail()
def testVendorFormatSyntaxError(self):
self.assertRaises(ParseError, self.dict.read_dictionary,
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
try:
self.dict.read_dictionary(StringIO('VENDOR Simplon 42 format=a,1'))
except ParseError as e:
self.assertEqual('Syntax' in str(e), True)
else:
self.fail()
def testBeginVendorTooFewColumns(self):
try:
self.dict.read_dictionary(StringIO('BEGIN-VENDOR'))
except ParseError as e:
self.assertEqual('begin-vendor' in str(e), True)
else:
self.fail()
def testBeginVendorUnknownVendor(self):
try:
self.dict.read_dictionary(StringIO('BEGIN-VENDOR Simplon'))
except ParseError as e:
self.assertEqual('Simplon' in str(e), True)
else:
self.fail()
def testBeginVendorParsing(self):
self.dict.read_dictionary(StringIO(
'VENDOR Simplon 42\n'
'BEGIN-VENDOR Simplon\n'
'ATTRIBUTE Test-Type 1 integer'))
self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1))
def testEndVendorUnknownVendor(self):
try:
self.dict.read_dictionary(StringIO('END-VENDOR'))
except ParseError as e:
self.assertEqual('end-vendor' in str(e), True)
else:
self.fail()
def testEndVendorUnbalanced(self):
try:
self.dict.read_dictionary(StringIO(
'VENDOR Simplon 42\n'
'BEGIN-VENDOR Simplon\n'
'END-VENDOR Oops\n'))
except ParseError as e:
self.assertEqual('Oops' in str(e), True)
else:
self.fail()
def testEndVendorParsing(self):
self.dict.read_dictionary(StringIO(
'VENDOR Simplon 42\n'
'BEGIN-VENDOR Simplon\n'
'END-VENDOR Simplon\n'
'ATTRIBUTE Test-Type 1 integer'))
self.assertEqual(self.dict.attrindex['Test-Type'], 1)
def testInclude(self):
try:
self.dict.read_dictionary(StringIO(
'$INCLUDE this_file_does_not_exist\n'
'VENDOR Simplon 42\n'
'BEGIN-VENDOR Simplon\n'
'END-VENDOR Simplon\n'
'ATTRIBUTE Test-Type 1 integer'))
except IOError as e:
self.assertEqual('this_file_does_not_exist' in str(e), True)
else:
self.fail()
def testDictFilePostParse(self):
f = DictFile(StringIO(
'VENDOR Simplon 42\n'))
for _ in f:
pass
self.assertEqual(f.file(), '')
self.assertEqual(f.line(), -1)
def testDictFileParseError(self):
tmpdict = Dictionary()
try:
tmpdict.read_dictionary(os.path.join(self.path, 'dictfiletest'))
except ParseError as e:
self.assertEqual('dictfiletest' in str(e), True)
else:
self.fail()

View File

@@ -1,87 +0,0 @@
import unittest
from pyrad.host import Host
from pyrad.packet import Packet
from pyrad.packet import AuthPacket
from pyrad.packet import AcctPacket
class ConstructionTests(unittest.TestCase):
def testSimpleConstruction(self):
host = Host()
self.assertEqual(host.authport, 1812)
self.assertEqual(host.acctport, 1813)
def testParameterOrder(self):
host = Host(123, 456, 789, 101)
self.assertEqual(host.authport, 123)
self.assertEqual(host.acctport, 456)
self.assertEqual(host.coaport, 789)
self.assertEqual(host.dict, 101)
def testNamedParameters(self):
host = Host(authport=123, acctport=456, coaport=789, dict=101)
self.assertEqual(host.authport, 123)
self.assertEqual(host.acctport, 456)
self.assertEqual(host.coaport, 789)
self.assertEqual(host.dict, 101)
class PacketCreationTests(unittest.TestCase):
def setUp(self):
self.host = Host()
def testCreatePacket(self):
packet = self.host.CreatePacket(id=15)
self.failUnless(isinstance(packet, Packet))
self.failUnless(packet.dict is self.host.dict)
self.assertEqual(packet.id, 15)
def testCreateAuthPacket(self):
packet = self.host.CreateAuthPacket(id=15)
self.failUnless(isinstance(packet, AuthPacket))
self.failUnless(packet.dict is self.host.dict)
self.assertEqual(packet.id, 15)
def testCreateAcctPacket(self):
packet = self.host.CreateAcctPacket(id=15)
self.failUnless(isinstance(packet, AcctPacket))
self.failUnless(packet.dict is self.host.dict)
self.assertEqual(packet.id, 15)
class MockPacket:
packet = object()
replypacket = object()
source = object()
def Packet(self):
return self.packet
def ReplyPacket(self):
return self.replypacket
class MockFd:
data = None
target = None
def sendto(self, data, target):
self.data = data
self.target = target
class PacketSendTest(unittest.TestCase):
def setUp(self):
self.host = Host()
self.fd = MockFd()
self.packet = MockPacket()
def testSendPacket(self):
self.host.SendPacket(self.fd, self.packet)
self.failUnless(self.fd.data is self.packet.packet)
self.failUnless(self.fd.target is self.packet.source)
def testSendReplyPacket(self):
self.host.SendReplyPacket(self.fd, self.packet)
self.failUnless(self.fd.data is self.packet.replypacket)
self.failUnless(self.fd.target is self.packet.source)

View File

@@ -1,535 +0,0 @@
import os
import unittest
from collections import OrderedDict
from pyrad import packet
from pyrad.client import Client
from pyrad.tests import home
from pyrad.dictionary import Dictionary
import hashlib
md5_constructor = hashlib.md5
class UtilityTests(unittest.TestCase):
def testGenerateID(self):
id = packet.CreateID()
self.assertTrue(isinstance(id, int))
newid = packet.CreateID()
self.assertNotEqual(id, newid)
class PacketConstructionTests(unittest.TestCase):
klass = packet.Packet
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'simple'))
def testBasicConstructor(self):
pkt = self.klass()
self.assertTrue(isinstance(pkt.code, int))
self.assertTrue(isinstance(pkt.id, int))
self.assertTrue(isinstance(pkt.secret, bytes))
def testNamedConstructor(self):
pkt = self.klass(code=26, id=38, secret=b'secret',
authenticator=b'authenticator',
dict='fakedict')
self.assertEqual(pkt.code, 26)
self.assertEqual(pkt.id, 38)
self.assertEqual(pkt.secret, b'secret')
self.assertEqual(pkt.authenticator, b'authenticator')
self.assertEqual(pkt.dict, 'fakedict')
def testConstructWithDictionary(self):
pkt = self.klass(dict=self.dict)
self.assertTrue(pkt.dict is self.dict)
def testConstructorIgnoredParameters(self):
marker = []
pkt = self.klass(fd=marker)
self.assertFalse(getattr(pkt, 'fd', None) is marker)
def testSecretMustBeBytestring(self):
self.assertRaises(TypeError, self.klass, secret='secret')
def testConstructorWithAttributes(self):
pkt = self.klass(**{'Test-String': 'this works', 'dict': self.dict})
self.assertEqual(pkt['Test-String'], ['this works'])
def testConstructorWithTlvAttribute(self):
pkt = self.klass(**{
'Test-Tlv-Str': 'this works',
'Test-Tlv-Int': 10,
'dict': self.dict
})
self.assertEqual(
pkt['Test-Tlv'],
{'Test-Tlv-Str': ['this works'], 'Test-Tlv-Int': [10]}
)
class PacketTests(unittest.TestCase):
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'full'))
self.packet = packet.Packet(
id=0, secret=b'secret',
authenticator=b'01234567890ABCDEF', dict=self.dict)
def testCreateReply(self):
reply = self.packet.CreateReply(**{'Test-Integer': 10})
self.assertEqual(reply.id, self.packet.id)
self.assertEqual(reply.secret, self.packet.secret)
self.assertEqual(reply.authenticator, self.packet.authenticator)
self.assertEqual(reply['Test-Integer'], [10])
def testAttributeAccess(self):
self.packet['Test-Integer'] = 10
self.assertEqual(self.packet['Test-Integer'], [10])
self.assertEqual(self.packet[3], [b'\x00\x00\x00\x0a'])
self.packet['Test-String'] = 'dummy'
self.assertEqual(self.packet['Test-String'], ['dummy'])
self.assertEqual(self.packet[1], [b'dummy'])
def testAttributeValueAccess(self):
self.packet['Test-Integer'] = 'Three'
self.assertEqual(self.packet['Test-Integer'], ['Three'])
self.assertEqual(self.packet[3], [b'\x00\x00\x00\x03'])
def testVendorAttributeAccess(self):
self.packet['Simplon-Number'] = 10
self.assertEqual(self.packet['Simplon-Number'], [10])
self.assertEqual(self.packet[(16, 1)], [b'\x00\x00\x00\x0a'])
self.packet['Simplon-Number'] = 'Four'
self.assertEqual(self.packet['Simplon-Number'], ['Four'])
self.assertEqual(self.packet[(16, 1)], [b'\x00\x00\x00\x04'])
def testRawAttributeAccess(self):
marker = [b'']
self.packet[1] = marker
self.assertTrue(self.packet[1] is marker)
self.packet[(16, 1)] = marker
self.assertTrue(self.packet[(16, 1)] is marker)
def testHasKey(self):
self.assertEqual('Test-String' in self.packet, False)
self.packet['Test-String'] = 'dummy'
self.assertEqual('Test-String' in self.packet, True)
self.assertEqual(1 in self.packet, True)
def testHasKeyWithUnknownKey(self):
self.assertEqual('Unknown-Attribute' in self.packet, False)
def testDelItem(self):
self.packet['Test-String'] = 'dummy'
del self.packet['Test-String']
self.assertEqual('Test-String' in self.packet, False)
self.packet['Test-String'] = 'dummy'
del self.packet[1]
self.assertEqual('Test-String' in self.packet, False)
def testKeys(self):
self.assertEqual(self.packet.keys(), [])
self.packet['Test-String'] = 'dummy'
self.assertEqual(self.packet.keys(), ['Test-String'])
self.packet['Test-Integer'] = 10
self.assertEqual(self.packet.keys(), ['Test-String', 'Test-Integer'])
OrderedDict.__setitem__(self.packet, 12345, None)
self.assertEqual(self.packet.keys(),
['Test-String', 'Test-Integer', 12345])
def testCreateAuthenticator(self):
a = packet.Packet.CreateAuthenticator()
self.assertTrue(isinstance(a, bytes))
self.assertEqual(len(a), 16)
b = packet.Packet.CreateAuthenticator()
self.assertNotEqual(a, b)
def testGenerateID(self):
id = self.packet.CreateID()
self.assertTrue(isinstance(id, int))
newid = self.packet.CreateID()
self.assertNotEqual(id, newid)
def testReplyPacket(self):
reply = self.packet.ReplyPacket()
self.assertEqual(
reply,
(b'\x00\x00\x00\x14\xb0\x5e\x4b\xfb\xcc\x1c'
b'\x8c\x8e\xc4\x72\xac\xea\x87\x45\x63\xa7'))
def testVerifyReply(self):
reply = self.packet.CreateReply()
reply.id += 1
with self.assertRaises(packet.PacketError):
self.packet.VerifyReply(reply.ReplyPacket())
reply.id = self.packet.id
reply.secret = b'different'
with self.assertRaises(packet.PacketError):
self.packet.VerifyReply(reply.ReplyPacket())
reply.secret = self.packet.secret
reply.authenticator = b'X' * 16
with self.assertRaises(packet.PacketError):
self.packet.VerifyReply(reply.ReplyPacket())
def testPktEncodeAttribute(self):
encode = self.packet._pkt_encode_attribute
# Encode a normal attribute
self.assertEqual(
encode(1, b'value'),
b'\x01\x07value')
# Encode a vendor attribute
self.assertEqual(
encode((1, 2), b'value'),
b'\x1a\x0d\x00\x00\x00\x01\x02\x07value')
def testPktEncodeTlvAttribute(self):
encode = self.packet._pkt_encode_tlv
# Encode a normal tlv attribute
self.assertEqual(
encode(4, {1: [b'value'], 2: [b'\x00\x00\x00\x02']}),
b'\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02')
# Encode a normal tlv attribute with several sub attribute instances
self.assertEqual(
encode(4, {1: [b'value', b'other'], 2: [b'\x00\x00\x00\x02']}),
b'\x04\x16\x01\x07value\x02\x06\x00\x00\x00\x02\x01\x07other')
# Encode a vendor tlv attribute
self.assertEqual(
encode((16, 3), {1: [b'value'], 2: [b'\x00\x00\x00\x02']}),
b'\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02')
def testPktEncodeLongTlvAttribute(self):
encode = self.packet._pkt_encode_tlv
long_str = b'a' * 245
# Encode a long tlv attribute - check it is split between AVPs
self.assertEqual(
encode(4, {1: [b'value', long_str], 2: [b'\x00\x00\x00\x02']}),
b'\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02\x04\xf9\x01\xf7' + long_str)
# Encode a long vendor tlv attribute
first_avp = b'\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02'
second_avp = b'\x1a\xff\x00\x00\x00\x10\x03\xf9\x01\xf7' + long_str
self.assertEqual(
encode((16, 3), {1: [b'value', long_str], 2: [b'\x00\x00\x00\x02']}),
first_avp + second_avp)
def testpkt_encode_attributes(self):
self.packet[1] = [b'value']
self.assertEqual(self.packet._pkt_encode_attributes(),
b'\x01\x07value')
self.packet.clear()
self.packet[(16, 2)] = [b'value']
self.assertEqual(self.packet._pkt_encode_attributes(),
b'\x1a\x0d\x00\x00\x00\x10\x02\x07value')
self.packet.clear()
self.packet[1] = [b'one', b'two', b'three']
self.assertEqual(self.packet._pkt_encode_attributes(),
b'\x01\x05one\x01\x05two\x01\x07three')
self.packet.clear()
self.packet[1] = [b'value']
self.packet[(16, 2)] = [b'value']
self.assertEqual(self.packet._pkt_encode_attributes(),
b'\x01\x07value\x1a\x0d\x00\x00\x00\x10\x02\x07value')
def testPktDecodeVendorAttribute(self):
decode = self.packet._pkt_decode_vendor_attribute
# Non-RFC2865 recommended form
self.assertEqual(decode(b''), [(26, b'')])
self.assertEqual(decode(b'12345'), [(26, b'12345')])
# Almost RFC2865 recommended form: bad length value
self.assertEqual(decode(b'\x00\x00\x00\x01\x02\x06value'),
[(26, b'\x00\x00\x00\x01\x02\x06value')])
# Proper RFC2865 recommended form
self.assertEqual(decode(b'\x00\x00\x00\x10\x02\x07value'),
[((16, 2), b'value')])
def testPktDecodeTlvAttribute(self):
decode = self.packet._pkt_decode_tlv_attribute
decode(4, b'\x01\x07value')
self.assertEqual(self.packet[4], {1: [b'value']})
# add another instance of the same sub attribute
decode(4, b'\x01\x07other')
self.assertEqual(self.packet[4], {1: [b'value', b'other']})
# add a different sub attribute
decode(4, b'\x02\x07\x00\x00\x00\x01')
self.assertEqual(self.packet[4], {
1: [b'value', b'other'],
2: [b'\x00\x00\x00\x01']
})
def testDecodePacketWithEmptyPacket(self):
try:
self.packet.DecodePacket(b'')
except packet.PacketError as e:
self.assertTrue('header is corrupt' in str(e))
else:
self.fail()
def testDecodePacketWithInvalidLength(self):
try:
self.packet.DecodePacket(b'\x00\x00\x00\x001234567890123456')
except packet.PacketError as e:
self.assertTrue('invalid length' in str(e))
else:
self.fail()
def testDecodePacketWithTooBigPacket(self):
try:
self.packet.DecodePacket(b'\x00\x00\x24\x00' + (0x2400 - 4) * b'X')
except packet.PacketError as e:
self.assertTrue('too long' in str(e))
else:
self.fail()
def testDecodePacketWithPartialAttributes(self):
try:
self.packet.DecodePacket(
b'\x01\x02\x00\x151234567890123456\x00')
except packet.PacketError as e:
self.assertTrue('header is corrupt' in str(e))
else:
self.fail()
def testDecodePacketWithoutAttributes(self):
self.packet.DecodePacket(b'\x01\x02\x00\x141234567890123456')
self.assertEqual(self.packet.code, 1)
self.assertEqual(self.packet.id, 2)
self.assertEqual(self.packet.authenticator, b'1234567890123456')
self.assertEqual(self.packet.keys(), [])
def testDecodePacketWithBadAttribute(self):
try:
self.packet.DecodePacket(
b'\x01\x02\x00\x161234567890123456\x00\x01')
except packet.PacketError as e:
self.assertTrue('too small' in str(e))
else:
self.fail()
def testDecodePacketWithEmptyAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x161234567890123456\x01\x02')
self.assertEqual(self.packet[1], [b''])
def testDecodePacketWithAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x1b1234567890123456\x01\x07value')
self.assertEqual(self.packet[1], [b'value'])
def testDecodePacketWithTlvAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x1d1234567890123456\x04\x09\x01\x07value')
self.assertEqual(self.packet[4], {1: [b'value']})
def testDecodePacketWithVendorTlvAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x231234567890123456\x1a\x0f\x00\x00\x00\x10\x03\x09\x01\x07value')
self.assertEqual(self.packet[(16, 3)], {1: [b'value']})
def testDecodePacketWithTlvAttributeWith2SubAttributes(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x231234567890123456\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x09')
self.assertEqual(self.packet[4], {1: [b'value'], 2: [b'\x00\x00\x00\x09']})
def testDecodePacketWithSplitTlvAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x251234567890123456\x04\x09\x01\x07value\x04\x09\x02\x06\x00\x00\x00\x09')
self.assertEqual(self.packet[4], {1: [b'value'], 2: [b'\x00\x00\x00\x09']})
def testDecodePacketWithMultiValuedAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two')
self.assertEqual(self.packet[1], [b'one', b'two'])
def testDecodePacketWithTwoAttributes(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two')
self.assertEqual(self.packet[1], [b'one', b'two'])
def testDecodePacketWithVendorAttribute(self):
self.packet.DecodePacket(
b'\x01\x02\x00\x1b1234567890123456\x1a\x07value')
self.assertEqual(self.packet[26], [b'value'])
def testEncodeKeyValues(self):
self.assertEqual(self.packet._encode_key_values(1, '1234'), (1, '1234'))
def testEncodeKey(self):
self.assertEqual(self.packet._encode_key(1), 1)
def testAddAttribute(self):
self.packet.AddAttribute('Test-String', '1')
self.assertEqual(self.packet['Test-String'], ['1'])
self.packet.AddAttribute('Test-String', '1')
self.assertEqual(self.packet['Test-String'], ['1', '1'])
self.packet.AddAttribute('Test-String', ['2', '3'])
self.assertEqual(self.packet['Test-String'], ['1', '1', '2', '3'])
class AuthPacketConstructionTests(PacketConstructionTests):
klass = packet.AuthPacket
def testConstructorDefaults(self):
pkt = self.klass()
self.assertEqual(pkt.code, packet.AccessRequest)
class AuthPacketTests(unittest.TestCase):
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'full'))
self.packet = packet.AuthPacket(
id=0, secret=b'secret',
authenticator=b'01234567890ABCDEF', dict=self.dict)
def testCreateReply(self):
reply = self.packet.CreateReply(**{'Test-Integer': 10})
self.assertEqual(reply.code, packet.AccessAccept)
self.assertEqual(reply.id, self.packet.id)
self.assertEqual(reply.secret, self.packet.secret)
self.assertEqual(reply.authenticator, self.packet.authenticator)
self.assertEqual(reply['Test-Integer'], [10])
def testRequestPacket(self):
self.assertEqual(self.packet.RequestPacket(),
b'\x01\x00\x00\x1401234567890ABCDE')
def testRequestPacketCreatesAuthenticator(self):
self.packet.authenticator = None
self.packet.RequestPacket()
self.assertTrue(self.packet.authenticator is not None)
def testRequestPacketCreatesID(self):
self.packet.id = None
self.packet.RequestPacket()
self.assertTrue(self.packet.id is not None)
def testPwCryptEmptyPassword(self):
self.assertEqual(self.packet.PwCrypt(''), b'')
def testPwCryptPassword(self):
self.assertEqual(self.packet.PwCrypt('Simplon'),
b'\xd3U;\xb23\r\x11\xba\x07\xe3\xa8*\xa8x\x14\x01')
def testPwCryptSetsAuthenticator(self):
self.packet.authenticator = None
self.packet.PwCrypt('')
self.assertTrue(self.packet.authenticator is not None)
def testPwDecryptEmptyPassword(self):
self.assertEqual(self.packet.PwDecrypt(b''), '')
def testPwDecryptPassword(self):
self.assertEqual(
self.packet.PwDecrypt(b'\xd3U;\xb23\r\x11\xba\x07\xe3\xa8*\xa8x\x14\x01'),
'Simplon')
class AuthPacketChapTests(unittest.TestCase):
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'chap'))
self.client = Client(server='localhost', secret=b'secret',
dict=self.dict)
def testVerifyChapPasswd(self):
chap_id = b'9'
chap_challenge = b'987654321'
chap_password = chap_id + md5_constructor(
chap_id + b'test_password' + chap_challenge).digest()
pkt = self.client.CreateAuthPacket(
code=packet.AccessChallenge,
authenticator=b'ABCDEFG',
User_Name='test_name',
CHAP_Challenge=chap_challenge,
CHAP_Password=chap_password
)
self.assertEqual(pkt['CHAP-Challenge'][0], chap_challenge)
self.assertEqual(pkt['CHAP-Password'][0], chap_password)
self.assertEqual(pkt.VerifyChapPasswd('test_password'), True)
class AuthPacketSaltTests(unittest.TestCase):
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'tunnelPassword'))
self.packet = packet.Packet(id=0, secret=b'secret',
dict=self.dict)
def testSaltCrypt(self):
self.packet['Tunnel-Password:1'] = 'test'
# TODO: need to get a correct reference values
# self.assertEqual(self.packet['Tunnel-Password'], b'')
class AcctPacketConstructionTests(PacketConstructionTests):
klass = packet.AcctPacket
def testConstructorDefaults(self):
pkt = self.klass()
self.assertEqual(pkt.code, packet.AccountingRequest)
def testConstructorRawPacket(self):
raw = (b'\x00\x00\x00\x14\xb0\x5e\x4b\xfb\xcc\x1c'
b'\x8c\x8e\xc4\x72\xac\xea\x87\x45\x63\xa7')
pkt = self.klass(packet=raw)
self.assertEqual(pkt.raw_packet, raw)
class AcctPacketTests(unittest.TestCase):
def setUp(self):
self.path = os.path.join(home, 'tests', 'data')
self.dict = Dictionary(os.path.join(self.path, 'full'))
self.packet = packet.AcctPacket(
id=0, secret=b'secret',
authenticator=b'01234567890ABCDEF', dict=self.dict)
def testCreateReply(self):
reply = self.packet.CreateReply(**{'Test-Integer': 10})
self.assertEqual(reply.code, packet.AccountingResponse)
self.assertEqual(reply.id, self.packet.id)
self.assertEqual(reply.secret, self.packet.secret)
self.assertEqual(reply.authenticator, self.packet.authenticator)
self.assertEqual(reply['Test-Integer'], [10])
def testVerifyAcctRequest(self):
rawpacket = self.packet.RequestPacket()
pkt = packet.AcctPacket(secret=b'secret', packet=rawpacket)
self.assertEqual(pkt.VerifyAcctRequest(), True)
pkt.secret = b'different'
self.assertEqual(pkt.VerifyAcctRequest(), False)
pkt.secret = b'secret'
pkt.raw_packet = b'X' + pkt.raw_packet[1:]
self.assertEqual(pkt.VerifyAcctRequest(), False)
def testRequestPacket(self):
self.assertEqual(
self.packet.RequestPacket(),
b'\x04\x00\x00\x14\x95\xdf\x90\xccbn\xfb\x15G!\x13\xea\xfa>6\x0f')
def testRequestPacketSetsId(self):
self.packet.id = None
self.packet.RequestPacket()
self.assertTrue(self.packet.id is not None)

View File

@@ -1,99 +0,0 @@
import select
import socket
import unittest
from pyrad.proxy import Proxy
from pyrad.packet import AccessAccept
from pyrad.packet import AccessRequest
from pyrad.server import ServerPacketError
from pyrad.server import Server
from pyrad.tests.mock import MockFd
from pyrad.tests.mock import MockPoll
from pyrad.tests.mock import MockSocket
from pyrad.tests.mock import MockClassMethod
from pyrad.tests.mock import UnmockClassMethods
class TrivialObject:
"""dummy object"""
class SocketTests(unittest.TestCase):
def setUp(self):
self.orgsocket = socket.socket
socket.socket = MockSocket
self.proxy = Proxy()
self.proxy._fdmap = {}
def tearDown(self):
socket.socket = self.orgsocket
def testProxyFd(self):
self.proxy._poll = MockPoll()
self.proxy._prepare_sockets()
self.failUnless(isinstance(self.proxy._proxyfd, MockSocket))
self.assertEqual(list(self.proxy._fdmap.keys()), [1])
self.assertEqual(
self.proxy._poll.registry,
{1: select.POLLIN | select.POLLPRI | select.POLLERR})
class ProxyPacketHandlingTests(unittest.TestCase):
def setUp(self):
self.proxy = Proxy()
self.proxy.hosts['host'] = TrivialObject()
self.proxy.hosts['host'].secret = 'supersecret'
self.packet = TrivialObject()
self.packet.code = AccessAccept
self.packet.source = ('host', 'port')
def testHandleProxyPacketUnknownHost(self):
self.packet.source = ('stranger', 'port')
try:
self.proxy._handle_proxy_packet(self.packet)
except ServerPacketError as e:
self.failUnless('unknown host' in str(e))
else:
self.fail()
def testHandleProxyPacketSetsSecret(self):
self.proxy._handle_proxy_packet(self.packet)
self.assertEqual(self.packet.secret, 'supersecret')
def testHandleProxyPacketHandlesWrongPacket(self):
self.packet.code = AccessRequest
try:
self.proxy._handle_proxy_packet(self.packet)
except ServerPacketError as e:
self.failUnless('non-response' in str(e))
else:
self.fail()
class OtherTests(unittest.TestCase):
def setUp(self):
self.proxy = Proxy()
self.proxy._proxyfd = MockFd()
def tearDown(self):
UnmockClassMethods(Proxy)
UnmockClassMethods(Server)
def testProcessInputNonProxyPort(self):
fd = MockFd(fd=111)
MockClassMethod(Server, '_process_input')
self.proxy._process_input(fd)
self.assertEqual(
self.proxy.called,
[('_process_input', (fd,), {})])
def testProcessInput(self):
MockClassMethod(Proxy, '_grab_packet')
MockClassMethod(Proxy, '_handle_proxy_packet')
self.proxy._process_input(self.proxy._proxyfd)
self.assertEqual(
[x[0] for x in self.proxy.called],
['_grab_packet', '_handle_proxy_packet'])
if not hasattr(select, 'poll'):
del SocketTests

View File

@@ -1,329 +0,0 @@
import select
import socket
import unittest
from pyrad.packet import PacketError
from pyrad.server import RemoteHost
from pyrad.server import Server
from pyrad.server import ServerPacketError
from pyrad.tests.mock import MockFinished
from pyrad.tests.mock import MockFd
from pyrad.tests.mock import MockPoll
from pyrad.tests.mock import MockSocket
from pyrad.tests.mock import MockClassMethod
from pyrad.tests.mock import UnmockClassMethods
from pyrad.packet import AccessRequest
from pyrad.packet import AccountingRequest
class TrivialObject:
"""dummy objec"""
class RemoteHostTests(unittest.TestCase):
def testSimpleConstruction(self):
host = RemoteHost('address', 'secret', 'name', 'authport', 'acctport', 'coaport')
self.assertEqual(host.address, 'address')
self.assertEqual(host.secret, 'secret')
self.assertEqual(host.name, 'name')
self.assertEqual(host.authport, 'authport')
self.assertEqual(host.acctport, 'acctport')
self.assertEqual(host.coaport, 'coaport')
def testNamedConstruction(self):
host = RemoteHost(
address='address', secret='secret', name='name',
authport='authport', acctport='acctport', coaport='coaport')
self.assertEqual(host.address, 'address')
self.assertEqual(host.secret, 'secret')
self.assertEqual(host.name, 'name')
self.assertEqual(host.authport, 'authport')
self.assertEqual(host.acctport, 'acctport')
self.assertEqual(host.coaport, 'coaport')
class ServerConstructiontests(unittest.TestCase):
def testSimpleConstruction(self):
server = Server()
self.assertEqual(server.authfds, [])
self.assertEqual(server.acctfds, [])
self.assertEqual(server.authport, 1812)
self.assertEqual(server.acctport, 1813)
self.assertEqual(server.coaport, 3799)
self.assertEqual(server.hosts, {})
def testParameterOrder(self):
server = Server([], 'authport', 'acctport', 'coaport', 'hosts', 'dict')
self.assertEqual(server.authfds, [])
self.assertEqual(server.acctfds, [])
self.assertEqual(server.authport, 'authport')
self.assertEqual(server.acctport, 'acctport')
self.assertEqual(server.coaport, 'coaport')
self.assertEqual(server.dict, 'dict')
def testBindDuringConstruction(self):
def BindToAddress(self, addr):
self.bound.append(addr)
bta = Server.BindToAddress
Server.BindToAddress = BindToAddress
Server.bound = []
server = Server(['one', 'two', 'three'])
self.assertEqual(server.bound, ['one', 'two', 'three'])
del Server.bound
Server.BindToAddress = bta
class SocketTests(unittest.TestCase):
def setUp(self):
self.orgsocket = socket.socket
socket.socket = MockSocket
self.server = Server()
def tearDown(self):
socket.socket = self.orgsocket
def testBind(self):
self.server.BindToAddress('192.168.13.13')
self.assertEqual(len(self.server.authfds), 1)
self.assertEqual(self.server.authfds[0].address,
('192.168.13.13', 1812))
self.assertEqual(len(self.server.acctfds), 1)
self.assertEqual(self.server.acctfds[0].address,
('192.168.13.13', 1813))
def testBindv6(self):
self.server.BindToAddress('2001:db8:123::1')
self.assertEqual(len(self.server.authfds), 1)
self.assertEqual(self.server.authfds[0].address,
('2001:db8:123::1', 1812))
self.assertEqual(len(self.server.acctfds), 1)
self.assertEqual(self.server.acctfds[0].address,
('2001:db8:123::1', 1813))
def testGrabPacket(self):
def gen(data):
res = TrivialObject()
res.data = data
return res
fd = MockFd()
fd.source = object()
pkt = self.server._grab_packet(gen, fd)
self.failUnless(isinstance(pkt, TrivialObject))
self.failUnless(pkt.fd is fd)
self.failUnless(pkt.source is fd.source)
self.failUnless(pkt.data is fd.data)
def testPrepareSocketNoFds(self):
self.server._poll = MockPoll()
self.server._prepare_sockets()
self.assertEqual(self.server._poll.registry, {})
self.assertEqual(self.server._realauthfds, [])
self.assertEqual(self.server._realacctfds, [])
def testPrepareSocketAuthFds(self):
self.server._poll = MockPoll()
self.server._fdmap = {}
self.server.authfds = [MockFd(12), MockFd(14)]
self.server._prepare_sockets()
self.assertEqual(list(self.server._fdmap.keys()), [12, 14])
self.assertEqual(
self.server._poll.registry,
{12: select.POLLIN | select.POLLPRI | select.POLLERR,
14: select.POLLIN | select.POLLPRI | select.POLLERR})
def testPrepareSocketAcctFds(self):
self.server._poll = MockPoll()
self.server._fdmap = {}
self.server.acctfds = [MockFd(12), MockFd(14)]
self.server._prepare_sockets()
self.assertEqual(list(self.server._fdmap.keys()), [12, 14])
self.assertEqual(
self.server._poll.registry,
{12: select.POLLIN | select.POLLPRI | select.POLLERR,
14: select.POLLIN | select.POLLPRI | select.POLLERR})
class AuthPacketHandlingTests(unittest.TestCase):
def setUp(self):
self.server = Server()
self.server.hosts['host'] = TrivialObject()
self.server.hosts['host'].secret = 'supersecret'
self.packet = TrivialObject()
self.packet.code = AccessRequest
self.packet.source = ('host', 'port')
def testHandleAuthPacketUnknownHost(self):
self.packet.source = ('stranger', 'port')
try:
self.server._handle_auth_packet(self.packet)
except ServerPacketError as e:
self.failUnless('unknown host' in str(e))
else:
self.fail()
def testHandleAuthPacketWrongPort(self):
self.packet.code = AccountingRequest
try:
self.server._handle_auth_packet(self.packet)
except ServerPacketError as e:
self.failUnless('port' in str(e))
else:
self.fail()
def testHandleAuthPacket(self):
def HandleAuthPacket(self, pkt):
self.handled = pkt
hap = Server.HandleAuthPacket
Server.HandleAuthPacket = HandleAuthPacket
self.server._handle_auth_packet(self.packet)
self.failUnless(self.server.handled is self.packet)
Server.HandleAuthPacket = hap
class AcctPacketHandlingTests(unittest.TestCase):
def setUp(self):
self.server = Server()
self.server.hosts['host'] = TrivialObject()
self.server.hosts['host'].secret = 'supersecret'
self.packet = TrivialObject()
self.packet.code = AccountingRequest
self.packet.source = ('host', 'port')
def testHandleAcctPacketUnknownHost(self):
self.packet.source = ('stranger', 'port')
try:
self.server._handle_acct_packet(self.packet)
except ServerPacketError as e:
self.failUnless('unknown host' in str(e))
else:
self.fail()
def testHandleAcctPacketWrongPort(self):
self.packet.code = AccessRequest
try:
self.server._handle_acct_packet(self.packet)
except ServerPacketError as e:
self.failUnless('port' in str(e))
else:
self.fail()
def testHandleAcctPacket(self):
def HandleAcctPacket(self, pkt):
self.handled = pkt
hap = Server.HandleAcctPacket
Server.HandleAcctPacket = HandleAcctPacket
self.server._handle_acct_packet(self.packet)
self.failUnless(self.server.handled is self.packet)
Server.HandleAcctPacket = hap
class OtherTests(unittest.TestCase):
def setUp(self):
self.server = Server()
def tearDown(self):
UnmockClassMethods(Server)
def testCreateReplyPacket(self):
class TrivialPacket:
source = object()
def CreateReply(self, **kw):
reply = TrivialObject()
reply.kw = kw
return reply
reply = self.server.CreateReplyPacket(
TrivialPacket(),
one='one', two='two')
self.failUnless(isinstance(reply, TrivialObject))
self.failUnless(reply.source is TrivialPacket.source)
self.assertEqual(reply.kw, dict(one='one', two='two'))
def testAuthProcessInput(self):
fd = MockFd(1)
self.server._realauthfds = [1]
MockClassMethod(Server, '_grab_packet')
MockClassMethod(Server, '_handle_auth_packet')
self.server._process_input(fd)
self.assertEqual(
[x[0] for x in self.server.called],
['_grab_packet', '_handle_auth_packet'])
self.assertEqual(self.server.called[0][1][1], fd)
def testAcctProcessInput(self):
fd = MockFd(1)
self.server._realauthfds = []
self.server._realacctfds = [1]
MockClassMethod(Server, '_grab_packet')
MockClassMethod(Server, '_handle_acct_packet')
self.server._process_input(fd)
self.assertEqual(
[x[0] for x in self.server.called],
['_grab_packet', '_handle_acct_packet'])
self.assertEqual(self.server.called[0][1][1], fd)
class ServerRunTests(unittest.TestCase):
def setUp(self):
self.server = Server()
self.origpoll = select.poll
select.poll = MockPoll
def tearDown(self):
MockPoll.results = []
select.poll = self.origpoll
UnmockClassMethods(Server)
def testRunInitializes(self):
MockClassMethod(Server, '_prepare_sockets')
self.assertRaises(MockFinished, self.server.Run)
self.assertEqual(self.server.called, [('_prepare_sockets', (), {})])
self.failUnless(isinstance(self.server._fdmap, dict))
self.failUnless(isinstance(self.server._poll, MockPoll))
def testRunIgnoresPollErrors(self):
self.server.authfds = [MockFd()]
MockPoll.results = [(0, select.POLLERR)]
self.assertRaises(MockFinished, self.server.Run)
def testRunIgnoresServerPacketErrors(self):
def RaisePacketError(self, fd):
raise ServerPacketError
MockClassMethod(Server, '_process_input', RaisePacketError)
self.server.authfds = [MockFd()]
MockPoll.results = [(0, select.POLLIN)]
self.assertRaises(MockFinished, self.server.Run)
def testRunIgnoresPacketErrors(self):
def RaisePacketError(self, fd):
raise PacketError
MockClassMethod(Server, '_process_input', RaisePacketError)
self.server.authfds = [MockFd()]
MockPoll.results = [(0, select.POLLIN)]
self.assertRaises(MockFinished, self.server.Run)
def testRunRunsProcessInput(self):
MockClassMethod(Server, '_process_input')
self.server.authfds = fd = [MockFd()]
MockPoll.results = [(0, select.POLLIN)]
self.assertRaises(MockFinished, self.server.Run)
self.assertEqual(self.server.called, [('_process_input', (fd[0],), {})])
if not hasattr(select, 'poll'):
del SocketTests
del ServerRunTests

View File

@@ -1,119 +0,0 @@
from ipaddress import AddressValueError
from pyrad import tools
import unittest
class EncodingTests(unittest.TestCase):
def testStringEncoding(self):
self.assertRaises(ValueError, tools.EncodeString, 'x' * 254)
self.assertEqual(
tools.EncodeString('1234567890'),
b'1234567890')
def testInvalidStringEncodingRaisesTypeError(self):
self.assertRaises(TypeError, tools.EncodeString, 1)
def testAddressEncoding(self):
self.assertRaises(AddressValueError, tools.EncodeAddress, 'TEST123')
self.assertEqual(
tools.EncodeAddress('192.168.0.255'),
b'\xc0\xa8\x00\xff')
def testInvalidAddressEncodingRaisesTypeError(self):
self.assertRaises(TypeError, tools.EncodeAddress, 1)
def testIntegerEncoding(self):
self.assertEqual(tools.EncodeInteger(0x01020304), b'\x01\x02\x03\x04')
def testInteger64Encoding(self):
self.assertEqual(
tools.EncodeInteger64(0xFFFFFFFFFFFFFFFF), b'\xff' * 8
)
def testUnsignedIntegerEncoding(self):
self.assertEqual(tools.EncodeInteger(0xFFFFFFFF), b'\xff\xff\xff\xff')
def testInvalidIntegerEncodingRaisesTypeError(self):
self.assertRaises(TypeError, tools.EncodeInteger, 'ONE')
def testDateEncoding(self):
self.assertEqual(tools.EncodeDate(0x01020304), b'\x01\x02\x03\x04')
def testInvalidDataEncodingRaisesTypeError(self):
self.assertRaises(TypeError, tools.EncodeDate, '1')
def testEncodeAscendBinary(self):
self.assertEqual(
tools.EncodeAscendBinary('family=ipv4 action=discard direction=in dst=10.10.255.254/32'),
b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
def testStringDecoding(self):
self.assertEqual(
tools.DecodeString(b'1234567890'),
'1234567890')
def testAddressDecoding(self):
self.assertEqual(
tools.DecodeAddress(b'\xc0\xa8\x00\xff'),
'192.168.0.255')
def testIntegerDecoding(self):
self.assertEqual(
tools.DecodeInteger(b'\x01\x02\x03\x04'),
0x01020304)
def testInteger64Decoding(self):
self.assertEqual(
tools.DecodeInteger64(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF
)
def testDateDecoding(self):
self.assertEqual(
tools.DecodeDate(b'\x01\x02\x03\x04'),
0x01020304)
def testUnknownTypeEncoding(self):
self.assertRaises(ValueError, tools.EncodeAttr, 'unknown', None)
def testUnknownTypeDecoding(self):
self.assertRaises(ValueError, tools.DecodeAttr, 'unknown', None)
def testEncodeFunction(self):
self.assertEqual(
tools.EncodeAttr('string', 'string'),
b'string')
self.assertEqual(
tools.EncodeAttr('octets', b'string'),
b'string')
self.assertEqual(
tools.EncodeAttr('ipaddr', '192.168.0.255'),
b'\xc0\xa8\x00\xff')
self.assertEqual(
tools.EncodeAttr('integer', 0x01020304),
b'\x01\x02\x03\x04')
self.assertEqual(
tools.EncodeAttr('date', 0x01020304),
b'\x01\x02\x03\x04')
self.assertEqual(
tools.EncodeAttr('integer64', 0xFFFFFFFFFFFFFFFF),
b'\xff'*8)
def testDecodeFunction(self):
self.assertEqual(
tools.DecodeAttr('string', b'string'),
'string')
self.assertEqual(
tools.EncodeAttr('octets', b'string'),
b'string')
self.assertEqual(
tools.DecodeAttr('ipaddr', b'\xc0\xa8\x00\xff'),
'192.168.0.255')
self.assertEqual(
tools.DecodeAttr('integer', b'\x01\x02\x03\x04'),
0x01020304)
self.assertEqual(
tools.DecodeAttr('integer64', b'\xff'*8),
0xFFFFFFFFFFFFFFFF)
self.assertEqual(
tools.DecodeAttr('date', b'\x01\x02\x03\x04'),
0x01020304)

View File

@@ -1,226 +0,0 @@
# tools.py
#
# Utility functions
import binascii
import ipaddress
import struct
def EncodeString(string):
if len(string) > 253:
raise ValueError('Can only encode strings of <= 253 characters')
if isinstance(string, str):
return string.encode('utf-8')
return string
def EncodeOctets(string):
if len(string) > 253:
raise ValueError('Can only encode strings of <= 253 characters')
if string.startswith(b'0x'):
hexstring = string.split(b'0x')[1]
return binascii.unhexlify(hexstring)
else:
return string
def EncodeAddress(addr):
if not isinstance(addr, str):
raise TypeError('Address has to be a string')
return ipaddress.IPv4Address(addr).packed
def EncodeIPv6Prefix(addr):
if not isinstance(addr, str):
raise TypeError('IPv6 Prefix has to be a string')
ip = ipaddress.IPv6Network(addr)
return struct.pack('2B', *[0, ip.prefixlen]) + ip.network_address.packed
def EncodeIPv6Address(addr):
if not isinstance(addr, str):
raise TypeError('IPv6 Address has to be a string')
return ipaddress.IPv6Address(addr).packed
def EncodeAscendBinary(string):
"""
Format: List of type=value pairs sperated by spaces.
Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'
Type:
family ipv4(default) or ipv6
action discard(default) or accept
direction in(default) or out
src source prefix (default ignore)
dst destination prefix (default ignore)
proto protocol number / next-header number (default ignore)
sport source port (default ignore)
dport destination port (default ignore)
sportq source port qualifier (default 0)
dportq destination port qualifier (default 0)
Source/Destination Port Qualifier:
0 no compare
1 less than
2 equal to
3 greater than
4 not equal to
"""
terms = {
'family': b'\x01',
'action': b'\x00',
'direction': b'\x01',
'src': b'\x00\x00\x00\x00',
'dst': b'\x00\x00\x00\x00',
'srcl': b'\x00',
'dstl': b'\x00',
'proto': b'\x00',
'sport': b'\x00\x00',
'dport': b'\x00\x00',
'sportq': b'\x00',
'dportq': b'\x00'
}
for t in string.split(' '):
key, value = t.split('=')
if key == 'family' and value == 'ipv6':
terms[key] = b'\x03'
if terms['src'] == b'\x00\x00\x00\x00':
terms['src'] = 16 * b'\x00'
if terms['dst'] == b'\x00\x00\x00\x00':
terms['dst'] = 16 * b'\x00'
elif key == 'action' and value == 'accept':
terms[key] = b'\x01'
elif key == 'direction' and value == 'out':
terms[key] = b'\x00'
elif key in ('src', 'dst'):
ip = ipaddress.ip_network(value)
terms[key] = ip.network_address.packed
terms[key+'l'] = struct.pack('B', ip.prefixlen)
elif key in ('sport', 'dport'):
terms[key] = struct.pack('!H', int(value))
elif key in ('sportq', 'dportq', 'proto'):
terms[key] = struct.pack('B', int(value))
trailer = 8 * b'\x00'
result = b''.join((
terms['family'], terms['action'], terms['direction'], b'\x00',
terms['src'], terms['dst'], terms['srcl'], terms['dstl'], terms['proto'], b'\x00',
terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], b'\x00\x00', trailer))
return result
def EncodeInteger(num, format='!I'):
try:
num = int(num)
except ValueError:
raise TypeError('Can not encode non-integer as integer')
return struct.pack(format, num)
def EncodeInteger64(num, format='!Q'):
try:
num = int(num)
except ValueError:
raise TypeError('Can not encode non-integer as integer64')
return struct.pack(format, num)
def EncodeDate(num):
if not isinstance(num, int):
raise TypeError('Can not encode non-integer as date')
return struct.pack('!I', num)
def DecodeString(string):
try:
return string.decode('utf-8')
except:
return string
def DecodeOctets(string):
return string
def DecodeAddress(addr):
return '.'.join((str(a) for a in struct.unpack('BBBB', addr)))
def DecodeIPv6Prefix(addr):
addr = addr + b'\x00' * (18-len(addr))
prefix = addr[:2]
addr = addr[2:]
return str(ipaddress.ip_network((prefix, addr)))
def DecodeIPv6Address(addr):
addr = addr + b'\x00' * (16-len(addr))
return str(ipaddress.IPv6Address(addr))
def DecodeAscendBinary(string):
return string
def DecodeInteger(num, format='!I'):
return (struct.unpack(format, num))[0]
def DecodeInteger64(num, format='!Q'):
return (struct.unpack(format, num))[0]
def DecodeDate(num):
return (struct.unpack('!I', num))[0]
ENCODE_MAP = {
'string': EncodeString,
'octets': EncodeOctets,
'integer': EncodeInteger,
'ipaddr': EncodeAddress,
'ipv6prefix': EncodeIPv6Prefix,
'ipv6addr': EncodeIPv6Address,
'abinary': EncodeAscendBinary,
'signed': lambda value: EncodeInteger(value, '!i'),
'short': lambda value: EncodeInteger(value, '!H'),
'byte': lambda value: EncodeInteger(value, '!B'),
'date': EncodeDate,
'integer64': EncodeInteger64,
}
def EncodeAttr(datatype, value):
try:
return ENCODE_MAP[datatype](value)
except KeyError:
raise ValueError(f'Unknown attribute type {datatype}')
DECODE_MAP = {
'string': DecodeString,
'octets': DecodeOctets,
'integer': DecodeInteger,
'ipaddr': DecodeAddress,
'ipv6prefix': DecodeIPv6Prefix,
'ipv6addr': DecodeIPv6Address,
'abinary': DecodeAscendBinary,
'signed': lambda value: DecodeInteger(value, '!i'),
'short': lambda value: DecodeInteger(value, '!H'),
'byte': lambda value: DecodeInteger(value, '!B'),
'date': DecodeDate,
'integer64': DecodeInteger64,
}
def DecodeAttr(datatype, value):
try:
return DECODE_MAP[datatype](value)
except KeyError:
raise ValueError(f'Unknown attribute type {datatype}')

View File

@@ -1,57 +0,0 @@
from pyrad.packet import Packet, random_generator
def salt_encrypt(packet: Packet, value: bytes) -> bytes:
length = struct.pack('B', len(value))
buf = length + value
buf += b'\x00' * (16 - (len(buf) % 16))
# First bit if the random value must be 1
random_value = 32768 + random_generator.randrange(0, 32767)
result = struct.pack('!H', random_value)
last = packet.authenticator + result
while buf:
cur_hash = md5_constructor(packet.secret + last).digest()
for b, h in zip(buf, cur_hash):
result += bytes([b ^ h])
last = result[-16:]
buf = buf[16:]
return result
def validate_chap_password(packet: Packet, password: bytes) -> bool:
# TODO:
challange = packet.get('CHAP-Challenge', packet.authenticator)
return chap_password == hashlib.md5(chapid + password + challenge).digest()
def validate_pap_password(packet: Packet, password: bytes) -> bool:
obf_pass = password_encode(packet, password)
return packet['Password'] == obf_pass
def password_encode(packet: Packet, password: bytes) -> bytes:
password += b'\x00' * (16 - (len(password) % 16))
return obfuscation_algorithm(packet, password)
def password_decode(packet: Packet, password: bytes) -> bytes:
decoded = obfuscation_algorithm(packet, password)
return decoded.rstrip(b'\x00')
def obfuscation_algorithm(packet: Packet, password: bytes) -> bytes:
result = b''
buf = password
last = packet.authenticator
while buf:
cur_hash = md5_constructor(packet.secret + last)
for b, h in zip(buf, cur_hash):
result += bytes([b ^ h])
(last, buf) = (buf[:16], buf[16:])
return result

View File

@@ -1,5 +1,20 @@
let
pkgs = import <nixpkgs> {};
python = pkgs.python36;
in
import ./default.nix { inherit pkgs python; }
pkgs.mkShell {
buildInputs = with pkgs; [
git
nixfmt
pkgs.python3
pkgs.python3Packages.black
pkgs.python3Packages.pytest
pkgs.python3Packages.pytest-black
pkgs.python3Packages.pytest-flake8
pkgs.python3Packages.pytest-mypy
pkgs.python3Packages.pytest-pylint
];
shellHook = ''
export PYTHONPATH=${./src}:$PYTHONPATH
'';
}

View File

@@ -38,9 +38,9 @@ This package contains four modules:
__docformat__ = 'epytext en'
__author__ = 'Christian Giese <developer@gicnet.de>'
__author__ = 'Istvan Ruzman <istvan@ruzman.eu>'
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
__copyright__ = 'Copyright 2002-2020 Wichert Akkerman and Christian Giese. All rights reserved.'
__version__ = '2.3'
__copyright__ = 'Copyright 2020 Istvan Ruzman'
__version__ = '0.1.0'
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'dictfile', 'new_client', 'new_host', 'new_packet']
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils']

View File

@@ -3,7 +3,7 @@
# Bidirectional map
class BiDict():
class BiDict:
def __init__(self):
self.forward = {}
self.backward = {}

View File

@@ -1,63 +1,94 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Implementation of a simple but extensible RADIUS Client"""
from typing import Optional, Union
from ipaddress import IPv4Address, IPv6Address
import select
import socket
import time
from pyrad3 import new_packet as P
from pyrad3 import new_host
import pyrad3.packet as P
from pyrad3 import host
SUPPORTED_SEND_TYPES = [
P.AccessRequest,
P.AccountingRequest,
P.CoARequest,
P.Code.AccessRequest,
P.Code.AccountingRequest,
P.Code.CoARequest,
]
PACKET_TYPE_PORT_MAPPING = {
P.AccessRequest: 'authport',
P.AccountingRequest: 'acctport',
P.CoARequest: 'coaport',
P.Code.AccessRequest: "authport",
P.Code.AccountingRequest: "acctport",
P.Code.CoARequest: "coaport",
}
class Timeout(Exception):
pass
"""Exception for wait timeouts"""
class UnsupportedPacketType(Exception):
pass
"""Exception for received packets"""
class Client(new_host.Host):
def __init__(self, server, secret, radius_dictionary, **kwargs):
class Client(host.Host):
"""A simple and extensible RADIUS Client."""
def __init__(
self,
server: Union[str, IPv4Address, IPv6Address],
secret: bytes,
radius_dictionary: dict,
interface: Optional[str],
**kwargs,
):
super().__init__(secret, radius_dictionary, **kwargs)
self.server = server
self._socket = None
self._poll = None
self.interface = interface
self._socket: Optional[socket.socket] = None
self._poll: Optional[select.poll] = None
def bind(self, addr):
"""Bind the Address to some socket"""
self._socket_close()
self._socket_open()
self._socket.bind(addr)
def _socket_open(self):
"""Open a client socket"""
if self._socket is not None:
return
try:
family = socket.getaddrinfo(self.server, 'www')[0][0]
family = socket.getaddrinfo(self.server, "www")[0][0]
except socket.gaierror:
family = socket.AF_INET
self._socket = socket.socket(family, socket.SOCK_DGRAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if self.interface is not None:
# This will fail on non-Linux systems - that's ok, because we don't
# have any implementation yet for non-Linux systems.
# Better to fail loudly than to do something unexpected silently.
self._socket.setsockopt(
socket.SOL_SOCKET,
socket.SO_BINDTODEVICE,
self.interface,
len(self.interface),
)
self._poll = select.poll()
self._poll.register(self._socket, select.POLLIN)
def _socket_close(self):
"""Close the Client socket"""
if self._socket is not None:
self._poll.unregister(self._socket)
self._socket.close()
self._socket = None
def _select_port(self, packet):
def _select_port(self, packet: P.Packet):
"""Select the RADIUS Port depending on the RADIUS Packet type"""
try:
port_type = PACKET_TYPE_PORT_MAPPING[packet.code]
return getattr(self, port_type)
@@ -65,7 +96,11 @@ class Client(new_host.Host):
pass
raise UnsupportedPacketType(f"The packet type {packet.code} by Client")
def _send_packet(self, packet):
def _send_packet(self, packet: P.Packet):
"""Send a RADIUS Packet and wait for the reply"""
assert self._socket is not None
assert self._poll is not None
port = self._select_port(packet)
raw_packet = packet.serialize()

View File

@@ -13,9 +13,7 @@ from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
class DatagramProtocolClient(asyncio.Protocol):
def __init__(self, server, port, logger,
client, retries=3, timeout=30):
def __init__(self, server, port, logger, client, retries=3, timeout=30):
self.transport = None
self.port = port
self.server = server
@@ -42,22 +40,31 @@ class DatagramProtocolClient(asyncio.Protocol):
# noinspection PyShadowingBuiltins
for id, req in self.pending_requests.items():
secs = (req['send_date'] - now).seconds
secs = (req["send_date"] - now).seconds
if secs > self.timeout:
if req['retries'] == self.retries:
self.logger.debug('[%s:%d] For request %d execute all retries',
self.server, self.port, id)
req['future'].set_exception(
TimeoutError('Timeout on Reply')
if req["retries"] == self.retries:
self.logger.debug(
"[%s:%d] For request %d execute all retries",
self.server,
self.port,
id,
)
req["future"].set_exception(
TimeoutError("Timeout on Reply")
)
req2delete.append(id)
else:
# Send again packet
req['send_date'] = now
req['retries'] += 1
self.logger.debug('[%s:%d] For request %d execute retry %d',
self.server, self.port, id, req['retries'])
self.transport.sendto(req['packet'].RequestPacket())
req["send_date"] = now
req["retries"] += 1
self.logger.debug(
"[%s:%d] For request %d execute retry %d",
self.server,
self.port,
id,
req["retries"],
)
self.transport.sendto(req["packet"].RequestPacket())
elif next_weak_up > secs:
next_weak_up = secs
@@ -72,15 +79,15 @@ class DatagramProtocolClient(asyncio.Protocol):
def send_packet(self, packet, future):
if packet.id in self.pending_requests:
raise Exception(f'Packet with id {packet.id} already present')
raise Exception(f"Packet with id {packet.id} already present")
# Store packet on pending requests map
self.pending_requests[packet.id] = {
'packet': packet,
'creation_date': datetime.now(),
'retries': 0,
'future': future,
'send_date': datetime.now()
"packet": packet,
"creation_date": datetime.now(),
"retries": 0,
"future": future,
"send_date": datetime.now(),
}
# In queue packet raw on socket buffer
@@ -88,10 +95,11 @@ class DatagramProtocolClient(asyncio.Protocol):
def connection_made(self, transport):
self.transport = transport
socket = transport.get_extra_info('socket')
socket = transport.get_extra_info("socket")
self.logger.info(
'[%s:%d] Transport created with binding in %s:%d',
self.server, self.port,
"[%s:%d] Transport created with binding in %s:%d",
self.server,
self.port,
socket.getsockname()[0],
socket.getsockname()[1],
)
@@ -99,38 +107,47 @@ class DatagramProtocolClient(asyncio.Protocol):
pre_loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop=self.client.loop)
# Start asynchronous timer handler
self.timeout_future = asyncio.ensure_future(
self.__timeout_handler__()
)
self.timeout_future = asyncio.ensure_future(self.__timeout_handler__())
asyncio.set_event_loop(loop=pre_loop)
def error_received(self, exc):
self.logger.error('[%s:%d] Error received: %s', self.server, self.port, exc)
self.logger.error(
"[%s:%d] Error received: %s", self.server, self.port, exc
)
def connection_lost(self, exc):
if exc:
self.logger.warn('[%s:%d] Connection lost: %s', self.server, self.port, str(exc))
self.logger.warn(
"[%s:%d] Connection lost: %s", self.server, self.port, str(exc)
)
else:
self.logger.info('[%s:%d] Transport closed', self.server, self.port)
self.logger.info("[%s:%d] Transport closed", self.server, self.port)
# noinspection PyUnusedLocal
def datagram_received(self, data, addr):
try:
req = self.pending_requests[data[0]]
reply = req.VerifyPacket(data)
req['future'].set_result(reply)
req["future"].set_result(reply)
# Remove request for map
del self.pending_requests[reply.id]
except KeyError:
self.logger.warn('[%s:%d] Ignore invalid reply: %s',
self.server, self.port, data)
self.logger.warn(
"[%s:%d] Ignore invalid reply: %s", self.server, self.port, data
)
except PacketError as exc:
self.logger.error('[%s:%d] Error on decode or verify packet: %s',
self.server, self.port, exc)
self.logger.error(
"[%s:%d] Error on decode or verify packet: %s",
self.server,
self.port,
exc,
)
async def close_transport(self):
if self.transport:
self.logger.debug('[%s:%d] Closing transport...', self.server, self.port)
self.logger.debug(
"[%s:%d] Closing transport...", self.server, self.port
)
self.transport.close()
self.transport = None
if self.timeout_future:
@@ -143,7 +160,9 @@ class DatagramProtocolClient(asyncio.Protocol):
return self.packet_id
def __str__(self):
return f'DatagramProtocolClient(server?={self.server}, port={self.port})'
return (
f"DatagramProtocolClient(server?={self.server}, port={self.port})"
)
# Used as protocol_factory
def __call__(self):
@@ -161,11 +180,21 @@ class ClientAsync:
:ivar timeout: number of seconds to wait for an answer
:type timeout: integer
"""
# noinspection PyShadowingBuiltins
def __init__(self, server, auth_port=1812, acct_port=1813,
coa_port=3799, secret=b'', dict=None,
loop=None, retries=3, timeout=30,
logger_name='pyrad'):
def __init__(
self,
server,
auth_port=1812,
acct_port=1813,
coa_port=3799,
secret=b"",
dict=None,
loop=None,
retries=3,
timeout=30,
logger_name="pyrad",
):
"""Constructor.
@@ -205,23 +234,30 @@ class ClientAsync:
self.protocol_coa = None
self.coa_port = coa_port
async def initialize_transports(self, enable_acct=False,
enable_auth=False, enable_coa=False,
local_addr=None, local_auth_port=None,
local_acct_port=None, local_coa_port=None):
async def initialize_transports(
self,
enable_acct=False,
enable_auth=False,
enable_coa=False,
local_addr=None,
local_auth_port=None,
local_acct_port=None,
local_coa_port=None,
):
task_list = []
if not enable_acct and not enable_auth and not enable_coa:
raise Exception('No transports selected')
raise Exception("No transports selected")
if enable_acct and not self.protocol_acct:
self.protocol_acct = DatagramProtocolClient(
self.server,
self.acct_port,
self.logger, self,
self.logger,
self,
retries=self.retries,
timeout=self.timeout
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_acct_port:
@@ -231,7 +267,7 @@ class ClientAsync:
self.protocol_acct,
reuse_port=True,
remote_addr=(self.server, self.acct_port),
local_addr=bind_addr
local_addr=bind_addr,
)
task_list.append(acct_connect)
@@ -239,9 +275,10 @@ class ClientAsync:
self.protocol_auth = DatagramProtocolClient(
self.server,
self.auth_port,
self.logger, self,
self.logger,
self,
retries=self.retries,
timeout=self.timeout
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_auth_port:
@@ -251,7 +288,7 @@ class ClientAsync:
self.protocol_auth,
reuse_port=True,
remote_addr=(self.server, self.auth_port),
local_addr=bind_addr
local_addr=bind_addr,
)
task_list.append(auth_connect)
@@ -259,9 +296,10 @@ class ClientAsync:
self.protocol_coa = DatagramProtocolClient(
self.server,
self.coa_port,
self.logger, self,
self.logger,
self,
retries=self.retries,
timeout=self.timeout
timeout=self.timeout,
)
bind_addr = None
if local_addr and local_coa_port:
@@ -271,22 +309,18 @@ class ClientAsync:
self.protocol_coa,
reuse_port=True,
remote_addr=(self.server, self.coa_port),
local_addr=bind_addr
local_addr=bind_addr,
)
task_list.append(coa_connect)
await asyncio.ensure_future(
asyncio.gather(
*task_list,
return_exceptions=False,
),
loop=self.loop
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
)
# noinspection SpellCheckingInspection
async def deinitialize_transports(self, deinit_coa=True,
deinit_auth=True,
deinit_acct=True):
async def deinitialize_transports(
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
):
if self.protocol_coa and deinit_coa:
await self.protocol_coa.close_transport()
del self.protocol_coa
@@ -312,11 +346,14 @@ class ClientAsync:
:rtype: pyrad.packet.Packet
"""
if not self.protocol_auth:
raise Exception('Transport not initialized')
raise Exception("Transport not initialized")
return AuthPacket(dict=self.dict,
id=self.protocol_auth.create_id(),
secret=self.secret, **args)
return AuthPacket(
dict=self.dict,
id=self.protocol_auth.create_id(),
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
def CreateAcctPacket(self, **args):
@@ -330,11 +367,14 @@ class ClientAsync:
:rtype: pyrad.packet.Packet
"""
if not self.protocol_acct:
raise Exception('Transport not initialized')
raise Exception("Transport not initialized")
return AcctPacket(id=self.protocol_acct.create_id(),
dict=self.dict,
secret=self.secret, **args)
return AcctPacket(
id=self.protocol_acct.create_id(),
dict=self.dict,
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
def CreateCoAPacket(self, **args):
@@ -349,20 +389,22 @@ class ClientAsync:
"""
if not self.protocol_acct:
raise Exception('Transport not initialized')
raise Exception("Transport not initialized")
return CoAPacket(id=self.protocol_coa.create_id(),
dict=self.dict,
secret=self.secret, **args)
return CoAPacket(
id=self.protocol_coa.create_id(),
dict=self.dict,
secret=self.secret,
**args,
)
# noinspection PyPep8Naming
# noinspection PyShadowingBuiltins
def CreatePacket(self, id, **args):
if not id:
raise Exception('Missing mandatory packet id')
raise Exception("Missing mandatory packet id")
return Packet(id=id, dict=self.dict,
secret=self.secret, **args)
return Packet(id=id, dict=self.dict, secret=self.secret, **args)
# noinspection PyPep8Naming
def SendPacket(self, pkt):
@@ -378,23 +420,23 @@ class ClientAsync:
if isinstance(pkt, AuthPacket):
if not self.protocol_auth:
raise Exception('Transport not initialized')
raise Exception("Transport not initialized")
self.protocol_auth.send_packet(pkt, ans)
elif isinstance(pkt, AcctPacket):
if not self.protocol_acct:
raise Exception('Transport not initialized')
raise Exception("Transport not initialized")
self.protocol_acct.send_packet(pkt, ans)
elif isinstance(pkt, CoAPacket):
if not self.protocol_coa:
raise Exception('Transport not initialized')
raise Exception("Transport not initialized")
self.protocol_coa.send_packet(pkt, ans)
else:
raise Exception('Unsupported packet')
raise Exception("Unsupported packet")
return ans

483
src/pyrad3/dictionary.py Normal file
View File

@@ -0,0 +1,483 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""RADIUS Dictionary.
Classes and Types to parse and represent a RADIUS dictionary.
"""
from enum import IntEnum, Enum, auto
from dataclasses import dataclass
from os.path import dirname, isabs, join, normpath
from typing import Dict, Generator, IO, List, Optional, Sequence, Tuple, Union
import logging
LOG = logging.getLogger(__name__)
INTEGER_TYPES = {
"byte": (0, 255),
"short": (0, 2 ** 16 - 1),
"signed": (-(2 ** 31), 2 ** 31 - 1),
"integer": (0, 2 ** 32 - 1),
"integer64": (0, 2 ** 64 - 1),
}
class Datatype(Enum):
"""Possible Datatypes for ATTRIBUTES"""
string = auto()
octets = auto()
date = auto()
abinary = auto()
byte = auto()
short = auto()
integer = auto()
signed = auto()
integer64 = auto()
ipaddr = auto()
ipv4prefix = auto()
ipv6addr = auto()
ipv6prefix = auto()
comboip = auto()
ifid = auto()
ether = auto()
concat = auto()
tlv = auto()
extended = auto()
longextended = auto()
evs = auto()
class ParseError(Exception):
"""RADIUS Dictionary Parser Error"""
def __init__(
self,
filename: str,
msg: str = None,
line: Optional[int] = None,
**data,
):
super().__init__()
self.msg = msg
self.file = filename
self.line = line
self.data = data
def __str__(self):
line = f"({self.line}" if self.line is not None else ""
return f"{self.file}{line}: ParseError: {self.msg}"
class Encrypt(IntEnum):
"""Enum for different RADIUS Encryption types."""
NoEncrpytion = 0
RadiusCrypt = 1
SaltCrypt = 2
AscendCrypt = 3
@dataclass
class Attribute: # pylint: disable=too-many-instance-attributes
"""RADIUS Attribute definition"""
name: str
code: int
datatype: Datatype
has_tag: bool = False
encrypt: Encrypt = Encrypt(0)
is_sub_attr: bool = False
# vendor = Dictionary
values: Dict[Union[int, str], Union[int, str]] = None
@dataclass
class Vendor:
"""Representation of a vendor"""
name: str
code: int
tlength: int
llength: int
continuation: bool
attrs: Dict[Union[int, Tuple[int, ...]], Attribute]
def dict_parser(
filename: str, rad_dict: IO
) -> Generator[Tuple[int, List[str]], None, None]:
"""Tokenstream of RADIUS Dictionary files
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
and "FILE_CLOSED" tokens will be emitted.
"""
yield (-1, ["FILE_OPENED", filename])
for line_num, line in enumerate(rad_dict.readlines()):
tokens = line.split("#", 1)[0].strip().split()
if tokens:
first_tok = tokens[0] = tokens[0].upper()
if first_tok == "$INCLUDE":
try:
inner_filename = tokens[1]
except IndexError:
raise ParseError(
filename, "$INCLUDE is missing a filename", line_num,
)
if not isabs(tokens[1]):
path = dirname(filename)
inner_filename = normpath(join(path, inner_filename))
yield from dict_loader(inner_filename)
yield (line_num, tokens)
yield (-1, ["FILE_CLOSED"])
def dict_loader(filename: str) -> Generator[Tuple[int, List[str]], None, None]:
"""Tokenstream of RADIUS Dictionary files
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
and "FILE_CLOSED" tokens will be emitted.
"""
with open(filename, "r") as rad_dict:
yield from dict_parser(filename, rad_dict)
def _parse_number(num: str) -> int:
"""Parse a number from (Free)RADIUS dictionaries
Numbers can be either decimal, octal, or hexadecimal.
"""
if num.startswith("0x"):
return int(num, 16)
if num.startswith("0o"):
return int(num, 8)
return int(num)
def _parse_attribute_code(attr_code: str) -> List[int]:
"""Parse attribute codes from (Free)RADIUS dictionaries
Codes can be either decimal, octal, or hexadecimal.
TLV typed can
"""
codes = []
for code in attr_code.split("."):
codes.append(_parse_number(code))
return codes
class Dictionary:
"""(Free)RADIUS Dictionary.
#TODO: Better documentation
This dictionary can "contain" multiple dictionaries.
"""
# there must be some nicer way to unittest this...
def __init__(self, dictionary: str, __dictio: Optional[IO] = None):
self.vendor: Dict[int, Vendor] = {}
self.vendor_lookup_id_by_name: Dict[str, int] = {}
self.attrindex: Dict[Union[int, str], Attribute] = {}
self.rfc_vendor = Vendor("RFC", 0, 1, 1, False, {})
self.cur_vendor = self.rfc_vendor
if __dictio is not None:
loader = dict_parser(dictionary, __dictio)
else:
loader = dict_loader(dictionary)
self.read_dictionary(loader)
def read_dictionary(
self, reader: Generator[Tuple[int, List[str]], None, None]
):
"""Read and parse a (Free)RADIUS dictionary."""
self.filestack = []
for line_num, tokens in reader:
key = tokens[0]
if key == "ATTRIBUTE":
self._parse_attribute(tokens, line_num)
elif key == "VALUE":
self._parse_value(tokens, line_num)
elif key == "FILE_OPENED":
LOG.info("Parsing file: %s", tokens[1])
if tokens[1] in self.filestack:
raise ParseError(
self.filestack[-1], "Include recursion detected"
)
self.filestack.append(tokens[1])
elif key == "FILE_CLOSED":
filename = self.filestack.pop()
LOG.info("Finished parsing file: %s", filename)
elif key == "VENDOR":
self._parse_vendor(tokens, line_num)
elif key == "BEGIN-VENDOR":
self._parse_begin_vendor(tokens, line_num)
elif key == "END-VENDOR":
self._parse_end_vendor(tokens, line_num)
elif key == "BEGIN-TLV":
raise NotImplementedError(
"BEGIN-TLV is deprecated and not supported by pyrad3"
)
elif key == "END-TLV":
raise NotImplementedError(
"END-TLV is deprecated and not supported by pyrad3"
)
else:
raise ParseError(
self.filestack[-1], f"Invalid Token key {key}", line_num
)
def _parse_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the vendor definition"""
filename = self.filestack[-1]
if len(tokens) not in {3, 4}:
raise ParseError(
filename, "Incorrect number of tokens for vendor statement"
)
vendor_name = tokens[1]
vendor_id = int(tokens[2], 0)
continuation = False
# Parse optional vendor specification
try:
vendor_format = tokens[3].split("=")
if vendor_format[0] != "format":
raise ParseError(
filename,
f"Unknown option {vendor_format[0]} for vendor definition",
line_num,
)
try:
vendor_format = vendor_format[1].split(",")
t_len, l_len = (int(a) for a in vendor_format[:2])
if t_len not in {1, 2, 4}:
raise ParseError(
filename,
f'Invalid type length definition "{t_len}" for vendor {vendor_name}',
line_num,
)
if l_len not in {0, 1, 2}:
raise ParseError(
filename,
f'Invalid length definition "{l_len}" for vendor {vendor_name}',
line_num,
)
try:
if vendor_format[2] == "c":
if not vendor_name.upper() == "WIMAX":
# Not sure why, but FreeRADIUS has this limit,
# so we just do the same cause they know better than me
raise ParseError(
filename,
"continuation-bit is only supported for WiMAX",
line_num,
)
continuation = True
except IndexError:
pass
except ValueError:
raise ParseError(
filename,
f"Syntax error in specification for vendor {vendor_name}",
line_num,
)
except IndexError:
# no format definition
t_len, l_len = 1, 1
vendor = Vendor(vendor_name, vendor_id, t_len, l_len, continuation, {})
self.vendor_lookup_id_by_name[vendor_name] = vendor_id
self.vendor[vendor_id] = vendor
def _parse_begin_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the BEGIN-VENDOR line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if self.cur_vendor != self.rfc_vendor:
raise ParseError(
filename,
"vendor-begin sections are not allowed to be nested",
line_num,
)
if len(tokens) != 2:
raise ParseError(
filename,
"Incorrect number of tokens for begin-vendor statement",
line_num,
)
try:
vendor_id = self.vendor_lookup_id_by_name[tokens[1]]
self.cur_vendor = self.vendor[vendor_id]
except KeyError:
raise ParseError(
filename,
f"Unknown vendor {tokens[1]} in begin-vendor statement",
line_num,
)
def _parse_end_vendor(self, tokens: Sequence[str], line_num: int):
"""Parse the END-VENDOR line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if len(tokens) != 2:
raise ParseError(
filename,
"Incorrect number of tokens for end-vendor statement",
line_num,
)
if self.cur_vendor.name != tokens[1]:
raise ParseError(
filename,
f"Closing non-opened vendor {tokens[1]} in end-vendor statement",
line_num,
)
self.cur_vendor = self.rfc_vendor
def _parse_attribute_flags(
self, tokens: Sequence[str], line_num: int
) -> Tuple[bool, Encrypt]:
"""Parse Attribute flags of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
has_tag = False
encrypt = Encrypt.NoEncrpytion
try:
flags = [flag.split("=") for flag in tokens[4].split(",")]
except IndexError:
return False, Encrypt.NoEncrpytion
for flag in flags:
flag_len = len(flag)
if flag == 1:
value = None
elif flag_len == 2:
value = flag[1]
else:
raise ParseError(
filename, f"Incorrect attribute flag {flag}", line_num
)
key = flag[0]
if key == "has_tag":
has_tag = True
elif key == "encrypt":
try:
encrypt = Encrypt(int(value)) # type: ignore
except (ValueError, TypeError):
raise ParseError(
filename,
f"Illegal attribute encryption {value}",
line_num,
)
else:
raise ParseError(
filename, "Unknown attribute flag {key}", line_num
)
return has_tag, encrypt
def _parse_attribute(self, tokens: Sequence[str], line_num: int):
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if not len(tokens) in {4, 5}:
raise ParseError(
filename,
"Incorrect number of tokens for attribute definition",
line_num,
)
has_tag, encrypt = self._parse_attribute_flags(tokens, line_num)
name, code, datatype = tokens[1:4]
if datatype == "concat" and self.cur_vendor != self.rfc_vendor:
raise ParseError(
filename,
'vendor attributes are not allowed to have the datatype "concat"',
line_num,
)
try:
codes = _parse_attribute_code(code)
except ValueError:
raise ParseError(
filename, f'invalid attribute code {code}""', line_num
)
for code in codes:
tlength = self.cur_vendor.tlength
if 2 ** (8 * tlength) <= code:
raise ParseError(
filename,
f"attribute code is too big, must be smaller than 2**{tlength}",
line_num,
)
if code < 0:
raise ParseError(
filename,
"negative attribute codes are not allowed",
line_num,
)
# TODO: Do we some explicit handling of tlvs?
# if len(codes) > 1:
# self._parse_attribute_tlv(codes, line_num)
# else:
# pass
base_datatype = datatype.split("[")[0].replace("-", "")
try:
attribute_type = Datatype[base_datatype]
except KeyError:
raise ParseError(filename, f"Illegal type: {datatype}", line_num)
attribute = Attribute(
name,
codes[-1],
attribute_type,
has_tag,
encrypt,
len(codes) > 1,
{},
)
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
self.cur_vendor.attrs[attrcode] = attribute
if self.cur_vendor != self.rfc_vendor:
codes = tuple([26] + codes)
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
self.attrindex[attrcode] = attribute
self.attrindex[name] = attribute
def _parse_value(self, tokens: Sequence[str], line_num: int):
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
filename = self.filestack[-1]
if len(tokens) != 4:
raise ParseError(
filename,
"Incorrect number of tokens for VALUE definition",
line_num,
)
(attr_name, key, value) = tokens[1:]
value = _parse_number(value)
attribute = self.attrindex[attr_name]
try:
datatype = str(attribute.datatype).split(".")[1]
lmin, lmax = INTEGER_TYPES[datatype]
if value < lmin or value > lmax:
raise ParseError(
filename,
f"VALUE {key}({value}) is not in the limit of type {datatype}",
line_num,
)
except KeyError:
raise ParseError(
filename,
f"only attributes with integer typed datatypes can have"
f"value definitions {attribute.datatype}",
line_num,
)
attribute.values[value] = key
attribute.values[key] = value

46
src/pyrad3/host.py Normal file
View File

@@ -0,0 +1,46 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Interface Class for RADIUS Clients and Servers"""
from pyrad3.dictionary import Dictionary
from pyrad3 import packet
class Host: # pylint: disable=too-many-arguments
"""Interface Class for RADIUS Clients and Servers"""
def __init__(
self,
secret: bytes,
radius_dict: Dictionary,
authport: int = 1812,
acctport: int = 1813,
coaport: int = 3799,
timeout: float = 30,
retries: int = 3,
):
self.secret = secret
self.dictionary = radius_dict
self.authport = authport
self.acctport = acctport
self.coaport = coaport
self.timeout = timeout
self.retries = retries
def create_packet(self, **kwargs):
"""Create a generic RADIUS Packet"""
return packet.Packet(self, **kwargs)
def create_auth_packet(self, **kwargs):
"""Create an Authentictaion packet (request per default)"""
return packet.AuthPacket(self, **kwargs)
def create_acct_packet(self, **kwargs):
"""Create an Accounting packet (request per default)"""
return packet.AcctPacket(self, **kwargs)
def create_coa_packet(self, **kwargs):
"""Create an CoA packet (requset per default)"""
return packet.CoAPacket(self, **kwargs)

305
src/pyrad3/packet.py Normal file
View File

@@ -0,0 +1,305 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
from collections import OrderedDict
from enum import IntEnum
from secrets import token_bytes
from typing import Any, Dict, Optional, Sequence
import hashlib
import hmac
from pyrad3.host import Host
from pyrad3.utils import (
PacketError,
Attribute,
parse_header,
parse_attributes,
calculate_authenticator,
validate_pap_password,
validate_chap_password,
)
HMAC = hmac.new
# Packet codes
class Code(IntEnum):
AccessRequest = 1
AccessAccept = 2
AccessReject = 3
AccountingRequest = 4
AccountingResponse = 5
AccessChallenge = 11
StatusServer = 12
StatusClient = 13
DisconnectRequest = 40
DisconnectACK = 41
DisconnectNAK = 42
CoARequest = 43
CoAACK = 44
CoANAK = 45
class AuthError(Exception):
pass
class Packet(OrderedDict):
def __init__(
self,
host: Host,
code: Code,
radius_id: int,
*,
request: "Packet" = None,
**attributes
):
super().__init__(**attributes)
self.code = code
self.id = radius_id
self.host = host
self.request = request
self.ordered_attributes: Sequence[Attribute] = []
self.raw_packet: Optional[bytearray] = None
self.authenticator: Optional[bytes] = None
@staticmethod
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
# Can we do better than Any with type hinting?
attrs: Dict[str, Any] = {}
for attr in ordered_attrs:
try:
attrs[attr.name].append(attr.value)
except KeyError:
attrs[attr.name] = [attr.value]
parsed_packet = Packet(host, code, radius_id, **attrs)
parsed_packet.authenticator = authenticator
parsed_packet.raw_packet = raw_packet
parsed_packet.ordered_attributes = ordered_attrs
return parsed_packet
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
self.verify_reply(raw_packet)
reply = Packet.from_raw(self.host, raw_packet)
reply.request = self
try:
if not reply.validate_message_authenticator():
raise PacketError("Packet has a wrong message authenticator")
except KeyError:
if "EAP-Message" in reply:
raise PacketError("Packet is missing a message authenticator")
return reply
def send(self):
"""Send the packet to the Client/Server.
"""
self.host._send_packet(self)
def verify_reply(self, raw_reply: bytes):
"""Verify the reply to this packet.
"""
if self.id != raw_reply[1]:
raise PacketError("Response has a wrong id")
# self.authenticator MUST be set, this packet got send so by definitation
# self.authenticator will not be non, but bytes
radius_hash = calculate_authenticator(
self.host.secret,
self.authenticator, # type: ignore
raw_reply,
)
if radius_hash != raw_reply[4:20]:
raise PacketError("Reply Packet has a wrong authenticator")
def validate_message_authenticator(self):
message_authenticator = self["Message-Authenticator"]
if isinstance(list, message_authenticator):
# There are multiple Message Authenticators, but a packet MUST NOT have
# more than one
return False
ma_attribute = self.find_first_attribute("Message-Authenticator")
generated = self._generate_message_authenticator(ma_attribute)
return message_authenticator == generated
def _generate_message_authenticator(self, ma_attr: Attribute):
assert self.authenticator is not None
assert self.request is not None
assert self.request.authenticator is not None
assert self.raw_packet is not None
# The message authenticator must be treated as 16 * \00
start_pos = ma_attr.pos + 2
end_pos = start_pos + 16
original_ma: bytes = ma_attr.value
self.raw_packet[start_pos:end_pos] = 16 * b"\00"
hmac_builder = HMAC(self.host.secret, digestmod=hashlib.md5)
hmac_builder.update(self.raw_packet)
if self.code in (Code.AccessRequest, Code.StatusServer):
hmac_builder.update(self.authenticator)
elif self.code in (
Code.AccessAccept,
Code.AccessChallenge,
Code.AccessReject,
):
hmac_builder.update(self.request.authenticator)
else:
hmac_builder.update(16 * b"\00")
hmac_builder.update(self.raw_packet[20:])
self.raw_packet[start_pos:end_pos] = original_ma
return hmac_builder.digest()
def add_message_authenticator(self):
self._encode_packet()
self._generate_message_authenticator(self)
try:
# quick lookup before we iterate over the whole packet
_ = self["Message-Authenticator"]
attr = self.find_first_attribute("Message-Authenticator")
except KeyError:
self["Message-Authenticator"] = 16 * b"\00"
attr = self.ordered_attributes[-1]
generated = self._generate_message_authenticator(attr)
self[attr.pos + 2 :] = generated
def refresh_message_authenticator(self):
self.add_message_authenticator()
def find_first_attribute(self, attr_type_name: str) -> Attribute:
for attr in self.ordered_attributes:
if attr.type == attr_type_name:
return attr.type
raise KeyError
def _encode_packet(self):
self.raw_packet = None
class AuthPacket(Packet):
def __init__(
self,
host: Host,
radius_id: int,
auth_type,
*,
code: Code = Code.AccessRequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
self.auth_type = auth_type
if code == Code.AccessRequest:
self.authenticator = token_bytes(16)
def create_accept(self, **attributes):
return AuthPacket(
self.host,
self.id,
self.auth_type,
request=self,
code=Code.AccessAccept,
**attributes
)
def create_reject(self, **attributes):
return AuthPacket(
self.host,
self.id,
self.auth_type,
request=self,
code=Code.AccessReject,
**attributes
)
def create_challange(self, **attributes):
return AuthPacket(
self.host,
self.id,
self.auth_type,
request=self,
code=Code.AccessChallenge,
**attributes
)
def validate_password(self, password: bytes) -> bool:
try:
return self.validate_pap(password)
except KeyError:
pass
# Will throw KeyError if no chap password exists
return self.validate_chap(password)
def validate_pap(self, password: bytes) -> bool:
packet_password = self["User-Password"]
return validate_pap_password(
self.host.secret,
self.authenticator, # type: ignore
packet_password,
password,
)
def validate_chap(self, password: bytes) -> bool:
packet_password = self["Chap-Password"]
chap_id = packet_password[:1]
chap_password = packet_password[1:]
try:
challenge = self["Chap-Challenge"]
except KeyError:
challenge = self.authenticator
return validate_chap_password(
chap_id, challenge, chap_password, password, # type: ignore
)
class AcctPacket(Packet):
def __init__(
self,
host: Host,
radius_id: int,
*,
code: Code = Code.AccountingRequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_response(self, **attributes):
return AcctPacket(
self.host,
self.id,
code=Code.AccountingResponse,
request=self,
**attributes
)
class CoAPacket(Packet):
def __init__(
self,
host: Host,
radius_id: int,
*,
code: Code = Code.CoARequest,
request: Optional[Packet] = None,
**attributes
):
super().__init__(host, code, radius_id, request=request, **attributes)
def create_ack(self, **attributes):
return CoAPacket(
self.host, self.id, code=Code.CoAACK, request=self, **attributes
)
def create_nack(self, **attributes):
return CoAPacket(
self.host, self.id, code=Code.CoANAK, request=self, **attributes
)

View File

@@ -26,7 +26,8 @@ class Proxy(Server):
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
self._poll.register(
self._proxyfd.fileno(),
(select.POLLIN | select.POLLPRI | select.POLLERR))
(select.POLLIN | select.POLLPRI | select.POLLERR),
)
def _handle_proxy_packet(self, pkt):
"""Process a packet received on the reply socket.
@@ -38,12 +39,15 @@ class Proxy(Server):
:type pkt: Packet class instance
"""
if pkt.source[0] not in self.hosts:
raise ServerPacketError('Received packet from unknown host')
raise ServerPacketError("Received packet from unknown host")
pkt.secret = self.hosts[pkt.source[0]].secret
if pkt.code not in [packet.AccessAccept, packet.AccessReject,
packet.AccountingResponse]:
raise ServerPacketError('Received non-response on proxy socket')
if pkt.code not in [
packet.AccessAccept,
packet.AccessReject,
packet.AccountingResponse,
]:
raise ServerPacketError("Received non-response on proxy socket")
def _process_input(self, fd):
"""Process available data.
@@ -62,7 +66,8 @@ class Proxy(Server):
"""
if fd.fileno() == self._proxyfd.fileno():
pkt = self._grab_packet(
lambda data, s=self: s.CreatePacket(packet=data), fd)
lambda data, s=self: s.CreatePacket(packet=data), fd
)
self._handle_proxy_packet(pkt)
else:
Server._process_input(self, fd)

View File

@@ -9,13 +9,15 @@ from pyrad import host
from pyrad import packet
LOGGER = logging.getLogger('pyrad')
LOGGER = logging.getLogger("pyrad")
class RemoteHost:
"""Remote RADIUS capable host we can talk to."""
def __init__(self, address, secret, name, authport=1812, acctport=1813, coaport=3799):
def __init__(
self, address, secret, name, authport=1812, acctport=1813, coaport=3799
):
"""Constructor.
:param address: IP address
@@ -62,10 +64,21 @@ class Server(host.Host):
:cvar MaxPacketSize: maximum size of a RADIUS packet
:type MaxPacketSize: integer
"""
MaxPacketSize = 4096
def __init__(self, addresses=[], authport=1812, acctport=1813, coaport=3799,
hosts=None, dict=None, auth_enabled=True, acct_enabled=True, coa_enabled=False):
def __init__(
self,
addresses=[],
authport=1812,
acctport=1813,
coaport=3799,
hosts=None,
dict=None,
auth_enabled=True,
acct_enabled=True,
coa_enabled=False,
):
"""Constructor.
:param addresses: IP addresses to listen on
@@ -114,7 +127,7 @@ class Server(host.Host):
"""
results = set()
try:
tmp = socket.getaddrinfo(addr, 'www')
tmp = socket.getaddrinfo(addr, "www")
except socket.gaierror:
return []
@@ -199,10 +212,10 @@ class Server(host.Host):
"""
if pkt.source[0] in self.hosts:
pkt.secret = self.hosts[pkt.source[0]].secret
elif '0.0.0.0' in self.hosts:
pkt.secret = self.hosts['0.0.0.0'].secret
elif "0.0.0.0" in self.hosts:
pkt.secret = self.hosts["0.0.0.0"].secret
else:
raise ServerPacketError('Received packet from unknown host')
raise ServerPacketError("Received packet from unknown host")
def _handle_auth_packet(self, pkt):
"""Process a packet received on the authentication port.
@@ -216,7 +229,8 @@ class Server(host.Host):
self._add_secret(pkt)
if pkt.code != packet.AccessRequest:
raise ServerPacketError(
'Received non-authentication packet on authentication port')
"Received non-authentication packet on authentication port"
)
self.HandleAuthPacket(pkt)
def _handle_acct_packet(self, pkt):
@@ -229,10 +243,13 @@ class Server(host.Host):
:type pkt: Packet class instance
"""
self._add_secret(pkt)
if pkt.code not in [packet.AccountingRequest,
packet.AccountingResponse]:
if pkt.code not in [
packet.AccountingRequest,
packet.AccountingResponse,
]:
raise ServerPacketError(
'Received non-accounting packet on accounting port')
"Received non-accounting packet on accounting port"
)
self.HandleAcctPacket(pkt)
def _handle_coa_packet(self, pkt):
@@ -251,7 +268,7 @@ class Server(host.Host):
elif pkt.code == packet.DisconnectRequest:
self.HandleDisconnectPacket(pkt)
else:
raise ServerPacketError('Received non-coa packet on coa port')
raise ServerPacketError("Received non-coa packet on coa port")
def _grab_packet(self, pktgen, fd):
"""Read a packet from a network connection.
@@ -273,7 +290,9 @@ class Server(host.Host):
"""
for fd in self.authfds + self.acctfds + self.coafds:
self._fdmap[fd.fileno()] = fd
self._poll.register(fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR)
self._poll.register(
fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR
)
if self.auth_enabled:
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
if self.acct_enabled:
@@ -307,16 +326,22 @@ class Server(host.Host):
:type fd: socket class instance
"""
if self.auth_enabled and fd.fileno() in self._realauthfds:
pkt = self._grab_packet(lambda data, s=self: s.CreateAuthPacket(packet=data), fd)
pkt = self._grab_packet(
lambda data, s=self: s.CreateAuthPacket(packet=data), fd
)
self._handle_auth_packet(pkt)
elif self.acct_enabled and fd.fileno() in self._realacctfds:
pkt = self._grab_packet(lambda data, s=self: s.CreateAcctPacket(packet=data), fd)
pkt = self._grab_packet(
lambda data, s=self: s.CreateAcctPacket(packet=data), fd
)
self._handle_acct_packet(pkt)
elif self.coa_enabled:
pkt = self._grab_packet(lambda data, s=self: s.CreateCoAPacket(packet=data), fd)
pkt = self._grab_packet(
lambda data, s=self: s.CreateCoAPacket(packet=data), fd
)
self._handle_coa_packet(pkt)
else:
raise ServerPacketError('Received packet for unknown handler')
raise ServerPacketError("Received packet for unknown handler")
def Run(self):
"""Main loop.
@@ -335,8 +360,8 @@ class Server(host.Host):
fdo = self._fdmap[fd]
self._process_input(fdo)
except ServerPacketError as err:
LOGGER.info('Dropping packet: %s', err)
LOGGER.info("Dropping packet: %s", err)
except packet.PacketError as err:
LOGGER.info('Received a broken packet: %s', err)
LOGGER.info("Received a broken packet: %s", err)
else:
LOGGER.error('Unexpected event in server main loop')
LOGGER.error("Unexpected event in server main loop")

View File

@@ -9,26 +9,37 @@ from abc import abstractmethod, ABCMeta
from enum import Enum
from datetime import datetime
from pyrad.packet import (
Packet, AccessAccept, AccessReject,
AccountingRequest, AccountingResponse,
DisconnectACK, DisconnectNAK, DisconnectRequest, CoARequest,
CoAACK, CoANAK, AccessRequest, AuthPacket, AcctPacket, CoAPacket,
PacketError
Packet,
AccessAccept,
AccessReject,
AccountingRequest,
AccountingResponse,
DisconnectACK,
DisconnectNAK,
DisconnectRequest,
CoARequest,
CoAACK,
CoANAK,
AccessRequest,
AuthPacket,
AcctPacket,
CoAPacket,
PacketError,
)
from pyrad.server import ServerPacketError
class ServerType(Enum):
Auth = 'Authentication'
Acct = 'Accounting'
Coa = 'Coa'
Auth = "Authentication"
Acct = "Accounting"
Coa = "Coa"
class DatagramProtocolServer(asyncio.Protocol):
def __init__(self, ip, port, logger, server, server_type, hosts,
request_callback):
def __init__(
self, ip, port, logger, server, server_type, hosts, request_callback
):
self.transport = None
self.ip = ip
self.port = port
@@ -40,94 +51,149 @@ class DatagramProtocolServer(asyncio.Protocol):
def connection_made(self, transport):
self.transport = transport
self.logger.info('[%s:%d] Transport created', self.ip, self.port)
self.logger.info("[%s:%d] Transport created", self.ip, self.port)
def connection_lost(self, exc):
if exc:
self.logger.warn('[%s:%d] Connection lost: %s', self.ip, self.port, str(exc))
self.logger.warn(
"[%s:%d] Connection lost: %s", self.ip, self.port, str(exc)
)
else:
self.logger.info('[%s:%d] Transport closed', self.ip, self.port)
self.logger.info("[%s:%d] Transport closed", self.ip, self.port)
def send_response(self, reply, addr):
self.transport.sendto(reply.ReplyPacket(), addr)
def datagram_received(self, data, addr):
self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr)
self.logger.debug(
"[%s:%d] Received %d bytes from %s",
self.ip,
self.port,
len(data),
addr,
)
receive_date = datetime.utcnow()
if addr[0] in self.hosts:
remote_host = self.hosts[addr[0]]
elif '0.0.0.0' in self.hosts:
remote_host = self.hosts['0.0.0.0']
elif "0.0.0.0" in self.hosts:
remote_host = self.hosts["0.0.0.0"]
else:
self.logger.warn('[%s:%d] Drop package from unknown source %s', self.ip, self.port, addr)
self.logger.warn(
"[%s:%d] Drop package from unknown source %s",
self.ip,
self.port,
addr,
)
return
try:
self.logger.debug('[%s:%d] Received from %s packet: %s', self.ip, self.port, addr, data.hex())
self.logger.debug(
"[%s:%d] Received from %s packet: %s",
self.ip,
self.port,
addr,
data.hex(),
)
req = Packet(packet=data, dict=self.server.dict)
except Exception as exc:
self.logger.error('[%s:%d] Error on decode packet: %s', self.ip, self.port, exc)
self.logger.error(
"[%s:%d] Error on decode packet: %s", self.ip, self.port, exc
)
return
try:
if req.code in (AccountingResponse, AccessAccept, AccessReject, CoANAK, CoAACK, DisconnectNAK, DisconnectACK):
raise ServerPacketError(f'Invalid response packet {req.code}')
if req.code in (
AccountingResponse,
AccessAccept,
AccessReject,
CoANAK,
CoAACK,
DisconnectNAK,
DisconnectACK,
):
raise ServerPacketError(f"Invalid response packet {req.code}")
elif self.server_type == ServerType.Auth:
if req.code != AccessRequest:
raise ServerPacketError('Received non-auth packet on auth port')
req = AuthPacket(secret=remote_host.secret,
dict=self.server.dict,
packet=data)
raise ServerPacketError(
"Received non-auth packet on auth port"
)
req = AuthPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyAuthRequest():
raise PacketError('Packet verification failed')
raise PacketError("Packet verification failed")
elif self.server_type == ServerType.Coa:
if req.code != DisconnectRequest and req.code != CoARequest:
raise ServerPacketError('Received non-coa packet on coa port')
req = CoAPacket(secret=remote_host.secret,
dict=self.server.dict,
packet=data)
raise ServerPacketError(
"Received non-coa packet on coa port"
)
req = CoAPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyCoARequest():
raise PacketError('Packet verification failed')
raise PacketError("Packet verification failed")
elif self.server_type == ServerType.Acct:
if req.code != AccountingRequest:
raise ServerPacketError('Received non-acct packet on acct port')
req = AcctPacket(secret=remote_host.secret,
dict=self.server.dict,
packet=data)
raise ServerPacketError(
"Received non-acct packet on acct port"
)
req = AcctPacket(
secret=remote_host.secret,
dict=self.server.dict,
packet=data,
)
if self.server.enable_pkt_verify:
if req.VerifyAcctRequest():
raise PacketError('Packet verification failed')
raise PacketError("Packet verification failed")
# Call request callback
self.request_callback(self, req, addr)
except Exception as exc:
if self.server.debug:
self.logger.exception('[%s:%d] Error for packet from %s', self.ip, self.port, addr)
self.logger.exception(
"[%s:%d] Error for packet from %s", self.ip, self.port, addr
)
else:
self.logger.error('[%s:%d] Error for packet from %s: %s', self.ip, self.port, addr, exc)
self.logger.error(
"[%s:%d] Error for packet from %s: %s",
self.ip,
self.port,
addr,
exc,
)
process_date = datetime.utcnow()
self.logger.debug('[%s:%d] Request from %s processed in %d ms', self.ip, self.port, addr, (process_date-receive_date).microseconds/1000)
self.logger.debug(
"[%s:%d] Request from %s processed in %d ms",
self.ip,
self.port,
addr,
(process_date - receive_date).microseconds / 1000,
)
def error_received(self, exc):
self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc)
self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc)
async def close_transport(self):
if self.transport:
self.logger.debug('[%s:%d] Close transport...', self.ip, self.port)
self.logger.debug("[%s:%d] Close transport...", self.ip, self.port)
self.transport.close()
self.transport = None
def __str__(self):
return f'DatagramProtocolServer(ip={self.ip}, port={self.port})'
return f"DatagramProtocolServer(ip={self.ip}, port={self.port})"
# Used as protocol_factory
def __call__(self):
@@ -135,12 +201,18 @@ class DatagramProtocolServer(asyncio.Protocol):
class ServerAsync(metaclass=ABCMeta):
def __init__(self, auth_port=1812, acct_port=1813,
coa_port=3799, hosts=None, dictionary=None,
loop=None, logger_name='pyrad',
enable_pkt_verify=False,
debug=False):
def __init__(
self,
auth_port=1812,
acct_port=1813,
coa_port=3799,
hosts=None,
dictionary=None,
loop=None,
logger_name="pyrad",
enable_pkt_verify=False,
debug=False,
):
if not loop:
self.loop = asyncio.get_event_loop()
@@ -174,20 +246,35 @@ class ServerAsync(metaclass=ABCMeta):
self.handle_acct_packet(protocol, req, addr)
elif protocol.server_type == ServerType.Auth:
self.handle_auth_packet(protocol, req, addr)
elif protocol.server_type == ServerType.Coa and \
req.code == CoARequest:
elif (
protocol.server_type == ServerType.Coa
and req.code == CoARequest
):
self.handle_coa_packet(protocol, req, addr)
elif protocol.server_type == ServerType.Coa and \
req.code == DisconnectRequest:
elif (
protocol.server_type == ServerType.Coa
and req.code == DisconnectRequest
):
self.handle_disconnect_packet(protocol, req, addr)
else:
self.logger.error('[%s:%s] Unexpected request found', protocol.ip, protocol.port)
self.logger.error(
"[%s:%s] Unexpected request found",
protocol.ip,
protocol.port,
)
except Exception as exc:
if self.debug:
self.logger.exception('[%s:%s] Unexpected error', protocol.ip, protocol.port)
self.logger.exception(
"[%s:%s] Unexpected error", protocol.ip, protocol.port
)
else:
self.logger.error('[%s:%s] Unexpected error: %s', protocol.ip, protocol.port, exc)
self.logger.error(
"[%s:%s] Unexpected error: %s",
protocol.ip,
protocol.port,
exc,
)
def __is_present_proto__(self, ip, port):
if port == self.auth_port:
@@ -217,87 +304,92 @@ class ServerAsync(metaclass=ABCMeta):
reply = pkt.CreateReply(**attributes)
return reply
async def initialize_transports(self, enable_acct=False,
enable_auth=False, enable_coa=False,
addresses=None):
async def initialize_transports(
self,
enable_acct=False,
enable_auth=False,
enable_coa=False,
addresses=None,
):
task_list = []
if not enable_acct and not enable_auth and not enable_coa:
raise Exception('No transports selected')
raise Exception("No transports selected")
if not addresses or len(addresses) == 0:
addresses = ['127.0.0.1']
addresses = ["127.0.0.1"]
# noinspection SpellCheckingInspection
for addr in addresses:
if enable_acct and not self.__is_present_proto__(addr, self.acct_port):
if enable_acct and not self.__is_present_proto__(
addr, self.acct_port
):
protocol_acct = DatagramProtocolServer(
addr,
self.acct_port,
self.logger, self,
self.logger,
self,
ServerType.Acct,
self.hosts,
self.__request_handler__
self.__request_handler__,
)
bind_addr = (addr, self.acct_port)
acct_connect = self.loop.create_datagram_endpoint(
protocol_acct,
reuse_port=True,
local_addr=bind_addr
protocol_acct, reuse_port=True, local_addr=bind_addr
)
self.acct_protocols.append(protocol_acct)
task_list.append(acct_connect)
if enable_auth and not self.__is_present_proto__(addr, self.auth_port):
if enable_auth and not self.__is_present_proto__(
addr, self.auth_port
):
protocol_auth = DatagramProtocolServer(
addr,
self.auth_port,
self.logger, self,
self.logger,
self,
ServerType.Auth,
self.hosts,
self.__request_handler__
self.__request_handler__,
)
bind_addr = (addr, self.auth_port)
auth_connect = self.loop.create_datagram_endpoint(
protocol_auth,
reuse_port=True,
local_addr=bind_addr
protocol_auth, reuse_port=True, local_addr=bind_addr
)
self.auth_protocols.append(protocol_auth)
task_list.append(auth_connect)
if enable_coa and not self.__is_present_proto__(addr, self.coa_port):
if enable_coa and not self.__is_present_proto__(
addr, self.coa_port
):
protocol_coa = DatagramProtocolServer(
addr,
self.coa_port,
self.logger, self,
self.logger,
self,
ServerType.Coa,
self.hosts,
self.__request_handler__
self.__request_handler__,
)
bind_addr = (addr, self.coa_port)
coa_connect = self.loop.create_datagram_endpoint(
protocol_coa,
reuse_port=True,
local_addr=bind_addr
protocol_coa, reuse_port=True, local_addr=bind_addr
)
self.coa_protocols.append(protocol_coa)
task_list.append(coa_connect)
await asyncio.ensure_future(
asyncio.gather(
*task_list,
return_exceptions=False,
),
loop=self.loop
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
)
# noinspection SpellCheckingInspection
async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True):
async def deinitialize_transports(
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
):
if deinit_coa:
for proto in self.coa_protocols:

243
src/pyrad3/tools.py Normal file
View File

@@ -0,0 +1,243 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Collections of functions to en- and decode RADIUS Attributes"""
from typing import Union
from ipaddress import IPv4Address, IPv6Address, IPv6Network, ip_network, ip_address
import struct
def encode_string(string: str) -> bytes:
"""Encode a RADIUS value of type string"""
if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters")
if isinstance(string, str):
return string.encode("utf-8")
return string
def encode_octets(string: bytes) -> bytes:
"""Encode a RADIUS value of type octet"""
if len(string) > 253:
raise ValueError("Can only encode strings of <= 253 characters")
return string
def encode_address(addr: Union[str, IPv4Address]) -> bytes:
"""Encode a RADIUS value of type ipaddr"""
return IPv4Address(addr).packed
def encode_ipv6_prefix(addr: Union[str, IPv6Network]) -> bytes:
"""Encode a RADIUS value of type ipv6prefix"""
address = IPv6Network(addr)
return struct.pack("2B", *[0, address.prefixlen]) + address.network_address.packed
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type ipv6addr"""
return IPv6Address(addr).packed
def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes:
"""Encode a RADIUS value of type combo-ip"""
return ip_address(addr).packed
def encode_ascend_binary(string: str) -> bytes:
"""
struct_format: List of type=value pairs sperated by spaces.
Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'
Type:
family ipv4(default) or ipv6
action discard(default) or accept
direction in(default) or out
src source prefix (default ignore)
dst destination prefix (default ignore)
proto protocol number / next-header number (default ignore)
sport source port (default ignore)
dport destination port (default ignore)
sportq source port qualifier (default 0)
dportq destination port qualifier (default 0)
Source/Destination Port Qualifier:
0 no compare
1 less than
2 equal to
3 greater than
4 not equal to
"""
terms = {
"family": b"\x01",
"action": b"\x00",
"direction": b"\x01",
"src": b"\x00\x00\x00\x00",
"dst": b"\x00\x00\x00\x00",
"srcl": b"\x00",
"dstl": b"\x00",
"proto": b"\x00",
"sport": b"\x00\x00",
"dport": b"\x00\x00",
"sportq": b"\x00",
"dportq": b"\x00",
}
for term in string.split(" "):
key, value = term.split("=")
if key == "family" and value == "ipv6":
terms[key] = b"\x03"
if terms["src"] == b"\x00\x00\x00\x00":
terms["src"] = 16 * b"\x00"
if terms["dst"] == b"\x00\x00\x00\x00":
terms["dst"] = 16 * b"\x00"
elif key == "action" and value == "accept":
terms[key] = b"\x01"
elif key == "direction" and value == "out":
terms[key] = b"\x00"
elif key in ("src", "dst"):
address = ip_network(value)
terms[key] = address.network_address.packed
terms[key + "l"] = struct.pack("B", address.prefixlen)
elif key in ("sport", "dport"):
terms[key] = struct.pack("!H", int(value))
elif key in ("sportq", "dportq", "proto"):
terms[key] = struct.pack("B", int(value))
trailer = 8 * b"\x00"
result = b"".join(
(
terms["family"],
terms["action"],
terms["direction"],
b"\x00",
terms["src"],
terms["dst"],
terms["srcl"],
terms["dstl"],
terms["proto"],
b"\x00",
terms["sport"],
terms["dport"],
terms["sportq"],
terms["dportq"],
b"\x00\x00",
trailer,
)
)
return result
def encode_integer(num: Union[int, str], struct_format="!I") -> bytes:
"""Encode a RADIUS value of some type integer"""
return struct.pack(struct_format, int(num))
def encode_date(num: Union[int, str]) -> bytes:
"""Encode a RADIUS value of type date"""
return struct.pack("!I", int(num))
def decode_string(string: bytes) -> Union[str, bytes]:
"""Decode a RADIUS value of type string"""
try:
return string.decode("utf-8")
except UnicodeDecodeError:
return string
def decode_octets(string: bytes) -> bytes:
"""Decode a RADIUS value of type octet"""
return string
def decode_address(addr: bytes) -> IPv4Address:
"""Decode a RADIUS value of type ipaddr"""
return IPv4Address(addr)
def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
"""Decode a RADIUS value of type ipv6prefix"""
addr = addr + b"\x00" * (18 - len(addr))
prefix = addr[:2]
addr = addr[2:]
return IPv6Network((prefix, addr))
def decode_ipv6_address(addr: bytes) -> IPv6Address:
"""Decode a RADIUS value of type ipv6addr"""
addr = addr + b"\x00" * (16 - len(addr))
return IPv6Address(addr)
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
"""Decode a RADIUS value of type combo-ip"""
return ip_address(addr).packed
def decode_ascend_binary(string):
"""Decode a RADIUS value of type abinary"""
raise NotImplementedError
def decode_integer(num: bytes, struct_format="!I") -> int:
"""Decode a RADIUS value of some integer type"""
return (struct.unpack(struct_format, num))[0]
def decode_date(num):
"""Decode a RADIUS value of type date"""
return (struct.unpack("!I", num))[0]
ENCODE_MAP = {
"string": encode_string,
"octets": encode_octets,
"integer": encode_integer,
"ipaddr": encode_address,
"ipv6prefix": encode_ipv6_prefix,
"ipv6addr": encode_ipv6_address,
"abinary": encode_ascend_binary,
"signed": lambda value: encode_integer(value, "!i"),
"short": lambda value: encode_integer(value, "!H"),
"byte": lambda value: encode_integer(value, "!B"),
"integer64": lambda value: encode_integer(value, '!Q'),
"date": encode_date,
}
def encode_attr(datatype, value):
"""Encode a RADIUS attribute"""
try:
return ENCODE_MAP[datatype](value)
except KeyError:
raise ValueError(f"Unknown attribute type {datatype}")
DECODE_MAP = {
"string": decode_string,
"octets": decode_octets,
"integer": decode_integer,
"ipaddr": decode_address,
"ipv6prefix": decode_ipv6_prefix,
"ipv6addr": decode_ipv6_address,
"abinary": decode_ascend_binary,
"signed": lambda value: decode_integer(value, "!i"),
"short": lambda value: decode_integer(value, "!H"),
"byte": lambda value: decode_integer(value, "!B"),
"integer64": lambda value: decode_integer(value, "!Q"),
"date": decode_date,
}
def decode_attr(datatype, value):
"""Decode a RADIUS attribute"""
try:
return DECODE_MAP[datatype](value)
except KeyError:
raise ValueError(f"Unknown attribute type {datatype}")

234
src/pyrad3/utils.py Normal file
View File

@@ -0,0 +1,234 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
"""Collection of functions to deal with RADIUS packet en- and decoding."""
from collections import namedtuple
from typing import List, Union
import hashlib
import secrets
import struct
from pyrad3.dictionary import Dictionary
RANDOM_GENERATOR = secrets.SystemRandom()
MD5 = hashlib.md5
class PacketError(Exception):
"""Exception for Invalid Packets"""
Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"])
Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"])
def parse_header(raw_packet: bytes) -> Header:
"""Parse the Header of a RADIUS Packet."""
try:
header = struct.unpack("!BBH16s", raw_packet)
except struct.error:
raise PacketError("Packet header is corrupt")
length = header[3]
if len(raw_packet) != length:
raise PacketError(
f"RADIUS Packet ({len(raw_packet)}) has an invalid length ({length})"
)
if length > 4096:
raise PacketError(f"Packet length is too big ({length})")
return Header(*header)
def parse_attributes(
rad_dict: Dictionary, raw_packet: bytes
) -> List[Attribute]:
"""Parse the Attributes in a RADIUS Packet.
This function skips the Header. The Header must be parsed and verified
separately.
"""
attributes = []
packet = raw_packet[20:]
while packet:
try:
(key, length) = struct.unpack("!BB", packet[0:2])
except struct.error:
raise PacketError("Attribute header is corrupt")
if length < 2:
raise PacketError(f"Attribute length ({length}) is too small")
value = packet[2:length]
offset = len(raw_packet) - len(packet) + length
if key == 26:
try:
attributes.extend(
parse_vendor_attributes(rad_dict, offset, value)
)
except (PacketError, IndexError):
attributes.append(
Attribute(
name="Unknown-Vendor-Attribute",
pos=offset,
type="octets",
length=int(packet[1]),
value=packet[2:],
)
)
else:
key = parse_key(rad_dict, key)
attributes.extend(parse_value(rad_dict, key, offset, value))
packet = packet[length:]
return attributes
def parse_vendor_attributes(
rad_dict: Dictionary, offset: int, vendor_value: bytes
) -> List[Attribute]:
"""Parse A Vendor Attribute"""
if len(vendor_value) < 4:
raise PacketError
vendor_id = int.from_bytes(vendor_value[:4], "big")
vendor_dict = rad_dict.vendor_by_id[vendor_id]
vendor_name = vendor_dict.name
attributes = []
vendor_tlv = vendor_value[4:]
while vendor_tlv:
try:
(key, length) = struct.unpack("!BB", vendor_tlv[0:2])
except struct.error:
attribute = [
Attribute(
name=f"Unknown-{vendor_name}-Attribute",
pos=offset - len(vendor_value),
type="octets",
length=len(vendor_value) - 4,
value=vendor_value,
)
]
else:
offset = offset - len(vendor_tlv) + length
key = parse_key(vendor_dict, key)
attribute = parse_value(vendor_dict, key, offset, vendor_tlv)
attributes.extend(attribute)
vendor_tlv = vendor_tlv[length:]
return attributes
def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]:
"""Parse the key in the Dictionary Context"""
try:
return rad_dict.attrs[key_id].name
except KeyError:
return key_id
def parse_value(
rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
) -> List[Attribute]:
"""Parse the Value in the given Key/Dictionary Context"""
return []
def calculate_authenticator(
secret: bytes, authenticator: bytes, raw_packet: bytes
) -> bytes:
"""Calculate the Authenticator for the RADIUS Packet"""
return MD5(
raw_packet[0:4] + authenticator + raw_packet[20:] + secret
).digest()
def validate_pap_password(
secret: bytes,
authenticator: bytes,
obfuscated_password: bytes,
plaintext_password: bytes,
) -> bool:
"""Check if the plaintext and the RADIUS passwords match
This function does not "decrypt" the received password.
"""
obf_pass = password_encode(secret, authenticator, plaintext_password)
return obfuscated_password == obf_pass
def password_encode(
secret: bytes, authenticator: bytes, password: bytes
) -> bytes:
"""Obfuscate the plaintext Password for RADIUS"""
password += b"\x00" * (16 - (len(password) % 16))
return obfuscation_algorithm(secret, authenticator, password)
def password_decode(
secret: bytes, authenticator: bytes, obfuscated_password: bytes
) -> str:
"""Reverse the RADIUS obfuscation on a given password
The password password is padded with \\x00 to a 16 byte boundary. The padding will
be removed by this function.
If the original password had some trailing \\x00 it will get lost. Therefore it is
not recommended to use (trailing) \\x00 in passwords.
"""
deobfuscated = obfuscation_algorithm(
secret, authenticator, obfuscated_password
)
return deobfuscated.rstrip(b"\x00").decode("utf-8")
def obfuscation_algorithm(
secret: bytes, authenticator: bytes, password: bytes
) -> bytes:
"""Obfuscate the plaintext password.
This function does not deal with the padding (which the
RADIUS Protocol requires.)
The User has to pad the password themself, or better use
the `password_encode` or `password_decode` function.
"""
result = b""
buf = password
last = authenticator
while buf:
cur_hash = MD5(secret + last).digest()
for cbuf, chash in zip(buf, cur_hash):
result += bytes([cbuf ^ chash])
(last, buf) = (buf[:16], buf[16:])
return result
def validate_chap_password(
chap_id: bytes,
challenge: bytes,
chap_password: bytes,
plaintext_password: bytes,
) -> bool:
"""Validate the CHAP password against the given plaintext password"""
return (
chap_password == MD5(chap_id + plaintext_password + challenge).digest()
)
def salt_encrypt(secret: bytes, authenticator: bytes, value: bytes) -> bytes:
"""Salt Encrypt the given value"""
# The highest bit MUST be 1
random_value = RANDOM_GENERATOR.randrange(32768, 65535)
salt = struct.pack("!H", random_value)
salted_auth = authenticator + salt
return obfuscation_algorithm(secret, salted_auth, value)
def salt_decrypt(
secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes
) -> bytes:
"""Decrypt the given value"""
salted_auth = authenticator + salt
return obfuscation_algorithm(secret, salted_auth, encrypted_value)

View File

@@ -0,0 +1 @@
$INCLUDE ./other_mutual_recursive

View File

@@ -0,0 +1 @@
$INCLUDE mutual_recursive

View File

@@ -0,0 +1 @@
$INCLUDE ./self_recursive

190
tests/test_dictionary.py Normal file
View File

@@ -0,0 +1,190 @@
# Copyright 2020 Istvan Ruzman
# SPDX-License-Identifier: MIT OR Apache-2.0
from io import StringIO
import pytest
from pyrad3.dictionary import Dictionary, ParseError
@pytest.mark.parametrize(
"filename", ["dictionaries/self_recursive", "dictionaries/mutual_recursive"]
)
def test_dictionary_recursion(filename):
with pytest.raises(ParseError):
Dictionary("tests/" + filename)
@pytest.mark.parametrize(
"line",
[
"$INCLUDE",
"BEGIN-VENDOR",
"END-VENDOR",
"VENDOR",
"VENDOR NAME",
"ATTRIBUTE",
"ATTRIBUTE NAME",
"VALUE",
"VALUE ATTRNAME",
"VALUE ATTRNAME VALUENAME",
],
)
def test_lines_missing_tokens(line):
dictionary = StringIO(line)
with pytest.raises(ParseError):
Dictionary("", dictionary)
@pytest.mark.parametrize(
"vendor",
[
"VENDOR test 1234",
"VENDOR test 1234 format=1,1",
"VENDOR test 1234 format=2,2",
"VENDOR test 1234 format=1,2",
"VENDOR test 1234 format=4,2",
"VENDOR test 1234 format=4,0",
"VENDOR WiMAX 1234 format=1,1,c",
],
)
def test_valid_vendor_definitions(vendor):
dictionary = StringIO(vendor)
Dictionary("", dictionary)
@pytest.mark.parametrize(
"vendor",
[
"VENDOR test 1234 1,1",
"VENDOR test 1234 format=3,1",
"VENDOR test 1234 format=2",
"VENDOR test 1234 format=1,2,c",
"VENDOR test 1234 format=1,9",
"VENDOR test 1234 format=4,4 suffix",
"VENDOR test 1234 format=a,b suffix",
],
)
def test_invalid_vendor_definitions(vendor):
dictionary = StringIO(vendor)
with pytest.raises(ParseError):
Dictionary("", dictionary)
@pytest.mark.parametrize(
"number",
[
"ATTRIBUTE NAME 0x01 byte",
"ATTRIBUTE NAME 0x0001 byte",
"ATTRIBUTE NAME 0o123 byte",
"ATTRIBUTE NAME 5 byte",
],
)
def test_valid_attribute_numbers(number):
dictionary = StringIO(number)
Dictionary("", dictionary)
@pytest.mark.parametrize(
"number",
[
"ATTRIBUTE NAME 1234 byte",
"ATTRIBUTE NAME ABCD byte",
"ATTRIBUTE NAME -1 byte",
],
)
def test_invalid_attribute_numbers(number):
dictionary = StringIO(number)
with pytest.raises(ParseError):
Dictionary("", dictionary)
@pytest.mark.parametrize("type_length", [1, 2, 4])
def test_attribute_number_limits(type_length):
too_big = 2 ** (8 * type_length)
max_value = too_big - 1
dictionary = StringIO(
f"VENDOR TEST 1234 format={type_length},1\n"
"BEGIN-VENDOR TEST\n"
f"ATTRIBUTE TEST {max_value} byte\n"
"END-VENDOR TEST\n"
)
Dictionary("", dictionary)
dictionary = StringIO(
f"VENDOR TEST 1234 format={type_length},1\n"
"BEGIN-VENDOR TEST\n"
f"ATTRIBUTE TEST {too_big} byte\n"
"END-VENDOR TEST\n"
)
with pytest.raises(ParseError):
Dictionary("", dictionary)
@pytest.mark.parametrize(
"value_definition",
[
"VALUE TEST-ATTRIBUTE TEST-VALUE 1",
"VALUE TEST-ATTRIBUTE TEST-VALUE 0x1",
"VALUE TEST-ATTRIBUTE TEST-VALUE 0o1",
],
)
def test_value_definition(value_definition):
dictionary = StringIO(
"\n".join(["ATTRIBUTE TEST-ATTRIBUTE 1 byte", value_definition])
)
Dictionary("", dictionary)
@pytest.mark.parametrize(
"value_num, attr_type",
[
(0, "byte"),
(255, "byte"),
(0, "short"),
(2 ** 16 - 1, "short"),
(0, "integer"),
(2 ** 32 - 1, "integer"),
((-(2 ** 31)), "signed"),
(2 ** 31 - 1, "signed"),
(0, "integer64"),
(2 ** 64 - 1, "integer64"),
],
)
def test_value_number_within_limit(value_num, attr_type):
dictionary = StringIO(
"\n".join(
[
f"ATTRIBUTE TEST-ATTRIBUTE 1 {attr_type}",
f"VALUE TEST-ATTRIBUTE TEST-VALUE {value_num}",
]
)
)
Dictionary("", dictionary)
@pytest.mark.parametrize(
"value_num, attr_type",
[
(-1, "byte"),
(256, "byte"),
(-1, "short"),
(2 ** 16, "short"),
(-1, "integer"),
(2 ** 32, "integer"),
(2 ** 31, "signed"),
((-(2 ** 31)) - 1, "signed"),
(-1, "integer64"),
(2 ** 64, "integer64"),
],
)
def test_value_number_out_of_limit(value_num, attr_type):
dictionary = StringIO(
"\n".join(
[
f"ATTRIBUTE TEST-ATTRIBUTE 1 {attr_type}",
f"VALUE TEST-ATTRIBUTE TEST-VALUE {value_num}",
]
)
)
with pytest.raises(ParseError):
Dictionary("", dictionary)