safe progress
This commit is contained in:
1
.flake8
1
.flake8
@@ -1,3 +1,2 @@
|
|||||||
[flake8]
|
[flake8]
|
||||||
max-complexity = 10
|
|
||||||
max-line-length = 100
|
max-line-length = 100
|
||||||
|
|||||||
@@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
python.pkgs.buildPythonPackage rec {
|
python.pkgs.buildPythonPackage rec {
|
||||||
pname = "pyrad";
|
pname = "pyrad";
|
||||||
version = "3.0-alpha";
|
version = "1.0-alpha";
|
||||||
|
|
||||||
buildInputs = with python.pkgs; [ netaddr six ];
|
buildInputs = with python.pkgs; [ ];
|
||||||
|
|
||||||
checkInputs = with python.pkgs; [
|
checkInputs = with python.pkgs; [
|
||||||
|
black
|
||||||
|
pytest
|
||||||
|
# pytest-cov
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,41 +15,47 @@ def send_accounting_packet(srv, req):
|
|||||||
try:
|
try:
|
||||||
srv.SendPacket(req)
|
srv.SendPacket(req)
|
||||||
except pyrad.client.Timeout:
|
except pyrad.client.Timeout:
|
||||||
print('RADIUS server does not reply')
|
print("RADIUS server does not reply")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except socket.error as error:
|
except socket.error as error:
|
||||||
print('Network error: ' + error[1])
|
print("Network error: " + error[1])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def main(path_to_dictionary):
|
def main(path_to_dictionary):
|
||||||
srv = Client(server='127.0.0.1',
|
srv = Client(
|
||||||
secret=b'Kah3choteereethiejeimaeziecumi',
|
server="127.0.0.1",
|
||||||
dict=Dictionary(path_to_dictionary))
|
secret=b"Kah3choteereethiejeimaeziecumi",
|
||||||
|
dict=Dictionary(path_to_dictionary),
|
||||||
|
)
|
||||||
|
|
||||||
req = srv.CreateAcctPacket(**{
|
req = srv.CreateAcctPacket(
|
||||||
'User-Name': 'wichert',
|
**{
|
||||||
'NAS-IP-Address': '192.168.1.10',
|
"User-Name": "wichert",
|
||||||
'NAS-Port': 0,
|
"NAS-IP-Address": "192.168.1.10",
|
||||||
'NAS-Identifier': 'trillian',
|
"NAS-Port": 0,
|
||||||
'Called-Station-Id': '00-04-5F-00-0F-D1',
|
"NAS-Identifier": "trillian",
|
||||||
'Calling-Station-Id': '00-01-24-80-B3-9C',
|
"Called-Station-Id": "00-04-5F-00-0F-D1",
|
||||||
'Framed-IP-Address': '10.0.0.100',
|
"Calling-Station-Id": "00-01-24-80-B3-9C",
|
||||||
})
|
"Framed-IP-Address": "10.0.0.100",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
print('Sending accounting start packet')
|
print("Sending accounting start packet")
|
||||||
req['Acct-Status-Type'] = 'Start'
|
req["Acct-Status-Type"] = "Start"
|
||||||
send_accounting_packet(srv, req)
|
send_accounting_packet(srv, req)
|
||||||
|
|
||||||
print('Sending accounting stop packet')
|
print("Sending accounting stop packet")
|
||||||
req['Acct-Status-Type'] = 'Stop'
|
req["Acct-Status-Type"] = "Stop"
|
||||||
req['Acct-Input-Octets'] = random.randrange(2**10, 2**30)
|
req["Acct-Input-Octets"] = random.randrange(2 ** 10, 2 ** 30)
|
||||||
req['Acct-Output-Octets'] = random.randrange(2**10, 2**30)
|
req["Acct-Output-Octets"] = random.randrange(2 ** 10, 2 ** 30)
|
||||||
req['Acct-Session-Time'] = random.randrange(120, 3600)
|
req["Acct-Session-Time"] = random.randrange(120, 3600)
|
||||||
req['Acct-Terminate-Cause'] = random.choice(['User-Request', 'Idle-Timeout'])
|
req["Acct-Terminate-Cause"] = random.choice(
|
||||||
|
["User-Request", "Idle-Timeout"]
|
||||||
|
)
|
||||||
send_accounting_packet(srv, req)
|
send_accounting_packet(srv, req)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary)
|
main(dictionary)
|
||||||
|
|||||||
@@ -11,43 +11,46 @@ from pyrad.dictionary import Dictionary
|
|||||||
|
|
||||||
|
|
||||||
def main(path_to_dictionary):
|
def main(path_to_dictionary):
|
||||||
srv = Client(server='127.0.0.1',
|
srv = Client(
|
||||||
secret=b'Kah3choteereethiejeimaeziecumi',
|
server="127.0.0.1",
|
||||||
dict=Dictionary(path_to_dictionary))
|
secret=b"Kah3choteereethiejeimaeziecumi",
|
||||||
|
dict=Dictionary(path_to_dictionary),
|
||||||
|
)
|
||||||
|
|
||||||
req = srv.CreateAuthPacket(
|
req = srv.CreateAuthPacket(
|
||||||
code=pyrad.packet.AccessRequest,
|
code=pyrad.packet.AccessRequest,
|
||||||
**{
|
**{
|
||||||
'User-Name': 'wichert',
|
"User-Name": "wichert",
|
||||||
'NAS-IP-Address': '192.168.1.10',
|
"NAS-IP-Address": "192.168.1.10",
|
||||||
'NAS-Port': 0,
|
"NAS-Port": 0,
|
||||||
'Service-Type': 'Login-User',
|
"Service-Type": "Login-User",
|
||||||
'NAS-Identifier': 'trillian',
|
"NAS-Identifier": "trillian",
|
||||||
'Called-Station-Id': '00-04-5F-00-0F-D1',
|
"Called-Station-Id": "00-04-5F-00-0F-D1",
|
||||||
'Calling-Station-Id': '00-01-24-80-B3-9C',
|
"Calling-Station-Id": "00-01-24-80-B3-9C",
|
||||||
'Framed-IP-Address': '10.0.0.100',
|
"Framed-IP-Address": "10.0.0.100",
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print('Sending authentication request')
|
print("Sending authentication request")
|
||||||
reply = srv.SendPacket(req)
|
reply = srv.SendPacket(req)
|
||||||
except pyrad.client.Timeout:
|
except pyrad.client.Timeout:
|
||||||
print('RADIUS server does not reply')
|
print("RADIUS server does not reply")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except socket.error as error:
|
except socket.error as error:
|
||||||
print('Network error: ' + error[1])
|
print("Network error: " + error[1])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if reply.code == pyrad.packet.AccessAccept:
|
if reply.code == pyrad.packet.AccessAccept:
|
||||||
print('Access accepted')
|
print("Access accepted")
|
||||||
else:
|
else:
|
||||||
print('Access denied')
|
print("Access denied")
|
||||||
|
|
||||||
print('Attributes returned by server:')
|
print("Attributes returned by server:")
|
||||||
for key, value in reply.items():
|
for key, value in reply.items():
|
||||||
print(f'{key} {value}')
|
print(f"{key} {value}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary)
|
main(dictionary)
|
||||||
|
|||||||
@@ -10,49 +10,58 @@ from pyrad.client_async import ClientAsync
|
|||||||
from pyrad.dictionary import Dictionary
|
from pyrad.dictionary import Dictionary
|
||||||
from pyrad.packet import AccessAccept
|
from pyrad.packet import AccessAccept
|
||||||
|
|
||||||
logging.basicConfig(level='DEBUG',
|
logging.basicConfig(
|
||||||
format='%(asctime)s [%(levelname)-8s] %(message)s')
|
level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_request(client, user):
|
def create_request(client, user):
|
||||||
return client.CreateAuthPacket(**{
|
return client.CreateAuthPacket(
|
||||||
'User-Name': user,
|
**{
|
||||||
'NAS-IP-Address': '192.168.1.10',
|
"User-Name": user,
|
||||||
'NAS-Port': 0,
|
"NAS-IP-Address": "192.168.1.10",
|
||||||
'Service-Type': 'Login-User',
|
"NAS-Port": 0,
|
||||||
'NAS-Identifier': 'trillian',
|
"Service-Type": "Login-User",
|
||||||
'Called-Station-Id': '00-04-5F-00-0F-D1',
|
"NAS-Identifier": "trillian",
|
||||||
'Calling-Station-Id': '00-01-24-80-B3-9C',
|
"Called-Station-Id": "00-04-5F-00-0F-D1",
|
||||||
'Framed-IP-Address': '10.0.0.100',
|
"Calling-Station-Id": "00-01-24-80-B3-9C",
|
||||||
})
|
"Framed-IP-Address": "10.0.0.100",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_reply(reply):
|
def print_reply(reply):
|
||||||
if reply.code == AccessAccept:
|
if reply.code == AccessAccept:
|
||||||
print('Access accepted')
|
print("Access accepted")
|
||||||
else:
|
else:
|
||||||
print('Access denied')
|
print("Access denied")
|
||||||
|
|
||||||
print('Attributes returned by server:')
|
print("Attributes returned by server:")
|
||||||
for key, value in reply.items():
|
for key, value in reply.items():
|
||||||
print(f'{key}: {value}')
|
print(f"{key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
def initialize_transport(loop, client):
|
def initialize_transport(loop, client):
|
||||||
loop.run_until_complete(
|
loop.run_until_complete(
|
||||||
asyncio.ensure_future(
|
asyncio.ensure_future(
|
||||||
client.initialize_transports(enable_auth=True,
|
client.initialize_transports(
|
||||||
local_addr='127.0.0.1',
|
enable_auth=True,
|
||||||
local_auth_port=8000,
|
local_addr="127.0.0.1",
|
||||||
enable_acct=True,
|
local_auth_port=8000,
|
||||||
enable_coa=True)))
|
enable_acct=True,
|
||||||
|
enable_coa=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(path_to_dictionary):
|
def main(path_to_dictionary):
|
||||||
client = ClientAsync(server='localhost',
|
client = ClientAsync(
|
||||||
secret=b'Kah3choteereethiejeimaeziecumi',
|
server="localhost",
|
||||||
timeout=4,
|
secret=b"Kah3choteereethiejeimaeziecumi",
|
||||||
dict=Dictionary(path_to_dictionary))
|
timeout=4,
|
||||||
|
dict=Dictionary(path_to_dictionary),
|
||||||
|
)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
@@ -62,41 +71,41 @@ def main(path_to_dictionary):
|
|||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
for i in range(255):
|
for i in range(255):
|
||||||
req = create_request(client, f'user{i}')
|
req = create_request(client, f"user{i}")
|
||||||
future = client.SendPacket(req)
|
future = client.SendPacket(req)
|
||||||
requests.append(future)
|
requests.append(future)
|
||||||
|
|
||||||
# Send auth requests asynchronously to the server
|
# Send auth requests asynchronously to the server
|
||||||
loop.run_until_complete(asyncio.ensure_future(
|
loop.run_until_complete(
|
||||||
asyncio.gather(
|
asyncio.ensure_future(
|
||||||
*requests,
|
asyncio.gather(*requests, return_exceptions=True)
|
||||||
return_exceptions=True
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
))
|
|
||||||
|
|
||||||
for future in requests:
|
for future in requests:
|
||||||
if future.exception():
|
if future.exception():
|
||||||
print('EXCEPTION ', future.exception())
|
print("EXCEPTION ", future.exception())
|
||||||
else:
|
else:
|
||||||
reply = future.result()
|
reply = future.result()
|
||||||
print_reply(reply)
|
print_reply(reply)
|
||||||
|
|
||||||
# Close transports
|
# Close transports
|
||||||
loop.run_until_complete(asyncio.ensure_future(
|
loop.run_until_complete(
|
||||||
client.deinitialize_transports()))
|
asyncio.ensure_future(client.deinitialize_transports())
|
||||||
print('END')
|
)
|
||||||
|
print("END")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('Error: ', exc)
|
print("Error: ", exc)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# Close transports
|
# Close transports
|
||||||
loop.run_until_complete(asyncio.ensure_future(
|
loop.run_until_complete(
|
||||||
client.deinitialize_transports()))
|
asyncio.ensure_future(client.deinitialize_transports())
|
||||||
|
)
|
||||||
|
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary)
|
main(dictionary)
|
||||||
|
|||||||
@@ -14,21 +14,21 @@ from pyrad.server import Server, RemoteHost
|
|||||||
|
|
||||||
|
|
||||||
def print_attributes(packet):
|
def print_attributes(packet):
|
||||||
print('Attributes')
|
print("Attributes")
|
||||||
for key, value in packet.items():
|
for key, value in packet.items():
|
||||||
print(f'{key}: {value}')
|
print(f"{key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
class FakeCoA(Server):
|
class FakeCoA(Server):
|
||||||
def HandleCoaPacket(self, packet):
|
def HandleCoaPacket(self, packet):
|
||||||
'''Accounting packet handler.
|
"""Accounting packet handler.
|
||||||
Function that is called when a valid
|
Function that is called when a valid
|
||||||
accounting packet has been received.
|
accounting packet has been received.
|
||||||
|
|
||||||
:param packet: packet to process
|
:param packet: packet to process
|
||||||
:type packet: Packet class instance
|
:type packet: Packet class instance
|
||||||
'''
|
"""
|
||||||
print('Received a coa request %d' % packet.code)
|
print("Received a coa request %d" % packet.code)
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
@@ -38,7 +38,7 @@ class FakeCoA(Server):
|
|||||||
self.SendReplyPacket(packet.fd, reply)
|
self.SendReplyPacket(packet.fd, reply)
|
||||||
|
|
||||||
def HandleDisconnectPacket(self, packet):
|
def HandleDisconnectPacket(self, packet):
|
||||||
print('Received a disconnect request %d' % packet.code)
|
print("Received a disconnect request %d" % packet.code)
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
@@ -52,27 +52,27 @@ def main(path_to_dictionary, coa_port):
|
|||||||
# create server/coa only and read dictionary
|
# create server/coa only and read dictionary
|
||||||
# bind and listen only on 127.0.0.1:argv[1]
|
# bind and listen only on 127.0.0.1:argv[1]
|
||||||
coa = FakeCoA(
|
coa = FakeCoA(
|
||||||
addresses=['127.0.0.1'],
|
addresses=["127.0.0.1"],
|
||||||
dict=Dictionary(path_to_dictionary),
|
dict=Dictionary(path_to_dictionary),
|
||||||
coaport=coa_port,
|
coaport=coa_port,
|
||||||
auth_enabled=False,
|
auth_enabled=False,
|
||||||
acct_enabled=False,
|
acct_enabled=False,
|
||||||
coa_enabled=True)
|
coa_enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
# add peers (address, secret, name)
|
# add peers (address, secret, name)
|
||||||
coa.hosts['127.0.0.1'] = RemoteHost(
|
coa.hosts["127.0.0.1"] = RemoteHost(
|
||||||
'127.0.0.1',
|
"127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost"
|
||||||
b'Kah3choteereethiejeimaeziecumi',
|
)
|
||||||
'localhost')
|
|
||||||
|
|
||||||
# start
|
# start
|
||||||
coa.Run()
|
coa.Run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) != 2:
|
||||||
print('usage: client-coa.py {portnumber}')
|
print("usage: client-coa.py {portnumber}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary, int(sys.argv[1]))
|
main(dictionary, int(sys.argv[1]))
|
||||||
|
|||||||
@@ -11,27 +11,29 @@ from pyrad.dictionary import Dictionary
|
|||||||
|
|
||||||
def main(path_to_dictionary, coa_type, nas_identifier):
|
def main(path_to_dictionary, coa_type, nas_identifier):
|
||||||
# create coa client
|
# create coa client
|
||||||
client = Client(server='127.0.0.1',
|
client = Client(
|
||||||
secret=b'Kah3choteereethiejeimaeziecumi',
|
server="127.0.0.1",
|
||||||
dict=Dictionary(path_to_dictionary))
|
secret=b"Kah3choteereethiejeimaeziecumi",
|
||||||
|
dict=Dictionary(path_to_dictionary),
|
||||||
|
)
|
||||||
|
|
||||||
# set coa timeout
|
# set coa timeout
|
||||||
client.timeout = 30
|
client.timeout = 30
|
||||||
|
|
||||||
# create coa request packet
|
# create coa request packet
|
||||||
attributes = {
|
attributes = {
|
||||||
'Acct-Session-Id': '1337',
|
"Acct-Session-Id": "1337",
|
||||||
'NAS-Identifier': nas_identifier,
|
"NAS-Identifier": nas_identifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
if coa_type == 'coa':
|
if coa_type == "coa":
|
||||||
# create coa request
|
# create coa request
|
||||||
request = client.CreateCoAPacket(**attributes)
|
request = client.CreateCoAPacket(**attributes)
|
||||||
elif coa_type == 'dis':
|
elif coa_type == "dis":
|
||||||
# create disconnect request
|
# create disconnect request
|
||||||
request = client.CreateCoAPacket(
|
request = client.CreateCoAPacket(
|
||||||
code=pyrad.packet.DisconnectRequest,
|
code=pyrad.packet.DisconnectRequest, **attributes
|
||||||
**attributes)
|
)
|
||||||
else:
|
else:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@@ -41,11 +43,11 @@ def main(path_to_dictionary, coa_type, nas_identifier):
|
|||||||
print(result.code)
|
print(result.code)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) != 3:
|
if len(sys.argv) != 3:
|
||||||
print('usage: coa.py {coa|dis} daemon-1234')
|
print("usage: coa.py {coa|dis} daemon-1234")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
|
|
||||||
main(dictionary, sys.argv[1], sys.argv[2])
|
main(dictionary, sys.argv[1], sys.argv[2])
|
||||||
|
|||||||
@@ -8,46 +8,52 @@ import pyrad.packet
|
|||||||
from pyrad import server
|
from pyrad import server
|
||||||
from pyrad.dictionary import Dictionary
|
from pyrad.dictionary import Dictionary
|
||||||
|
|
||||||
logging.basicConfig(filename='pyrad.log', level='DEBUG',
|
logging.basicConfig(
|
||||||
format='%(asctime)s [%(levelname)-8s] %(message)s')
|
filename="pyrad.log",
|
||||||
|
level="DEBUG",
|
||||||
|
format="%(asctime)s [%(levelname)-8s] %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_attributes(packet):
|
def print_attributes(packet):
|
||||||
print('Attributes')
|
print("Attributes")
|
||||||
for key, value in packet.items():
|
for key, value in packet.items():
|
||||||
print(f'{key}: {value}')
|
print(f"{key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
class FakeServer(server.Server):
|
class FakeServer(server.Server):
|
||||||
def HandleAuthPacket(self, packet):
|
def HandleAuthPacket(self, packet):
|
||||||
print('Received an authentication request')
|
print("Received an authentication request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet, **{
|
reply = self.CreateReplyPacket(
|
||||||
'Service-Type': 'Framed-User',
|
packet,
|
||||||
'Framed-IP-Address': '192.168.0.1',
|
**{
|
||||||
'Framed-IPv6-Prefix': 'fc66::/64'
|
"Service-Type": "Framed-User",
|
||||||
})
|
"Framed-IP-Address": "192.168.0.1",
|
||||||
|
"Framed-IPv6-Prefix": "fc66::/64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
reply.code = pyrad.packet.AccessAccept
|
reply.code = pyrad.packet.AccessAccept
|
||||||
self.SendReplyPacket(packet.fd, reply)
|
self.SendReplyPacket(packet.fd, reply)
|
||||||
|
|
||||||
def HandleAcctPacket(self, packet):
|
def HandleAcctPacket(self, packet):
|
||||||
print('Received an accounting request')
|
print("Received an accounting request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
self.SendReplyPacket(packet.fd, reply)
|
self.SendReplyPacket(packet.fd, reply)
|
||||||
|
|
||||||
def HandleCoaPacket(self, packet):
|
def HandleCoaPacket(self, packet):
|
||||||
print('Received an coa request')
|
print("Received an coa request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
self.SendReplyPacket(packet.fd, reply)
|
self.SendReplyPacket(packet.fd, reply)
|
||||||
|
|
||||||
def HandleDisconnectPacket(self, packet):
|
def HandleDisconnectPacket(self, packet):
|
||||||
print('Received an disconnect request')
|
print("Received an disconnect request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
@@ -58,20 +64,18 @@ class FakeServer(server.Server):
|
|||||||
|
|
||||||
def main(path_to_dictionary):
|
def main(path_to_dictionary):
|
||||||
# create server and read dictionary
|
# create server and read dictionary
|
||||||
srv = FakeServer(dict=Dictionary(path_to_dictionary),
|
srv = FakeServer(dict=Dictionary(path_to_dictionary), coa_enabled=True)
|
||||||
coa_enabled=True)
|
|
||||||
|
|
||||||
# add clients (address, secret, name)
|
# add clients (address, secret, name)
|
||||||
srv.hosts['127.0.0.1'] = server.RemoteHost(
|
srv.hosts["127.0.0.1"] = server.RemoteHost(
|
||||||
'127.0.0.1',
|
"127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost"
|
||||||
b'Kah3choteereethiejeimaeziecumi',
|
)
|
||||||
'localhost')
|
srv.BindToAddress("0.0.0.0")
|
||||||
srv.BindToAddress('0.0.0.0')
|
|
||||||
|
|
||||||
# start server
|
# start server
|
||||||
srv.Run()
|
srv.Run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary)
|
main(dictionary)
|
||||||
|
|||||||
@@ -12,57 +12,67 @@ from pyrad.server import RemoteHost
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import uvloop
|
import uvloop
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
logging.basicConfig(level='DEBUG',
|
logging.basicConfig(
|
||||||
format='%(asctime)s [%(levelname)-8s] %(message)s')
|
level="DEBUG", format="%(asctime)s [%(levelname)-8s] %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_attributes(packet):
|
def print_attributes(packet):
|
||||||
print('Attributes returned by server:')
|
print("Attributes returned by server:")
|
||||||
for key, value in packet.items():
|
for key, value in packet.items():
|
||||||
print(f'{key}: {value}')
|
print(f"{key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
class FakeServer(ServerAsync):
|
class FakeServer(ServerAsync):
|
||||||
def __init__(self, loop, dictionary):
|
def __init__(self, loop, dictionary):
|
||||||
|
|
||||||
ServerAsync.__init__(self, loop=loop, dictionary=dictionary,
|
ServerAsync.__init__(
|
||||||
enable_pkt_verify=True, debug=True)
|
self,
|
||||||
|
loop=loop,
|
||||||
|
dictionary=dictionary,
|
||||||
|
enable_pkt_verify=True,
|
||||||
|
debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
def handle_auth_packet(self, protocol, packet, addr):
|
def handle_auth_packet(self, protocol, packet, addr):
|
||||||
print('Received an authentication request with id ', packet.id)
|
print("Received an authentication request with id ", packet.id)
|
||||||
print('Authenticator ', packet.authenticator.hex())
|
print("Authenticator ", packet.authenticator.hex())
|
||||||
print('Secret ', packet.secret)
|
print("Secret ", packet.secret)
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet, **{
|
reply = self.CreateReplyPacket(
|
||||||
'Service-Type': 'Framed-User',
|
packet,
|
||||||
'Framed-IP-Address': '192.168.0.1',
|
**{
|
||||||
'Framed-IPv6-Prefix': 'fc66::/64'
|
"Service-Type": "Framed-User",
|
||||||
})
|
"Framed-IP-Address": "192.168.0.1",
|
||||||
|
"Framed-IPv6-Prefix": "fc66::/64",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
reply.code = AccessAccept
|
reply.code = AccessAccept
|
||||||
protocol.send_response(reply, addr)
|
protocol.send_response(reply, addr)
|
||||||
|
|
||||||
def handle_acct_packet(self, protocol, packet, addr):
|
def handle_acct_packet(self, protocol, packet, addr):
|
||||||
print('Received an accounting request')
|
print("Received an accounting request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
protocol.send_response(reply, addr)
|
protocol.send_response(reply, addr)
|
||||||
|
|
||||||
def handle_coa_packet(self, protocol, packet, addr):
|
def handle_coa_packet(self, protocol, packet, addr):
|
||||||
print('Received an coa request')
|
print("Received an coa request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
protocol.send_response(reply, addr)
|
protocol.send_response(reply, addr)
|
||||||
|
|
||||||
def handle_disconnect_packet(self, protocol, packet, addr):
|
def handle_disconnect_packet(self, protocol, packet, addr):
|
||||||
print('Received an disconnect request')
|
print("Received an disconnect request")
|
||||||
print_attributes(packet)
|
print_attributes(packet)
|
||||||
|
|
||||||
reply = self.CreateReplyPacket(packet)
|
reply = self.CreateReplyPacket(packet)
|
||||||
@@ -77,17 +87,19 @@ def main(path_to_dictionary):
|
|||||||
server = FakeServer(loop=loop, dictionary=Dictionary(path_to_dictionary))
|
server = FakeServer(loop=loop, dictionary=Dictionary(path_to_dictionary))
|
||||||
|
|
||||||
# add clients (address, secret, name)
|
# add clients (address, secret, name)
|
||||||
server.hosts['127.0.0.1'] = RemoteHost('127.0.0.1',
|
server.hosts["127.0.0.1"] = RemoteHost(
|
||||||
b'Kah3choteereethiejeimaeziecumi',
|
"127.0.0.1", b"Kah3choteereethiejeimaeziecumi", "localhost"
|
||||||
'localhost')
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize transports
|
# Initialize transports
|
||||||
loop.run_until_complete(
|
loop.run_until_complete(
|
||||||
asyncio.ensure_future(
|
asyncio.ensure_future(
|
||||||
server.initialize_transports(enable_auth=True,
|
server.initialize_transports(
|
||||||
enable_acct=True,
|
enable_auth=True, enable_acct=True, enable_coa=True
|
||||||
enable_coa=True)))
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# start server
|
# start server
|
||||||
@@ -96,20 +108,22 @@ def main(path_to_dictionary):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Close transports
|
# Close transports
|
||||||
loop.run_until_complete(asyncio.ensure_future(
|
loop.run_until_complete(
|
||||||
server.deinitialize_transports()))
|
asyncio.ensure_future(server.deinitialize_transports())
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('Error: ', exc)
|
print("Error: ", exc)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
# Close transports
|
# Close transports
|
||||||
loop.run_until_complete(asyncio.ensure_future(
|
loop.run_until_complete(
|
||||||
server.deinitialize_transports()))
|
asyncio.ensure_future(server.deinitialize_transports())
|
||||||
|
)
|
||||||
|
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary)
|
main(dictionary)
|
||||||
|
|||||||
@@ -11,32 +11,33 @@ from pyrad.dictionary import Dictionary
|
|||||||
|
|
||||||
|
|
||||||
def main(path_to_dictionary):
|
def main(path_to_dictionary):
|
||||||
srv = Client(server='localhost',
|
srv = Client(
|
||||||
authport=18121,
|
server="localhost",
|
||||||
secret=b'test',
|
authport=18121,
|
||||||
dict=Dictionary(path_to_dictionary))
|
secret=b"test",
|
||||||
|
dict=Dictionary(path_to_dictionary),
|
||||||
|
)
|
||||||
|
|
||||||
req = srv.CreateAuthPacket(
|
req = srv.CreateAuthPacket(
|
||||||
code=pyrad.packet.StatusServer,
|
code=pyrad.packet.StatusServer, FreeRADIUS_Statistics_Type="All",
|
||||||
FreeRADIUS_Statistics_Type='All',
|
|
||||||
)
|
)
|
||||||
req.add_message_authenticator()
|
req.add_message_authenticator()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print('Sending FreeRADIUS status request')
|
print("Sending FreeRADIUS status request")
|
||||||
reply = srv.SendPacket(req)
|
reply = srv.SendPacket(req)
|
||||||
except pyrad.client.Timeout:
|
except pyrad.client.Timeout:
|
||||||
print('RADIUS server does not reply')
|
print("RADIUS server does not reply")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
except socket.error as error:
|
except socket.error as error:
|
||||||
print('Network error: ' + error[1])
|
print("Network error: " + error[1])
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
print('Attributes returned by server:')
|
print("Attributes returned by server:")
|
||||||
for key, value in reply.items():
|
for key, value in reply.items():
|
||||||
print(f'{key}: {value}')
|
print(f"{key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
dictionary = path.join(path.dirname(path.abspath(__file__)), 'dictionary')
|
dictionary = path.join(path.dirname(path.abspath(__file__)), "dictionary")
|
||||||
main(dictionary)
|
main(dictionary)
|
||||||
|
|||||||
@@ -21,13 +21,11 @@ classifiers = [
|
|||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: System :: Systems Administration :: Authentication/Directory"
|
"Topic :: System :: Systems Administration :: Authentication/Directory"
|
||||||
|
|
||||||
]
|
]
|
||||||
packages = [
|
packages = [
|
||||||
{ include = "pyrad3"},
|
{ include = "src/pyrad3"},
|
||||||
]
|
]
|
||||||
include = [
|
include = [
|
||||||
"CHANGELOG.md",
|
|
||||||
"LICENSE-APACHE",
|
"LICENSE-APACHE",
|
||||||
"LICENSE-MIT",
|
"LICENSE-MIT",
|
||||||
"README.md",
|
"README.md",
|
||||||
@@ -40,9 +38,16 @@ repository = "https://github.com/pyradius/pyrad3"
|
|||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^5.4"
|
pytest = "^5.4"
|
||||||
pytest-cov = "^2.5"
|
pytest-black = "^0.30"
|
||||||
|
pytest-cov = "^2.10"
|
||||||
|
pytest-mypy = "^0.6"
|
||||||
|
pytest-pylint = "^0.17"
|
||||||
|
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 100
|
line-length = 80
|
||||||
include = '\.py'
|
include = '\.py'
|
||||||
|
|
||||||
|
|
||||||
|
[tool.pylint.messages_control]
|
||||||
|
disable = "bad-continuation"
|
||||||
|
|||||||
221
pyrad3/client.py
221
pyrad3/client.py
@@ -1,221 +0,0 @@
|
|||||||
# client.py
|
|
||||||
#
|
|
||||||
# Copyright 2002-2007 Wichert Akkerman <wichert@wiggy.net>
|
|
||||||
|
|
||||||
__docformat__ = "epytext en"
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import select
|
|
||||||
import socket
|
|
||||||
import time
|
|
||||||
import struct
|
|
||||||
from pyrad import host
|
|
||||||
from pyrad import packet
|
|
||||||
|
|
||||||
EAP_CODE_REQUEST = 1
|
|
||||||
EAP_CODE_RESPONSE = 2
|
|
||||||
EAP_TYPE_IDENTITY = 1
|
|
||||||
|
|
||||||
|
|
||||||
class Timeout(Exception):
|
|
||||||
"""Simple exception class which is raised when a timeout occurs
|
|
||||||
while waiting for a RADIUS server to respond."""
|
|
||||||
|
|
||||||
|
|
||||||
class Client(host.Host):
|
|
||||||
"""Basic RADIUS client.
|
|
||||||
This class implements a basic RADIUS client. It can send requests
|
|
||||||
to a RADIUS server, taking care of timeouts and retries, and
|
|
||||||
validate its replies.
|
|
||||||
|
|
||||||
:ivar retries: number of times to retry sending a RADIUS request
|
|
||||||
:type retries: integer
|
|
||||||
:ivar timeout: number of seconds to wait for an answer
|
|
||||||
:type timeout: float
|
|
||||||
"""
|
|
||||||
def __init__(self, server, authport=1812, acctport=1813,
|
|
||||||
coaport=3799, secret=b'', dict=None, retries=3,
|
|
||||||
timeout=5):
|
|
||||||
|
|
||||||
"""Constructor.
|
|
||||||
|
|
||||||
:param server: hostname or IP address of RADIUS server
|
|
||||||
:type server: string
|
|
||||||
:param authport: port to use for authentication packets
|
|
||||||
:type authport: integer
|
|
||||||
:param acctport: port to use for accounting packets
|
|
||||||
:type acctport: integer
|
|
||||||
:param coaport: port to use for CoA packets
|
|
||||||
:type coaport: integer
|
|
||||||
:param secret: RADIUS secret
|
|
||||||
:type secret: string
|
|
||||||
:param dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary
|
|
||||||
"""
|
|
||||||
host.Host.__init__(self, authport, acctport, coaport, dict)
|
|
||||||
|
|
||||||
self.server = server
|
|
||||||
self.secret = secret
|
|
||||||
self._socket = None
|
|
||||||
self.retries = retries
|
|
||||||
self.timeout = timeout
|
|
||||||
self._poll = select.poll()
|
|
||||||
|
|
||||||
def bind(self, addr):
|
|
||||||
"""Bind socket to an address.
|
|
||||||
Binding the socket used for communicating to an address can be
|
|
||||||
usefull when working on a machine with multiple addresses.
|
|
||||||
|
|
||||||
:param addr: network address (hostname or IP) and port to bind to
|
|
||||||
:type addr: host,port tuple
|
|
||||||
"""
|
|
||||||
self._close_socket()
|
|
||||||
self._socket_open()
|
|
||||||
self._socket.bind(addr)
|
|
||||||
|
|
||||||
def _socket_open(self):
|
|
||||||
try:
|
|
||||||
family = socket.getaddrinfo(self.server, 'www')[0][0]
|
|
||||||
except:
|
|
||||||
family = socket.AF_INET
|
|
||||||
if not self._socket:
|
|
||||||
self._socket = socket.socket(family,
|
|
||||||
socket.SOCK_DGRAM)
|
|
||||||
self._socket.setsockopt(socket.SOL_SOCKET,
|
|
||||||
socket.SO_REUSEADDR, 1)
|
|
||||||
self._poll.register(self._socket, select.POLLIN)
|
|
||||||
|
|
||||||
def _close_socket(self):
|
|
||||||
if self._socket:
|
|
||||||
self._poll.unregister(self._socket)
|
|
||||||
self._socket.close()
|
|
||||||
self._socket = None
|
|
||||||
|
|
||||||
def CreateAuthPacket(self, **args):
|
|
||||||
"""Create a new RADIUS packet.
|
|
||||||
This utility function creates a new RADIUS packet which can
|
|
||||||
be used to communicate with the RADIUS server this client
|
|
||||||
talks to. This is initializing the new packet with the
|
|
||||||
dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.AuthPacket
|
|
||||||
"""
|
|
||||||
return host.Host.CreateAuthPacket(self, secret=self.secret, **args)
|
|
||||||
|
|
||||||
def CreateAcctPacket(self, **args):
|
|
||||||
"""Create a new RADIUS packet.
|
|
||||||
This utility function creates a new RADIUS packet which can
|
|
||||||
be used to communicate with the RADIUS server this client
|
|
||||||
talks to. This is initializing the new packet with the
|
|
||||||
dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.Packet
|
|
||||||
"""
|
|
||||||
return host.Host.CreateAcctPacket(self, secret=self.secret, **args)
|
|
||||||
|
|
||||||
def CreateCoAPacket(self, **args):
|
|
||||||
"""Create a new RADIUS packet.
|
|
||||||
This utility function creates a new RADIUS packet which can
|
|
||||||
be used to communicate with the RADIUS server this client
|
|
||||||
talks to. This is initializing the new packet with the
|
|
||||||
dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.Packet
|
|
||||||
"""
|
|
||||||
return host.Host.CreateCoAPacket(self, secret=self.secret, **args)
|
|
||||||
|
|
||||||
def _send_packet(self, pkt, port):
|
|
||||||
"""Send a packet to a RADIUS server.
|
|
||||||
|
|
||||||
:param pkt: the packet to send
|
|
||||||
:type pkt: pyrad.packet.Packet
|
|
||||||
:param port: UDP port to send packet to
|
|
||||||
:type port: integer
|
|
||||||
:return: the reply packet received
|
|
||||||
:rtype: pyrad.packet.Packet
|
|
||||||
:raise Timeout: RADIUS server does not reply
|
|
||||||
"""
|
|
||||||
self._socket_open()
|
|
||||||
|
|
||||||
for attempt in range(self.retries):
|
|
||||||
if attempt and pkt.code == packet.AccountingRequest:
|
|
||||||
if "Acct-Delay-Time" in pkt:
|
|
||||||
pkt["Acct-Delay-Time"] = \
|
|
||||||
pkt["Acct-Delay-Time"][0] + self.timeout
|
|
||||||
else:
|
|
||||||
pkt["Acct-Delay-Time"] = self.timeout
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
waitto = now + self.timeout
|
|
||||||
|
|
||||||
self._socket.sendto(pkt.RequestPacket(), (self.server, port))
|
|
||||||
|
|
||||||
while now < waitto:
|
|
||||||
ready = self._poll.poll((waitto - now) * 1000)
|
|
||||||
|
|
||||||
if ready:
|
|
||||||
rawreply = self._socket.recv(4096)
|
|
||||||
else:
|
|
||||||
now = time.time()
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
return pkt.VerifyReply(rawreply)
|
|
||||||
except packet.PacketError:
|
|
||||||
# TODO: report or error out maybe?
|
|
||||||
pass
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
raise Timeout
|
|
||||||
|
|
||||||
def SendPacket(self, pkt):
|
|
||||||
"""Send a packet to a RADIUS server.
|
|
||||||
|
|
||||||
:param pkt: the packet to send
|
|
||||||
:type pkt: pyrad.packet.Packet
|
|
||||||
:return: the reply packet received
|
|
||||||
:rtype: pyrad.packet.Packet
|
|
||||||
:raise Timeout: RADIUS server does not reply
|
|
||||||
"""
|
|
||||||
if isinstance(pkt, packet.AuthPacket):
|
|
||||||
if pkt.auth_type == 'eap-md5':
|
|
||||||
# Creating EAP-Identity
|
|
||||||
password = pkt[2][0] if 2 in pkt else pkt[1][0]
|
|
||||||
pkt[79] = [struct.pack('!BBHB%ds' % len(password),
|
|
||||||
EAP_CODE_RESPONSE,
|
|
||||||
packet.CurrentID,
|
|
||||||
len(password) + 5,
|
|
||||||
EAP_TYPE_IDENTITY,
|
|
||||||
password)]
|
|
||||||
reply = self._send_packet(pkt, self.authport)
|
|
||||||
if (reply
|
|
||||||
and reply.code == packet.AccessChallenge
|
|
||||||
and pkt.auth_type == 'eap-md5'
|
|
||||||
):
|
|
||||||
# Got an Access-Challenge
|
|
||||||
_eap_code, eap_id, _eap_size, _eap_type, eap_md5 = struct.unpack(
|
|
||||||
'!BBHB%ds' % (len(reply[79][0]) - 5), reply[79][0]
|
|
||||||
)
|
|
||||||
# Sending back an EAP-Type-MD5-Challenge
|
|
||||||
# Thank god for http://www.secdev.org/python/eapy.py
|
|
||||||
client_pw = pkt[2][0] if 2 in pkt else pkt[1][0]
|
|
||||||
md5_challenge = hashlib.md5(
|
|
||||||
struct.pack('!B', eap_id) + client_pw + eap_md5[1:]
|
|
||||||
).digest()
|
|
||||||
pkt[79] = [
|
|
||||||
struct.pack('!BBHBB', 2, eap_id, len(md5_challenge) + 6,
|
|
||||||
4, len(md5_challenge)) + md5_challenge
|
|
||||||
]
|
|
||||||
# Copy over Challenge-State
|
|
||||||
pkt[24] = reply[24]
|
|
||||||
reply = self._send_packet(pkt, self.authport)
|
|
||||||
return reply
|
|
||||||
|
|
||||||
if isinstance(pkt, packet.CoAPacket):
|
|
||||||
return self._send_packet(pkt, self.coaport)
|
|
||||||
|
|
||||||
return self._send_packet(pkt, self.acctport)
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
# curved.py
|
|
||||||
#
|
|
||||||
# Copyright 2002 Wichert Akkerman <wichert@wiggy.net>
|
|
||||||
|
|
||||||
"""Twisted integration code
|
|
||||||
"""
|
|
||||||
|
|
||||||
__docformat__ = 'epytext en'
|
|
||||||
|
|
||||||
from twisted.internet import protocol
|
|
||||||
from twisted.internet import reactor
|
|
||||||
from twisted.python import log
|
|
||||||
import sys
|
|
||||||
from pyrad import dictionary
|
|
||||||
from pyrad import host
|
|
||||||
from pyrad import packet
|
|
||||||
|
|
||||||
|
|
||||||
class PacketError(Exception):
|
|
||||||
"""Exception class for bogus packets
|
|
||||||
|
|
||||||
PacketError exceptions are only used inside the Server class to
|
|
||||||
abort processing of a packet.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class RADIUS(host.Host, protocol.DatagramProtocol):
|
|
||||||
def __init__(self, hosts={}, dict=dictionary.Dictionary()):
|
|
||||||
host.Host.__init__(self, dict=dict)
|
|
||||||
self.hosts = hosts
|
|
||||||
|
|
||||||
def processPacket(self, pkt):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def createPacket(self, **kwargs):
|
|
||||||
raise NotImplementedError('Attempted to use a pure base class')
|
|
||||||
|
|
||||||
def datagramReceived(self, datagram, source):
|
|
||||||
host, port = source
|
|
||||||
try:
|
|
||||||
pkt = self.CreatePacket(packet=datagram)
|
|
||||||
except packet.PacketError as err:
|
|
||||||
log.msg('Dropping invalid packet: ' + str(err))
|
|
||||||
return
|
|
||||||
|
|
||||||
if host not in self.hosts:
|
|
||||||
log.msg('Dropping packet from unknown host ' + host)
|
|
||||||
return
|
|
||||||
|
|
||||||
pkt.source = (host, port)
|
|
||||||
try:
|
|
||||||
self.processPacket(pkt)
|
|
||||||
except PacketError as err:
|
|
||||||
log.msg('Dropping packet from %s: %s' % (host, str(err)))
|
|
||||||
|
|
||||||
|
|
||||||
class RADIUSAccess(RADIUS):
|
|
||||||
def createPacket(self, **kwargs):
|
|
||||||
self.CreateAuthPacket(**kwargs)
|
|
||||||
|
|
||||||
def processPacket(self, pkt):
|
|
||||||
if pkt.code != packet.AccessRequest:
|
|
||||||
raise PacketError(
|
|
||||||
'non-AccessRequest packet on authentication socket')
|
|
||||||
|
|
||||||
|
|
||||||
class RADIUSAccounting(RADIUS):
|
|
||||||
def createPacket(self, **kwargs):
|
|
||||||
self.CreateAcctPacket(**kwargs)
|
|
||||||
|
|
||||||
def processPacket(self, pkt):
|
|
||||||
if pkt.code != packet.AccountingRequest:
|
|
||||||
raise PacketError(
|
|
||||||
'non-AccountingRequest packet on authentication socket')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
log.startLogging(sys.stdout, 0)
|
|
||||||
reactor.listenUDP(1812, RADIUSAccess())
|
|
||||||
reactor.listenUDP(1813, RADIUSAccounting())
|
|
||||||
reactor.run()
|
|
||||||
@@ -1,116 +0,0 @@
|
|||||||
# dictfile.py
|
|
||||||
#
|
|
||||||
# Copyright 2009 Kristoffer Gronlund <kristoffer.gronlund@purplescout.se>
|
|
||||||
|
|
||||||
""" Dictionary File
|
|
||||||
|
|
||||||
Implements an iterable file format that handles the
|
|
||||||
RADIUS $INCLUDE directives behind the scene.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class _Node():
|
|
||||||
"""Dictionary file node
|
|
||||||
|
|
||||||
A single dictionary file.
|
|
||||||
"""
|
|
||||||
__slots__ = ('name', 'lines', 'current', 'length', 'dir')
|
|
||||||
|
|
||||||
def __init__(self, fd, name, parentdir):
|
|
||||||
self.lines = fd.readlines()
|
|
||||||
self.length = len(self.lines)
|
|
||||||
self.current = 0
|
|
||||||
self.name = os.path.basename(name)
|
|
||||||
path = os.path.dirname(name)
|
|
||||||
if os.path.isabs(path):
|
|
||||||
self.dir = path
|
|
||||||
else:
|
|
||||||
self.dir = os.path.join(parentdir, path)
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
if self.current >= self.length:
|
|
||||||
return None
|
|
||||||
self.current += 1
|
|
||||||
return self.lines[self.current - 1]
|
|
||||||
|
|
||||||
|
|
||||||
class DictFile():
|
|
||||||
"""Dictionary file class
|
|
||||||
|
|
||||||
An iterable file type that handles $INCLUDE
|
|
||||||
directives internally.
|
|
||||||
"""
|
|
||||||
__slots__ = ['stack']
|
|
||||||
|
|
||||||
def __init__(self, fil):
|
|
||||||
"""
|
|
||||||
@param fil: a dictionary file to parse
|
|
||||||
@type fil: string or file
|
|
||||||
"""
|
|
||||||
self.stack = []
|
|
||||||
self.__read_node(fil)
|
|
||||||
|
|
||||||
def __read_node(self, fil):
|
|
||||||
node = None
|
|
||||||
parentdir = self.__cur_dir()
|
|
||||||
if isinstance(fil, str):
|
|
||||||
fname = None
|
|
||||||
if os.path.isabs(fil):
|
|
||||||
fname = fil
|
|
||||||
else:
|
|
||||||
fname = os.path.join(parentdir, fil)
|
|
||||||
fd = open(fname, "rt")
|
|
||||||
node = _Node(fd, fil, parentdir)
|
|
||||||
fd.close()
|
|
||||||
else:
|
|
||||||
node = _Node(fil, '', parentdir)
|
|
||||||
self.stack.append(node)
|
|
||||||
|
|
||||||
def __cur_dir(self):
|
|
||||||
if self.stack:
|
|
||||||
return self.stack[-1].dir
|
|
||||||
return os.path.realpath(os.curdir)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __get_include(line):
|
|
||||||
line = line.split("#", 1)[0].strip()
|
|
||||||
tokens = line.split()
|
|
||||||
if tokens and tokens[0].upper() == '$INCLUDE':
|
|
||||||
return " ".join(tokens[1:])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def line(self):
|
|
||||||
"""Returns line number of current file
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.stack[-1].current
|
|
||||||
except (AttributeError, IndexError):
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def file(self):
|
|
||||||
"""Returns name of current file
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return self.stack[-1].name
|
|
||||||
except (AttributeError, IndexError):
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
while self.stack:
|
|
||||||
line = self.stack[-1].next()
|
|
||||||
if line is None:
|
|
||||||
self.stack.pop()
|
|
||||||
else:
|
|
||||||
inc = DictFile.__get_include(line)
|
|
||||||
if inc:
|
|
||||||
self.__read_node(inc)
|
|
||||||
else:
|
|
||||||
return line
|
|
||||||
raise StopIteration
|
|
||||||
|
|
||||||
next = __next__ # BBB for python <3
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
# dictionary.py
|
|
||||||
#
|
|
||||||
# Copyright 2002,2005,2007,2016 Wichert Akkerman <wichert@wiggy.net>
|
|
||||||
"""
|
|
||||||
RADIUS uses dictionaries to define the attributes that can
|
|
||||||
be used in packets. The Dictionary class stores the attribute
|
|
||||||
definitions from one or more dictionary files.
|
|
||||||
|
|
||||||
Dictionary files are textfiles with one command per line.
|
|
||||||
Comments are specified by starting with a # character, and empty
|
|
||||||
lines are ignored.
|
|
||||||
|
|
||||||
The commands supported are::
|
|
||||||
|
|
||||||
ATTRIBUTE <attribute> <code> <type> [<vendor>]
|
|
||||||
specify an attribute and its type
|
|
||||||
|
|
||||||
VALUE <attribute> <valuename> <value>
|
|
||||||
specify a value attribute
|
|
||||||
|
|
||||||
VENDOR <name> <id>
|
|
||||||
specify a vendor ID
|
|
||||||
|
|
||||||
BEGIN-VENDOR <vendorname>
|
|
||||||
begin definition of vendor attributes
|
|
||||||
|
|
||||||
END-VENDOR <vendorname>
|
|
||||||
end definition of vendor attributes
|
|
||||||
|
|
||||||
|
|
||||||
The datatypes currently supported are:
|
|
||||||
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| type | description |
|
|
||||||
+===============+==============================================+
|
|
||||||
| string | ASCII string |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| ipaddr | IPv4 address |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| date | 32 bits UNIX |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| octets | arbitrary binary data |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| abinary | ascend binary data |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| ipv6addr | 16 octets in network byte order |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| ipv6prefix | 18 octets in network byte order |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| integer | 32 bits unsigned number |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| signed | 32 bits signed number |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| short | 16 bits unsigned number |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| byte | 8 bits unsigned number |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| tlv | Nested tag-length-value |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| integer64 | 64 bits unsigned number |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
|
|
||||||
These datatypes are parsed but not supported:
|
|
||||||
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| type | description |
|
|
||||||
+===============+==============================================+
|
|
||||||
| ifid | 8 octets in network byte order |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
| ether | 6 octets of hh:hh:hh:hh:hh:hh |
|
|
||||||
| | where 'h' is hex digits, upper or lowercase. |
|
|
||||||
+---------------+----------------------------------------------+
|
|
||||||
"""
|
|
||||||
from copy import copy
|
|
||||||
|
|
||||||
from pyrad import bidict
|
|
||||||
from pyrad import tools
|
|
||||||
from pyrad import dictfile
|
|
||||||
|
|
||||||
__docformat__ = 'epytext en'
|
|
||||||
|
|
||||||
|
|
||||||
DATATYPES = frozenset(['string', 'ipaddr', 'integer', 'date', 'octets',
|
|
||||||
'abinary', 'ipv6addr', 'ipv6prefix', 'short', 'byte',
|
|
||||||
'signed', 'ifid', 'ether', 'tlv', 'integer64'])
|
|
||||||
|
|
||||||
|
|
||||||
class ParseError(Exception):
|
|
||||||
"""Dictionary parser exceptions.
|
|
||||||
|
|
||||||
:ivar msg: Error message
|
|
||||||
:type msg: string
|
|
||||||
:ivar linenumber: Line number on which the error occurred
|
|
||||||
:type linenumber: integer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, msg=None, **data):
|
|
||||||
super().__init__()
|
|
||||||
self.msg = msg
|
|
||||||
self.file = data.get('file', '')
|
|
||||||
self.line = data.get('line', -1)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
line = f'({self.line})' if self.line > -1 else ''
|
|
||||||
return f'{self.file}{line}: ParseError: {self.msg}'
|
|
||||||
|
|
||||||
|
|
||||||
class Attribute():
|
|
||||||
def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', values=None,
|
|
||||||
encrypt=0, has_tag=False):
|
|
||||||
if datatype not in DATATYPES:
|
|
||||||
raise ValueError('Invalid data type')
|
|
||||||
self.name = name
|
|
||||||
self.code = code
|
|
||||||
self.type = datatype
|
|
||||||
self.vendor = vendor
|
|
||||||
self.encrypt = encrypt
|
|
||||||
self.has_tag = has_tag
|
|
||||||
self.values = bidict.BiDict()
|
|
||||||
self.sub_attributes = {}
|
|
||||||
self.parent = None
|
|
||||||
self.is_sub_attribute = is_sub_attribute
|
|
||||||
if values:
|
|
||||||
for (key, value) in values.items():
|
|
||||||
self.values.add(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
class Dictionary():
|
|
||||||
"""RADIUS dictionary class.
|
|
||||||
This class stores all information about vendors, attributes and their
|
|
||||||
values as defined in RADIUS dictionary files.
|
|
||||||
|
|
||||||
:ivar vendors: bidict mapping vendor name to vendor code
|
|
||||||
:type vendors: bidict
|
|
||||||
:ivar attrindex: bidict mapping
|
|
||||||
:type attrindex: bidict
|
|
||||||
:ivar attributes: bidict mapping attribute name to attribute class
|
|
||||||
:type attributes: bidict
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dict=None, *dicts):
|
|
||||||
"""
|
|
||||||
:param dict: path of dictionary file or file-like object to read
|
|
||||||
:type dict: string or file
|
|
||||||
:param dicts: list of dictionaries
|
|
||||||
:type dicts: sequence of strings or files
|
|
||||||
"""
|
|
||||||
self.vendors = bidict.BiDict()
|
|
||||||
self.vendors.add('', 0)
|
|
||||||
self.attrindex = bidict.BiDict()
|
|
||||||
self.attributes = {}
|
|
||||||
self.defer_parse = []
|
|
||||||
|
|
||||||
if dict:
|
|
||||||
self.read_dictionary(dict)
|
|
||||||
|
|
||||||
for i in dicts:
|
|
||||||
self.read_dictionary(i)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.attributes)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.attributes[key]
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
return key in self.attributes
|
|
||||||
|
|
||||||
has_key = __contains__
|
|
||||||
|
|
||||||
def __parse_attribute(self, state, tokens):
|
|
||||||
if not len(tokens) in [4, 5]:
|
|
||||||
raise ParseError(
|
|
||||||
'Incorrect number of tokens for attribute definition',
|
|
||||||
name=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
vendor = state['vendor']
|
|
||||||
has_tag = False
|
|
||||||
encrypt = 0
|
|
||||||
if len(tokens) >= 5:
|
|
||||||
def keyval(o):
|
|
||||||
kv = o.split('=')
|
|
||||||
if len(kv) == 2:
|
|
||||||
return (kv[0], kv[1])
|
|
||||||
else:
|
|
||||||
return (kv[0], None)
|
|
||||||
options = [keyval(o) for o in tokens[4].split(',')]
|
|
||||||
|
|
||||||
for (key, val) in options:
|
|
||||||
if key == 'has_tag':
|
|
||||||
has_tag = True
|
|
||||||
elif key == 'encrypt':
|
|
||||||
if val not in ['1', '2', '3']:
|
|
||||||
raise ParseError(
|
|
||||||
f'Illegal attribute encryption: {val}',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
encrypt = int(val)
|
|
||||||
|
|
||||||
if (not has_tag) and encrypt == 0:
|
|
||||||
vendor = tokens[4]
|
|
||||||
if not self.vendors.has_forward(vendor):
|
|
||||||
if vendor == "concat":
|
|
||||||
# ignore attributes with concat (freeradius compat.)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
raise ParseError('Unknown vendor ' + vendor,
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
(attribute, code, datatype) = tokens[1:4]
|
|
||||||
|
|
||||||
codes = code.split('.')
|
|
||||||
|
|
||||||
# Codes can be sent as hex, or octal or decimal string representations.
|
|
||||||
tmp = []
|
|
||||||
for c in codes:
|
|
||||||
if c.startswith('0x'):
|
|
||||||
tmp.append(int(c, 16))
|
|
||||||
elif c.startswith('0o'):
|
|
||||||
tmp.append(int(c, 8))
|
|
||||||
else:
|
|
||||||
tmp.append(int(c, 10))
|
|
||||||
codes = tmp
|
|
||||||
|
|
||||||
is_sub_attribute = (len(codes) > 1)
|
|
||||||
if len(codes) == 2:
|
|
||||||
code = int(codes[1])
|
|
||||||
parent_code = int(codes[0])
|
|
||||||
elif len(codes) == 1:
|
|
||||||
code = int(codes[0])
|
|
||||||
parent_code = None
|
|
||||||
else:
|
|
||||||
raise ParseError('nested tlvs are not supported')
|
|
||||||
|
|
||||||
datatype = datatype.split("[")[0]
|
|
||||||
|
|
||||||
if datatype not in DATATYPES:
|
|
||||||
raise ParseError('Illegal type: ' + datatype,
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
if vendor:
|
|
||||||
if is_sub_attribute:
|
|
||||||
key = (self.vendors.get_forward(vendor), parent_code, code)
|
|
||||||
else:
|
|
||||||
key = (self.vendors.get_forward(vendor), code)
|
|
||||||
else:
|
|
||||||
if is_sub_attribute:
|
|
||||||
key = (parent_code, code)
|
|
||||||
else:
|
|
||||||
key = code
|
|
||||||
|
|
||||||
self.attrindex.add(attribute, key)
|
|
||||||
self.attributes[attribute] = Attribute(attribute, code, datatype, is_sub_attribute,
|
|
||||||
vendor, encrypt=encrypt, has_tag=has_tag)
|
|
||||||
if datatype == 'tlv':
|
|
||||||
# save attribute in tlvs
|
|
||||||
state['tlvs'][code] = self.attributes[attribute]
|
|
||||||
if is_sub_attribute:
|
|
||||||
# save sub attribute in parent tlv and update their parent field
|
|
||||||
state['tlvs'][parent_code].sub_attributes[code] = attribute
|
|
||||||
self.attributes[attribute].parent = state['tlvs'][parent_code]
|
|
||||||
|
|
||||||
def __parse_value(self, state, tokens, defer):
|
|
||||||
if len(tokens) != 4:
|
|
||||||
raise ParseError('Incorrect number of tokens for value definition',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
(attr, key, value) = tokens[1:]
|
|
||||||
|
|
||||||
try:
|
|
||||||
adef = self.attributes[attr]
|
|
||||||
except KeyError:
|
|
||||||
if defer:
|
|
||||||
self.defer_parse.append((copy(state), copy(tokens)))
|
|
||||||
return
|
|
||||||
raise ParseError('Value defined for unknown attribute ' + attr,
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
if adef.type in ['integer', 'signed', 'short', 'byte', 'integer64']:
|
|
||||||
value = int(value, 0)
|
|
||||||
value = tools.EncodeAttr(adef.type, value)
|
|
||||||
self.attributes[attr].values.add(key, value)
|
|
||||||
|
|
||||||
def __parse_vendor(self, state, tokens):
|
|
||||||
if len(tokens) not in [3, 4]:
|
|
||||||
raise ParseError(
|
|
||||||
'Incorrect number of tokens for vendor definition',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
# Parse format specification, but do
|
|
||||||
# nothing about it for now
|
|
||||||
if len(tokens) == 4:
|
|
||||||
fmt = tokens[3].split('=')
|
|
||||||
if fmt[0] != 'format':
|
|
||||||
raise ParseError(
|
|
||||||
f"Unknown option '{fmt[0]}' for vendor definition",
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
try:
|
|
||||||
(t, l) = tuple(int(a) for a in fmt[1].split(','))
|
|
||||||
if t not in [1, 2, 4] or l not in [0, 1, 2]:
|
|
||||||
raise ParseError(
|
|
||||||
f'Unknown vendor format specification {fmt[1]}',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
except ValueError:
|
|
||||||
raise ParseError(
|
|
||||||
'Syntax error in vendor specification',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
(vendorname, vendor) = tokens[1:3]
|
|
||||||
self.vendors.add(vendorname, int(vendor, 0))
|
|
||||||
|
|
||||||
def __parse_begin_vendor(self, state, tokens):
|
|
||||||
if len(tokens) != 2:
|
|
||||||
raise ParseError(
|
|
||||||
'Incorrect number of tokens for begin-vendor statement',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
vendor = tokens[1]
|
|
||||||
|
|
||||||
if not self.vendors.has_forward(vendor):
|
|
||||||
raise ParseError(
|
|
||||||
f'Unknown vendor {vendor} in begin-vendor statement',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
state['vendor'] = vendor
|
|
||||||
|
|
||||||
def __parse_end_vendor(self, state, tokens):
|
|
||||||
if len(tokens) != 2:
|
|
||||||
raise ParseError(
|
|
||||||
'Incorrect number of tokens for end-vendor statement',
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
|
|
||||||
vendor = tokens[1]
|
|
||||||
|
|
||||||
if state['vendor'] != vendor:
|
|
||||||
raise ParseError(
|
|
||||||
'Ending non-open vendor' + vendor,
|
|
||||||
file=state['file'],
|
|
||||||
line=state['line'])
|
|
||||||
state['vendor'] = ''
|
|
||||||
|
|
||||||
def read_dictionary(self, file):
|
|
||||||
"""Parse a dictionary file.
|
|
||||||
Reads a RADIUS dictionary file and merges its contents into the
|
|
||||||
class instance.
|
|
||||||
|
|
||||||
:param file: Name of dictionary file to parse or a file-like object
|
|
||||||
:type file: string or file-like object
|
|
||||||
"""
|
|
||||||
|
|
||||||
fil = dictfile.DictFile(file)
|
|
||||||
|
|
||||||
state = {}
|
|
||||||
state['vendor'] = ''
|
|
||||||
state['tlvs'] = {}
|
|
||||||
self.defer_parse = []
|
|
||||||
for line in fil:
|
|
||||||
state['file'] = fil.file()
|
|
||||||
state['line'] = fil.line()
|
|
||||||
line = line.split('#', 1)[0].strip()
|
|
||||||
|
|
||||||
tokens = line.split()
|
|
||||||
if not tokens:
|
|
||||||
continue
|
|
||||||
|
|
||||||
key = tokens[0].upper()
|
|
||||||
if key == 'ATTRIBUTE':
|
|
||||||
self.__parse_attribute(state, tokens)
|
|
||||||
elif key == 'VALUE':
|
|
||||||
self.__parse_value(state, tokens, True)
|
|
||||||
elif key == 'VENDOR':
|
|
||||||
self.__parse_vendor(state, tokens)
|
|
||||||
elif key == 'BEGIN-VENDOR':
|
|
||||||
self.__parse_begin_vendor(state, tokens)
|
|
||||||
elif key == 'END-VENDOR':
|
|
||||||
self.__parse_end_vendor(state, tokens)
|
|
||||||
|
|
||||||
for state, tokens in self.defer_parse:
|
|
||||||
key = tokens[0].upper()
|
|
||||||
if key == 'VALUE':
|
|
||||||
self.__parse_value(state, tokens, False)
|
|
||||||
self.defer_parse = []
|
|
||||||
100
pyrad3/host.py
100
pyrad3/host.py
@@ -1,100 +0,0 @@
|
|||||||
# host.py
|
|
||||||
#
|
|
||||||
# Copyright 2003,2007 Wichert Akkerman <wichert@wiggy.net>
|
|
||||||
from pyrad import packet
|
|
||||||
|
|
||||||
|
|
||||||
class Host(object):
|
|
||||||
"""Generic RADIUS capable host.
|
|
||||||
|
|
||||||
:ivar dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary
|
|
||||||
:ivar authport: port to listen on for authentication packets
|
|
||||||
:type authport: integer
|
|
||||||
:ivar acctport: port to listen on for accounting packets
|
|
||||||
:type acctport: integer
|
|
||||||
"""
|
|
||||||
def __init__(self, authport=1812, acctport=1813, coaport=3799, dict=None):
|
|
||||||
"""Constructor
|
|
||||||
|
|
||||||
:param authport: port to listen on for authentication packets
|
|
||||||
:type authport: integer
|
|
||||||
:param acctport: port to listen on for accounting packets
|
|
||||||
:type acctport: integer
|
|
||||||
:param coaport: port to listen on for CoA packets
|
|
||||||
:type coaport: integer
|
|
||||||
:param dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary
|
|
||||||
"""
|
|
||||||
self.dict = dict
|
|
||||||
self.authport = authport
|
|
||||||
self.acctport = acctport
|
|
||||||
self.coaport = coaport
|
|
||||||
|
|
||||||
def CreatePacket(self, **args):
|
|
||||||
"""Create a new RADIUS packet.
|
|
||||||
This utility function creates a new RADIUS authentication
|
|
||||||
packet which can be used to communicate with the RADIUS server
|
|
||||||
this client talks to. This is initializing the new packet with
|
|
||||||
the dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.Packet
|
|
||||||
"""
|
|
||||||
return packet.Packet(dict=self.dict, **args)
|
|
||||||
|
|
||||||
def CreateAuthPacket(self, **args):
|
|
||||||
"""Create a new authentication RADIUS packet.
|
|
||||||
This utility function creates a new RADIUS authentication
|
|
||||||
packet which can be used to communicate with the RADIUS server
|
|
||||||
this client talks to. This is initializing the new packet with
|
|
||||||
the dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.AuthPacket
|
|
||||||
"""
|
|
||||||
return packet.AuthPacket(dict=self.dict, **args)
|
|
||||||
|
|
||||||
def CreateAcctPacket(self, **args):
|
|
||||||
"""Create a new accounting RADIUS packet.
|
|
||||||
This utility function creates a new accouting RADIUS packet
|
|
||||||
which can be used to communicate with the RADIUS server this
|
|
||||||
client talks to. This is initializing the new packet with the
|
|
||||||
dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.AcctPacket
|
|
||||||
"""
|
|
||||||
return packet.AcctPacket(dict=self.dict, **args)
|
|
||||||
|
|
||||||
def CreateCoAPacket(self, **args):
|
|
||||||
"""Create a new CoA RADIUS packet.
|
|
||||||
This utility function creates a new CoA RADIUS packet
|
|
||||||
which can be used to communicate with the RADIUS server this
|
|
||||||
client talks to. This is initializing the new packet with the
|
|
||||||
dictionary and secret used for the client.
|
|
||||||
|
|
||||||
:return: a new empty packet instance
|
|
||||||
:rtype: pyrad.packet.CoAPacket
|
|
||||||
"""
|
|
||||||
return packet.CoAPacket(dict=self.dict, **args)
|
|
||||||
|
|
||||||
def SendPacket(self, fd, pkt):
|
|
||||||
"""Send a packet.
|
|
||||||
|
|
||||||
:param fd: socket to send packet with
|
|
||||||
:type fd: socket class instance
|
|
||||||
:param pkt: packet to send
|
|
||||||
:type pkt: Packet class instance
|
|
||||||
"""
|
|
||||||
fd.sendto(pkt.Packet(), pkt.source)
|
|
||||||
|
|
||||||
def SendReplyPacket(self, fd, pkt):
|
|
||||||
"""Send a packet.
|
|
||||||
|
|
||||||
:param fd: socket to send packet with
|
|
||||||
:type fd: socket class instance
|
|
||||||
:param pkt: packet to send
|
|
||||||
:type pkt: Packet class instance
|
|
||||||
"""
|
|
||||||
fd.sendto(pkt.ReplyPacket(), pkt.source)
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
# Copyright 2020 Istvan Ruzman
|
|
||||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
|
||||||
from pyrad3 import new_packet as packet
|
|
||||||
|
|
||||||
class Host:
|
|
||||||
def __init__(self, secret, radius_dict,
|
|
||||||
authport=1812, acctport=1813, coaport=3799,
|
|
||||||
timeout=30, retries=3):
|
|
||||||
self.secret = secret
|
|
||||||
self.dictionary = radius_dict
|
|
||||||
|
|
||||||
self.authport = authport
|
|
||||||
self.acctport = acctport
|
|
||||||
self.coaport = coaport
|
|
||||||
|
|
||||||
self.timeout = timeout
|
|
||||||
self.retries = retries
|
|
||||||
|
|
||||||
def create_packet(self, **kwargs):
|
|
||||||
return packet.Packet(self, **kwargs)
|
|
||||||
|
|
||||||
def create_auth_packet(self, **kwargs):
|
|
||||||
return packet.AuthPacket(self, **kwargs)
|
|
||||||
|
|
||||||
def create_acct_packet(self, **kwargs):
|
|
||||||
return packet.AcctPacket(self, **kwargs)
|
|
||||||
|
|
||||||
def create_coa_packet(self, **kwargs):
|
|
||||||
return packet.CoAPacket(self, **kwargs)
|
|
||||||
@@ -1,143 +0,0 @@
|
|||||||
# Copyright 2020 Istvan Ruzman
|
|
||||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import secrets
|
|
||||||
import struct
|
|
||||||
|
|
||||||
# Packet codes
|
|
||||||
AccessRequest = 1
|
|
||||||
AccessAccept = 2
|
|
||||||
AccessReject = 3
|
|
||||||
AccountingRequest = 4
|
|
||||||
AccountingResponse = 5
|
|
||||||
AccessChallenge = 11
|
|
||||||
StatusServer = 12
|
|
||||||
StatusClient = 13
|
|
||||||
DisconnectRequest = 40
|
|
||||||
DisconnectACK = 41
|
|
||||||
DisconnectNAK = 42
|
|
||||||
CoARequest = 43
|
|
||||||
CoAACK = 44
|
|
||||||
CoANAK = 45
|
|
||||||
|
|
||||||
|
|
||||||
class PacketError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class AuthError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
RANDOM_GENERATOR = secrets.SystemRandom()
|
|
||||||
|
|
||||||
|
|
||||||
class Packet(OrderedDict):
|
|
||||||
def __init__(self, host, code, radius_id, *, request=None, **attributes):
|
|
||||||
super().__init__(**attributes)
|
|
||||||
self.code = code
|
|
||||||
self.id = radius_id
|
|
||||||
self.host = host
|
|
||||||
self.request = request
|
|
||||||
|
|
||||||
def send(self):
|
|
||||||
self.host._send_packet(self)
|
|
||||||
|
|
||||||
def verify_reply(self, raw_reply):
|
|
||||||
if self.id != raw_reply[1]:
|
|
||||||
raise PacketError("Response has a wrong id")
|
|
||||||
|
|
||||||
radius_hash = self.calculate_radius_hash(raw_reply)
|
|
||||||
|
|
||||||
if radius_hash != raw_reply[4:20]:
|
|
||||||
raise PacketError("Reply Packet has a wrong authenticator")
|
|
||||||
|
|
||||||
return self.parse_raw_packet(raw_reply)
|
|
||||||
|
|
||||||
def calculate_radius_hash(self, data, authenticator=None):
|
|
||||||
if authenticator is None:
|
|
||||||
authenticator = self.authenticator
|
|
||||||
return hashlib.md5(data[0:4] + authenticator +
|
|
||||||
data[20:] + self.secret).digest()
|
|
||||||
|
|
||||||
def parse_raw_reply(self, raw_packet):
|
|
||||||
(code, radius_id, length, authenticator) = self._parse_header(raw_packet)
|
|
||||||
attrs = self.parse_attributes(raw_packet)
|
|
||||||
return Packet(self.host, code, radius_id, **attrs, request=self)
|
|
||||||
|
|
||||||
def _parse_header(self, raw_packet):
|
|
||||||
try:
|
|
||||||
header = struct.unpack('!BBH16s', raw_packet)
|
|
||||||
except struct.error:
|
|
||||||
raise PacketError('Packet header is corrupt')
|
|
||||||
|
|
||||||
length = header[3]
|
|
||||||
if len(raw_packet) != length:
|
|
||||||
raise PacketError('Packet has invalid length')
|
|
||||||
if length > 4096:
|
|
||||||
raise PacketError(f'Packet length is too big ({length})')
|
|
||||||
return header
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacket(Packet):
|
|
||||||
def __init__(self, host, radius_id, auth_type, *,
|
|
||||||
code=AccessRequest, request=None,
|
|
||||||
**attributes):
|
|
||||||
super().__init__(host, code, radius_id,
|
|
||||||
request=request, **attributes)
|
|
||||||
self.auth_type = auth_type
|
|
||||||
if code == AccessRequest:
|
|
||||||
self.authenticator = secrets.token_bytes(16)
|
|
||||||
|
|
||||||
def create_accept(self, **attributes):
|
|
||||||
return AuthPacket(self.host, self.id, self.auth_type,
|
|
||||||
request=self,
|
|
||||||
code=AccessAccept
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
def create_reject(self, **attributes):
|
|
||||||
return AuthPacket(self.host, self.id, self.auth_type,
|
|
||||||
request=self,
|
|
||||||
code=AccessReject,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
def create_challange(self, **attributes):
|
|
||||||
return AuthPacket(self.host, self.id, self.auth_type,
|
|
||||||
request=self,
|
|
||||||
code=AccessChallenge,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
|
|
||||||
class AcctPacket(Packet):
|
|
||||||
def __init__(self, host, radius_id, *,
|
|
||||||
code=AccountingRequest, request=None,
|
|
||||||
**attributes):
|
|
||||||
super().__init__(host, code, radius_id,
|
|
||||||
request=request, **attributes)
|
|
||||||
|
|
||||||
def create_response(self, **attributes):
|
|
||||||
return AcctPacket(self.host, self.id,
|
|
||||||
code=AccountingResponse,
|
|
||||||
request=self,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
|
|
||||||
class CoAPacket(Packet):
|
|
||||||
def __init__(self, host, radius_id, *,
|
|
||||||
code=CoARequest, request=None,
|
|
||||||
**attributes):
|
|
||||||
super().__init__(host, code, radius_id,
|
|
||||||
request=request, **attributes)
|
|
||||||
|
|
||||||
def create_ack(self, **attributes):
|
|
||||||
return CoAPacket(self.host, self.id,
|
|
||||||
code=CoAACK,
|
|
||||||
request=self,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
def create_nack(self, **attributes):
|
|
||||||
return CoAPacket(self.host, self.id,
|
|
||||||
code=CoANACK,
|
|
||||||
request=self,
|
|
||||||
**attributes)
|
|
||||||
869
pyrad3/packet.py
869
pyrad3/packet.py
@@ -1,869 +0,0 @@
|
|||||||
# packet.py
|
|
||||||
#
|
|
||||||
# Copyright 2002-2005,2007 Wichert Akkerman <wichert@wiggy.net>
|
|
||||||
#
|
|
||||||
# A RADIUS packet as defined in RFC 2138
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import secrets
|
|
||||||
import struct
|
|
||||||
from pyrad import tools
|
|
||||||
|
|
||||||
md5_constructor = hashlib.md5
|
|
||||||
|
|
||||||
# Packet codes
|
|
||||||
AccessRequest = 1
|
|
||||||
AccessAccept = 2
|
|
||||||
AccessReject = 3
|
|
||||||
AccountingRequest = 4
|
|
||||||
AccountingResponse = 5
|
|
||||||
AccessChallenge = 11
|
|
||||||
StatusServer = 12
|
|
||||||
StatusClient = 13
|
|
||||||
DisconnectRequest = 40
|
|
||||||
DisconnectACK = 41
|
|
||||||
DisconnectNAK = 42
|
|
||||||
CoARequest = 43
|
|
||||||
CoAACK = 44
|
|
||||||
CoANAK = 45
|
|
||||||
|
|
||||||
# Use cryptographic-safe random generator as provided by the OS.
|
|
||||||
random_generator = secrets.SystemRandom()
|
|
||||||
|
|
||||||
# Current ID
|
|
||||||
CurrentID = random_generator.randrange(1, 255)
|
|
||||||
|
|
||||||
|
|
||||||
class PacketError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Packet(OrderedDict):
|
|
||||||
"""Packet acts like a standard python map to provide simple access
|
|
||||||
to the RADIUS attributes. Since RADIUS allows for repeated
|
|
||||||
attributes the value will always be a sequence. pyrad makes sure
|
|
||||||
to preserve the ordering when encoding and decoding packets.
|
|
||||||
|
|
||||||
There are two ways to use the map intereface: if attribute
|
|
||||||
names are used pyrad take care of en-/decoding data. If
|
|
||||||
the attribute type number (or a vendor ID/attribute type
|
|
||||||
tuple for vendor attributes) is used you work with the
|
|
||||||
raw data.
|
|
||||||
|
|
||||||
Normally you will not use this class directly, but one of the
|
|
||||||
:obj:`AuthPacket` or :obj:`AcctPacket` classes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, code=0, id=None, secret=b'', authenticator=None,
|
|
||||||
**attributes):
|
|
||||||
"""Constructor
|
|
||||||
|
|
||||||
:param dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary class
|
|
||||||
:param secret: secret needed to communicate with a RADIUS server
|
|
||||||
:type secret: string
|
|
||||||
:param id: packet identification number
|
|
||||||
:type id: integer (8 bits)
|
|
||||||
:param code: packet type code
|
|
||||||
:type code: integer (8bits)
|
|
||||||
:param packet: raw packet to decode
|
|
||||||
:type packet: string
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.code = code
|
|
||||||
if id is not None:
|
|
||||||
self.id = id
|
|
||||||
else:
|
|
||||||
self.id = CreateID()
|
|
||||||
if not isinstance(secret, bytes):
|
|
||||||
raise TypeError('secret must be a binary string')
|
|
||||||
self.secret = secret
|
|
||||||
if authenticator is not None and \
|
|
||||||
not isinstance(authenticator, bytes):
|
|
||||||
raise TypeError('authenticator must be a binary string')
|
|
||||||
self.authenticator = authenticator
|
|
||||||
self.message_authenticator = None
|
|
||||||
|
|
||||||
if 'dict' in attributes:
|
|
||||||
self.dict = attributes['dict']
|
|
||||||
|
|
||||||
if 'packet' in attributes:
|
|
||||||
self.raw_packet = attributes['packet']
|
|
||||||
self.DecodePacket(attributes['packet'])
|
|
||||||
|
|
||||||
if 'message_authenticator' in attributes:
|
|
||||||
self.message_authenticator = attributes['message_authenticator']
|
|
||||||
|
|
||||||
for (key, value) in attributes.items():
|
|
||||||
if key in [
|
|
||||||
'dict', 'fd', 'packet',
|
|
||||||
'message_authenticator',
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
key = key.replace('_', '-')
|
|
||||||
self.AddAttribute(key, value)
|
|
||||||
|
|
||||||
def add_message_authenticator(self):
|
|
||||||
self.message_authenticator = True
|
|
||||||
# Maintain a zero octets content for md5 and hmac calculation.
|
|
||||||
self['Message-Authenticator'] = 16 * b'\00'
|
|
||||||
|
|
||||||
if self.id is None:
|
|
||||||
self.id = self.CreateID()
|
|
||||||
|
|
||||||
if self.authenticator is None and self.code == AccessRequest:
|
|
||||||
self.authenticator = self.CreateAuthenticator()
|
|
||||||
self._refresh_message_authenticator()
|
|
||||||
|
|
||||||
def get_message_authenticator(self):
|
|
||||||
self._refresh_message_authenticator()
|
|
||||||
return self.message_authenticator
|
|
||||||
|
|
||||||
def _refresh_message_authenticator(self):
|
|
||||||
hmac_constructor = hmac.new(self.secret, digestmod=hashlib.md5)
|
|
||||||
|
|
||||||
# Maintain a zero octets content for md5 and hmac calculation.
|
|
||||||
self['Message-Authenticator'] = 16 * b'\00'
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
|
|
||||||
header = struct.pack('!BBH', self.code, self.id,
|
|
||||||
(20 + len(attr)))
|
|
||||||
|
|
||||||
hmac_constructor.update(header[0:4])
|
|
||||||
if self.code in (AccountingRequest, DisconnectRequest,
|
|
||||||
CoARequest, AccountingResponse):
|
|
||||||
hmac_constructor.update(16 * b'\00')
|
|
||||||
else:
|
|
||||||
# NOTE: self.authenticator on reply packet is initialized
|
|
||||||
# with request authenticator by design.
|
|
||||||
# For AccessAccept, AccessReject and AccessChallenge
|
|
||||||
# it is needed use original Authenticator.
|
|
||||||
# For AccessAccept, AccessReject and AccessChallenge
|
|
||||||
# it is needed use original Authenticator.
|
|
||||||
if self.authenticator is None:
|
|
||||||
raise Exception('No authenticator found')
|
|
||||||
hmac_constructor.update(self.authenticator)
|
|
||||||
|
|
||||||
hmac_constructor.update(attr)
|
|
||||||
self['Message-Authenticator'] = hmac_constructor.digest()
|
|
||||||
|
|
||||||
def verify_message_authenticator(self,
|
|
||||||
original_authenticator=None,
|
|
||||||
original_code=None):
|
|
||||||
"""Verify packet Message-Authenticator.
|
|
||||||
|
|
||||||
:return: False if verification failed else True
|
|
||||||
:rtype: boolean
|
|
||||||
"""
|
|
||||||
if self.message_authenticator is None:
|
|
||||||
raise Exception('No Message-Authenticator AVP present')
|
|
||||||
|
|
||||||
prev_ma = self['Message-Authenticator']
|
|
||||||
|
|
||||||
self['Message-Authenticator'] = 16 * b'\00'
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
|
|
||||||
header = struct.pack('!BBH', self.code, self.id,
|
|
||||||
(20 + len(attr)))
|
|
||||||
|
|
||||||
hmac_constructor = hmac.new(self.secret, digestmod=hashlib.md5)
|
|
||||||
hmac_constructor.update(header)
|
|
||||||
if self.code in (AccountingRequest, DisconnectRequest,
|
|
||||||
CoARequest, AccountingResponse):
|
|
||||||
if original_code is None or original_code != StatusServer:
|
|
||||||
# TODO: Handle Status-Server response correctly.
|
|
||||||
hmac_constructor.update(16 * b'\00')
|
|
||||||
elif self.code in (AccessAccept, AccessChallenge,
|
|
||||||
AccessReject):
|
|
||||||
if original_authenticator is None:
|
|
||||||
if self.authenticator:
|
|
||||||
# NOTE: self.authenticator on reply packet is initialized
|
|
||||||
# with request authenticator by design.
|
|
||||||
original_authenticator = self.authenticator
|
|
||||||
else:
|
|
||||||
raise Exception('Missing original authenticator')
|
|
||||||
|
|
||||||
hmac_constructor.update(original_authenticator)
|
|
||||||
else:
|
|
||||||
# On Access-Request and Status-Server use dynamic authenticator
|
|
||||||
hmac_constructor.update(self.authenticator)
|
|
||||||
|
|
||||||
hmac_constructor.update(attr)
|
|
||||||
self['Message-Authenticator'] = prev_ma[0]
|
|
||||||
return prev_ma[0] == hmac_constructor.digest()
|
|
||||||
|
|
||||||
def CreateReply(self, **attributes):
|
|
||||||
"""Create a new packet as a reply to this one. This method
|
|
||||||
makes sure the authenticator and secret are copied over
|
|
||||||
to the new instance.
|
|
||||||
"""
|
|
||||||
return Packet(id=self.id, secret=self.secret,
|
|
||||||
authenticator=self.authenticator, dict=self.dict,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
def _decode_value(self, attr, value):
|
|
||||||
try:
|
|
||||||
return attr.values.get_backward(value)
|
|
||||||
except KeyError:
|
|
||||||
return tools.DecodeAttr(attr.type, value)
|
|
||||||
|
|
||||||
def _encode_value(self, attr, value):
|
|
||||||
try:
|
|
||||||
result = attr.values.get_forward(value)
|
|
||||||
except KeyError:
|
|
||||||
result = tools.EncodeAttr(attr.type, value)
|
|
||||||
|
|
||||||
if attr.encrypt == 2:
|
|
||||||
# salt encrypt attribute
|
|
||||||
result = self.SaltCrypt(result)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _encode_key_values(self, key, values):
|
|
||||||
if not isinstance(key, str):
|
|
||||||
return (key, values)
|
|
||||||
|
|
||||||
if not isinstance(values, (list, tuple)):
|
|
||||||
values = [values]
|
|
||||||
|
|
||||||
key, _, tag = key.partition(":")
|
|
||||||
attr = self.dict.attributes[key]
|
|
||||||
key = self._encode_key(key)
|
|
||||||
if attr.has_tag:
|
|
||||||
tag = '0' if tag == '' else tag
|
|
||||||
tag = struct.pack('B', int(tag))
|
|
||||||
if attr.type == "integer":
|
|
||||||
# When a tagged value has the type int only 3 bytes are used
|
|
||||||
# the first byte is the tag itself, so we need to shorten our int
|
|
||||||
return (key, [tag + self._encode_value(attr, v)[1:] for v in values])
|
|
||||||
else:
|
|
||||||
return (key, [tag + self._encode_value(attr, v) for v in values])
|
|
||||||
else:
|
|
||||||
return (key, [self._encode_value(attr, v) for v in values])
|
|
||||||
|
|
||||||
def _encode_key(self, key):
|
|
||||||
if not isinstance(key, str):
|
|
||||||
return key
|
|
||||||
|
|
||||||
attr = self.dict.attributes[key]
|
|
||||||
# sub attribute keys don't need vendor
|
|
||||||
if attr.vendor and not attr.is_sub_attribute:
|
|
||||||
return (self.dict.vendors.get_forward(attr.vendor), attr.code)
|
|
||||||
else:
|
|
||||||
return attr.code
|
|
||||||
|
|
||||||
def _decode_key(self, key):
|
|
||||||
"""Turn a key into a string if possible"""
|
|
||||||
try:
|
|
||||||
return self.dict.attrindex.get_backward(key)
|
|
||||||
except KeyError:
|
|
||||||
return key
|
|
||||||
|
|
||||||
def AddAttribute(self, key, value):
|
|
||||||
"""Add an attribute to the packet.
|
|
||||||
|
|
||||||
:param key: attribute name or identification
|
|
||||||
:type key: string, attribute code or (vendor code, attribute code)
|
|
||||||
tuple
|
|
||||||
:param value: value
|
|
||||||
:type value: depends on type of attribute
|
|
||||||
"""
|
|
||||||
attr = self.dict.attributes[key.partition(':')[0]]
|
|
||||||
|
|
||||||
(key, value) = self._encode_key_values(key, value)
|
|
||||||
|
|
||||||
if attr.is_sub_attribute:
|
|
||||||
tlv = self.setdefault(self._encode_key(attr.parent.name), {})
|
|
||||||
encoded = tlv.setdefault(key, [])
|
|
||||||
else:
|
|
||||||
encoded = self.setdefault(key, [])
|
|
||||||
|
|
||||||
encoded.extend(value)
|
|
||||||
|
|
||||||
def get(self, key, failobj=None):
|
|
||||||
try:
|
|
||||||
res = self.__getitem__(key)
|
|
||||||
except KeyError:
|
|
||||||
res = failobj
|
|
||||||
return res
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
if not isinstance(key, str):
|
|
||||||
return super().__getitem__(key)
|
|
||||||
|
|
||||||
values = super().__getitem__(self._encode_key(key))
|
|
||||||
attr = self.dict.attributes[key]
|
|
||||||
if attr.type == 'tlv': # return map from sub attribute code to its values
|
|
||||||
res = {}
|
|
||||||
for (sub_attr_key, sub_attr_val) in values.items():
|
|
||||||
sub_attr_name = attr.sub_attributes[sub_attr_key]
|
|
||||||
sub_attr = self.dict.attributes[sub_attr_name]
|
|
||||||
for v in sub_attr_val:
|
|
||||||
res.setdefault(sub_attr_name, []).append(self._decode_value(sub_attr, v))
|
|
||||||
return res
|
|
||||||
else:
|
|
||||||
res = []
|
|
||||||
for v in values:
|
|
||||||
res.append(self._decode_value(attr, v))
|
|
||||||
return res
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
try:
|
|
||||||
return super().__contains__(self._encode_key(key))
|
|
||||||
except KeyError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
has_key = __contains__
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
super().__delitem__(self._encode_key(key))
|
|
||||||
|
|
||||||
def __setitem__(self, key, item):
|
|
||||||
if isinstance(key, str):
|
|
||||||
(key, item) = self._encode_key_values(key, item)
|
|
||||||
super().__setitem__(key, item)
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return [self._decode_key(key) for key in OrderedDict.keys(self)]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def CreateAuthenticator():
|
|
||||||
"""Create a packet authenticator. All RADIUS packets contain a sixteen
|
|
||||||
byte authenticator which is used to authenticate replies from the
|
|
||||||
RADIUS server and in the password hiding algorithm. This function
|
|
||||||
returns a suitable random string that can be used as an authenticator.
|
|
||||||
|
|
||||||
:return: valid packet authenticator
|
|
||||||
:rtype: binary string
|
|
||||||
"""
|
|
||||||
|
|
||||||
return secrets.token_bytes(16)
|
|
||||||
|
|
||||||
def CreateID(self):
|
|
||||||
"""Create a packet ID. All RADIUS requests have a ID which is used to
|
|
||||||
identify a request. This is used to detect retries and replay attacks.
|
|
||||||
This function returns a suitable random number that can be used as ID.
|
|
||||||
|
|
||||||
:return: ID number
|
|
||||||
:rtype: integer
|
|
||||||
|
|
||||||
"""
|
|
||||||
return int.from_bytes(secrets.token_bytes(1), 'little')
|
|
||||||
|
|
||||||
def ReplyPacket(self):
|
|
||||||
"""Create a ready-to-transmit authentication reply packet.
|
|
||||||
Returns a RADIUS packet which can be directly transmitted
|
|
||||||
to a RADIUS server. This differs with Packet() in how
|
|
||||||
the authenticator is calculated.
|
|
||||||
|
|
||||||
:return: raw packet
|
|
||||||
:rtype: string
|
|
||||||
"""
|
|
||||||
assert self.authenticator
|
|
||||||
assert self.secret is not None
|
|
||||||
|
|
||||||
if self.message_authenticator:
|
|
||||||
self._refresh_message_authenticator()
|
|
||||||
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
header = struct.pack('!BBH', self.code, self.id, (20 + len(attr)))
|
|
||||||
|
|
||||||
authenticator = md5_constructor(header[0:4] + self.authenticator
|
|
||||||
+ attr + self.secret).digest()
|
|
||||||
|
|
||||||
return header + authenticator + attr
|
|
||||||
|
|
||||||
def VerifyReply(self, rawreply):
|
|
||||||
if int(rawreply[1]) != self.id:
|
|
||||||
raise PacketError("Reply Packet has wrong id")
|
|
||||||
|
|
||||||
# The Authenticator field in an Accounting-Response packet is called
|
|
||||||
# the Response Authenticator, and contains a one-way MD5 hash
|
|
||||||
# calculated over a stream of octets consisting of the Accounting
|
|
||||||
# Response Code, Identifier, Length, the Request Authenticator field
|
|
||||||
# from the Accounting-Request packet being replied to, and the
|
|
||||||
# response attributes if any, followed by the shared secret. The
|
|
||||||
# resulting 16 octet MD5 hash value is stored in the Authenticator
|
|
||||||
# field of the Accounting-Response packet.
|
|
||||||
hash = md5_constructor(rawreply[0:4] + self.authenticator +
|
|
||||||
rawreply[20:] + self.secret).digest()
|
|
||||||
|
|
||||||
if hash != rawreply[4:20]:
|
|
||||||
raise PacketError("Reply Packet has a wrong authenticator")
|
|
||||||
return self.CreateReply(packet=rawreply)
|
|
||||||
|
|
||||||
def _pkt_encode_attribute(self, key, value):
|
|
||||||
if isinstance(key, tuple):
|
|
||||||
value = struct.pack('!L', key[0]) + \
|
|
||||||
self._pkt_encode_attribute(key[1], value)
|
|
||||||
key = 26
|
|
||||||
|
|
||||||
return struct.pack('!BB', key, (len(value) + 2)) + value
|
|
||||||
|
|
||||||
def _pkt_encode_tlv(self, tlv_key, tlv_value):
|
|
||||||
tlv_attr = self.dict.attributes[self._decode_key(tlv_key)]
|
|
||||||
curr_avp = b''
|
|
||||||
avps = []
|
|
||||||
max_sub_attribute_len = max(map(lambda item: len(item[1]), tlv_value.items()))
|
|
||||||
for i in range(max_sub_attribute_len):
|
|
||||||
sub_attr_encoding = b''
|
|
||||||
for (code, datalst) in tlv_value.items():
|
|
||||||
if i < len(datalst):
|
|
||||||
sub_attr_encoding += self._pkt_encode_attribute(code, datalst[i])
|
|
||||||
# split above 255. assuming len of one instance of all sub tlvs is lower than 255
|
|
||||||
if (len(sub_attr_encoding) + len(curr_avp)) < 245:
|
|
||||||
curr_avp += sub_attr_encoding
|
|
||||||
else:
|
|
||||||
avps.append(curr_avp)
|
|
||||||
curr_avp = sub_attr_encoding
|
|
||||||
avps.append(curr_avp)
|
|
||||||
tlv_avps = []
|
|
||||||
for avp in avps:
|
|
||||||
value = struct.pack('!BB', tlv_attr.code, (len(avp) + 2)) + avp
|
|
||||||
tlv_avps.append(value)
|
|
||||||
if tlv_attr.vendor:
|
|
||||||
vendor_avps = b''
|
|
||||||
for avp in tlv_avps:
|
|
||||||
vendor_avps += struct.pack(
|
|
||||||
'!BBL', 26, (len(avp) + 6),
|
|
||||||
self.dict.vendors.get_forward(tlv_attr.vendor)
|
|
||||||
) + avp
|
|
||||||
return vendor_avps
|
|
||||||
else:
|
|
||||||
return b''.join(tlv_avps)
|
|
||||||
|
|
||||||
def _pkt_encode_attributes(self):
|
|
||||||
result = b''
|
|
||||||
for (code, datalst) in self.items():
|
|
||||||
attribute = self.dict.attributes.get(self._decode_key(code))
|
|
||||||
if attribute and attribute.type == 'tlv':
|
|
||||||
result += self._pkt_encode_tlv(code, datalst)
|
|
||||||
else:
|
|
||||||
for data in datalst:
|
|
||||||
result += self._pkt_encode_attribute(code, data)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _pkt_decode_vendor_attribute(self, data):
|
|
||||||
# Check if this packet is long enough to be in the
|
|
||||||
# RFC2865 recommended form
|
|
||||||
if len(data) < 6:
|
|
||||||
return [(26, data)]
|
|
||||||
|
|
||||||
(vendor, atype, length) = struct.unpack('!LBB', data[:6])[0:3]
|
|
||||||
attribute = self.dict.attributes.get(self._decode_key((vendor, atype)))
|
|
||||||
try:
|
|
||||||
if attribute and attribute.type == 'tlv':
|
|
||||||
self._pkt_decode_tlv_attribute((vendor, atype), data[6:length + 4])
|
|
||||||
tlvs = [] # tlv is added to the packet inside _pkt_decode_tlv_attribute
|
|
||||||
else:
|
|
||||||
tlvs = [((vendor, atype), data[6:length + 4])]
|
|
||||||
except:
|
|
||||||
return [(26, data)]
|
|
||||||
|
|
||||||
sumlength = 4 + length
|
|
||||||
while len(data) > sumlength:
|
|
||||||
try:
|
|
||||||
atype, length = struct.unpack('!BB', data[sumlength:sumlength+2])[0:2]
|
|
||||||
except:
|
|
||||||
return [(26, data)]
|
|
||||||
tlvs.append(((vendor, atype), data[sumlength+2:sumlength+length]))
|
|
||||||
sumlength += length
|
|
||||||
return tlvs
|
|
||||||
|
|
||||||
def _pkt_decode_tlv_attribute(self, code, data):
|
|
||||||
sub_attributes = self.setdefault(code, {})
|
|
||||||
loc = 0
|
|
||||||
|
|
||||||
while loc < len(data):
|
|
||||||
atype, length = struct.unpack('!BB', data[loc:loc+2])[0:2]
|
|
||||||
sub_attributes.setdefault(atype, []).append(data[loc+2:loc+length])
|
|
||||||
loc += length
|
|
||||||
|
|
||||||
def DecodePacket(self, packet):
|
|
||||||
"""Initialize the object from raw packet data. Decode a packet as
|
|
||||||
received from the network and decode it.
|
|
||||||
|
|
||||||
:param packet: raw packet
|
|
||||||
:type packet: string"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
(self.code, self.id, length, self.authenticator) = \
|
|
||||||
struct.unpack('!BBH16s', packet[0:20])
|
|
||||||
|
|
||||||
except struct.error:
|
|
||||||
raise PacketError('Packet header is corrupt')
|
|
||||||
if len(packet) != length:
|
|
||||||
raise PacketError('Packet has invalid length')
|
|
||||||
if length > 4096:
|
|
||||||
raise PacketError(f'Packet length is too long ({length})')
|
|
||||||
|
|
||||||
self.clear()
|
|
||||||
|
|
||||||
packet = packet[20:]
|
|
||||||
while packet:
|
|
||||||
try:
|
|
||||||
(key, attrlen) = struct.unpack('!BB', packet[0:2])
|
|
||||||
except struct.error:
|
|
||||||
raise PacketError('Attribute header is corrupt')
|
|
||||||
|
|
||||||
if attrlen < 2:
|
|
||||||
raise PacketError(f'Attribute length is too small (attrlen)')
|
|
||||||
|
|
||||||
value = packet[2:attrlen]
|
|
||||||
attribute = self.dict.attributes.get(self._decode_key(key))
|
|
||||||
if key == 26:
|
|
||||||
for (key, value) in self._pkt_decode_vendor_attribute(value):
|
|
||||||
self.setdefault(key, []).append(value)
|
|
||||||
elif key == 80:
|
|
||||||
# POST: Message Authenticator AVP is present.
|
|
||||||
self.message_authenticator = True
|
|
||||||
self.setdefault(key, []).append(value)
|
|
||||||
elif attribute and attribute.type == 'tlv':
|
|
||||||
self._pkt_decode_tlv_attribute(key, value)
|
|
||||||
else:
|
|
||||||
self.setdefault(key, []).append(value)
|
|
||||||
|
|
||||||
packet = packet[attrlen:]
|
|
||||||
|
|
||||||
def SaltCrypt(self, value):
|
|
||||||
"""Salt Encryption
|
|
||||||
|
|
||||||
:param value: plaintext value
|
|
||||||
:type password: unicode string
|
|
||||||
:return: obfuscated version of the value
|
|
||||||
:rtype: binary string
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(value, str):
|
|
||||||
value = value.encode('utf-8')
|
|
||||||
|
|
||||||
if self.authenticator is None:
|
|
||||||
# self.authenticator = self.CreateAuthenticator()
|
|
||||||
self.authenticator = 16 * b'\x00'
|
|
||||||
|
|
||||||
random_value = 32768 + random_generator.randrange(0, 32767)
|
|
||||||
result = struct.pack('!H', random_value)
|
|
||||||
|
|
||||||
length = struct.pack("B", len(value))
|
|
||||||
buf = length + value
|
|
||||||
if len(buf) % 16 != 0:
|
|
||||||
buf += b'\x00' * (16 - (len(buf) % 16))
|
|
||||||
|
|
||||||
last = self.authenticator + result
|
|
||||||
while buf:
|
|
||||||
cur_hash = md5_constructor(self.secret + last).digest()
|
|
||||||
for b, h in zip(buf, cur_hash):
|
|
||||||
result += bytes([b ^ h])
|
|
||||||
last = result[-16:]
|
|
||||||
buf = buf[16:]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacket(Packet):
|
|
||||||
def __init__(self, code=AccessRequest, id=None, secret=b'',
|
|
||||||
authenticator=None, auth_type='pap', **attributes):
|
|
||||||
"""Constructor
|
|
||||||
|
|
||||||
:param code: packet type code
|
|
||||||
:type code: integer (8bits)
|
|
||||||
:param id: packet identification number
|
|
||||||
:type id: integer (8 bits)
|
|
||||||
:param secret: secret needed to communicate with a RADIUS server
|
|
||||||
:type secret: string
|
|
||||||
|
|
||||||
:param dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary class
|
|
||||||
|
|
||||||
:param packet: raw packet to decode
|
|
||||||
:type packet: string
|
|
||||||
"""
|
|
||||||
|
|
||||||
Packet.__init__(self, code, id, secret, authenticator, **attributes)
|
|
||||||
self.auth_type = auth_type
|
|
||||||
|
|
||||||
def CreateReply(self, **attributes):
|
|
||||||
"""Create a new packet as a reply to this one. This method
|
|
||||||
makes sure the authenticator and secret are copied over
|
|
||||||
to the new instance.
|
|
||||||
"""
|
|
||||||
return AuthPacket(AccessAccept, self.id,
|
|
||||||
self.secret, self.authenticator, dict=self.dict,
|
|
||||||
auth_type=self.auth_type, **attributes)
|
|
||||||
|
|
||||||
def RequestPacket(self):
|
|
||||||
"""Create a ready-to-transmit authentication request packet.
|
|
||||||
Return a RADIUS packet which can be directly transmitted
|
|
||||||
to a RADIUS server.
|
|
||||||
|
|
||||||
:return: raw packet
|
|
||||||
:rtype: string
|
|
||||||
"""
|
|
||||||
if self.authenticator is None:
|
|
||||||
self.authenticator = self.CreateAuthenticator()
|
|
||||||
|
|
||||||
if self.id is None:
|
|
||||||
self.id = self.CreateID()
|
|
||||||
|
|
||||||
if self.message_authenticator:
|
|
||||||
self._refresh_message_authenticator()
|
|
||||||
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
if self.auth_type == 'eap-md5':
|
|
||||||
header = struct.pack(
|
|
||||||
'!BBH16s', self.code, self.id, (20 + 18 + len(attr)), self.authenticator
|
|
||||||
)
|
|
||||||
msg = header \
|
|
||||||
+ attr \
|
|
||||||
+ struct.pack('!BB', 80, struct.calcsize('!BB16s')),
|
|
||||||
digest = hmac.new(self.secret, msg, digestmod=hashlib.md5).digest()
|
|
||||||
return msg + digest
|
|
||||||
|
|
||||||
header = struct.pack('!BBH16s', self.code, self.id,
|
|
||||||
(20 + len(attr)), self.authenticator)
|
|
||||||
|
|
||||||
return header + attr
|
|
||||||
|
|
||||||
def PwDecrypt(self, password):
|
|
||||||
"""Obfuscate a RADIUS password. RADIUS hides passwords in packets by
|
|
||||||
using an algorithm based on the MD5 hash of the packet authenticator
|
|
||||||
and RADIUS secret. This function reverses the obfuscation process.
|
|
||||||
|
|
||||||
:param password: obfuscated form of password
|
|
||||||
:type password: binary string
|
|
||||||
:return: plaintext password
|
|
||||||
:rtype: unicode string
|
|
||||||
"""
|
|
||||||
pw = self.radius_password_pseudo_hash(password).rstrip(b'\x00')
|
|
||||||
|
|
||||||
return pw.decode('utf-8')
|
|
||||||
|
|
||||||
def PwCrypt(self, password):
|
|
||||||
"""Obfuscate password.
|
|
||||||
RADIUS hides passwords in packets by using an algorithm
|
|
||||||
based on the MD5 hash of the packet authenticator and RADIUS
|
|
||||||
secret. If no authenticator has been set before calling PwCrypt
|
|
||||||
one is created automatically. Changing the authenticator after
|
|
||||||
setting a password that has been encrypted using this function
|
|
||||||
will not work.
|
|
||||||
|
|
||||||
:param password: plaintext password
|
|
||||||
:type password: unicode string
|
|
||||||
:return: obfuscated version of the password
|
|
||||||
:rtype: binary string
|
|
||||||
"""
|
|
||||||
if self.authenticator is None:
|
|
||||||
self.authenticator = self.CreateAuthenticator()
|
|
||||||
|
|
||||||
if isinstance(password, str):
|
|
||||||
password = password.encode('utf-8')
|
|
||||||
|
|
||||||
buf = password
|
|
||||||
if len(password) % 16 != 0:
|
|
||||||
buf += b'\x00' * (16 - (len(password) % 16))
|
|
||||||
|
|
||||||
return self.radius_password_pseudo_hash(buf)
|
|
||||||
|
|
||||||
def radius_password_pseudo_hash(self, password):
|
|
||||||
result = b''
|
|
||||||
buf = password
|
|
||||||
last = self.authenticator
|
|
||||||
|
|
||||||
while buf:
|
|
||||||
cur_hash = md5_constructor(self.secret + last).digest()
|
|
||||||
for b, h in zip(buf, cur_hash):
|
|
||||||
result += bytes([b ^ h])
|
|
||||||
|
|
||||||
(last, buf) = (buf[:16], buf[16:])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def VerifyChapPasswd(self, userpwd):
|
|
||||||
""" Verify RADIUS ChapPasswd
|
|
||||||
|
|
||||||
:param userpwd: plaintext password
|
|
||||||
:type userpwd: str
|
|
||||||
:return: is verify ok
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not self.authenticator:
|
|
||||||
self.authenticator = self.CreateAuthenticator()
|
|
||||||
|
|
||||||
if isinstance(userpwd, str):
|
|
||||||
userpwd = userpwd.strip().encode('utf-8')
|
|
||||||
|
|
||||||
chap_password = tools.DecodeOctets(self.get(3)[0])
|
|
||||||
if len(chap_password) != 17:
|
|
||||||
return False
|
|
||||||
|
|
||||||
chapid = chr(chap_password[0]).encode('utf-8')
|
|
||||||
password = chap_password[1:]
|
|
||||||
|
|
||||||
challenge = self.authenticator
|
|
||||||
if 'CHAP-Challenge' in self:
|
|
||||||
challenge = self['CHAP-Challenge'][0]
|
|
||||||
return password == md5_constructor(chapid + userpwd + challenge).digest()
|
|
||||||
|
|
||||||
def VerifyAuthRequest(self):
|
|
||||||
"""Verify request authenticator.
|
|
||||||
|
|
||||||
:return: True if verification failed else False
|
|
||||||
:rtype: boolean
|
|
||||||
"""
|
|
||||||
assert self.raw_packet
|
|
||||||
hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' +
|
|
||||||
self.raw_packet[20:] + self.secret).digest()
|
|
||||||
return hash == self.authenticator
|
|
||||||
|
|
||||||
|
|
||||||
class AcctPacket(Packet):
|
|
||||||
"""RADIUS accounting packets. This class is a specialization
|
|
||||||
of the generic :obj:`Packet` class for accounting packets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, code=AccountingRequest, id=None, secret=b'',
|
|
||||||
authenticator=None, **attributes):
|
|
||||||
"""Constructor
|
|
||||||
|
|
||||||
:param dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary class
|
|
||||||
:param secret: secret needed to communicate with a RADIUS server
|
|
||||||
:type secret: string
|
|
||||||
:param id: packet identification number
|
|
||||||
:type id: integer (8 bits)
|
|
||||||
:param code: packet type code
|
|
||||||
:type code: integer (8bits)
|
|
||||||
:param packet: raw packet to decode
|
|
||||||
:type packet: string
|
|
||||||
"""
|
|
||||||
Packet.__init__(self, code, id, secret, authenticator, **attributes)
|
|
||||||
|
|
||||||
def CreateReply(self, **attributes):
|
|
||||||
"""Create a new packet as a reply to this one. This method
|
|
||||||
makes sure the authenticator and secret are copied over
|
|
||||||
to the new instance.
|
|
||||||
"""
|
|
||||||
return AcctPacket(AccountingResponse, self.id,
|
|
||||||
self.secret, self.authenticator, dict=self.dict,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
def VerifyAcctRequest(self):
|
|
||||||
"""Verify request authenticator.
|
|
||||||
|
|
||||||
:return: False if verification failed else True
|
|
||||||
:rtype: boolean
|
|
||||||
"""
|
|
||||||
assert self.raw_packet
|
|
||||||
|
|
||||||
hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' +
|
|
||||||
self.raw_packet[20:] + self.secret).digest()
|
|
||||||
|
|
||||||
return hash == self.authenticator
|
|
||||||
|
|
||||||
def RequestPacket(self):
|
|
||||||
"""Create a ready-to-transmit authentication request packet.
|
|
||||||
Return a RADIUS packet which can be directly transmitted
|
|
||||||
to a RADIUS server.
|
|
||||||
|
|
||||||
:return: raw packet
|
|
||||||
:rtype: string
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.id is None:
|
|
||||||
self.id = self.CreateID()
|
|
||||||
|
|
||||||
if self.message_authenticator:
|
|
||||||
self._refresh_message_authenticator()
|
|
||||||
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
header = struct.pack('!BBH', self.code, self.id, (20 + len(attr)))
|
|
||||||
self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' +
|
|
||||||
attr + self.secret).digest()
|
|
||||||
|
|
||||||
ans = header + self.authenticator + attr
|
|
||||||
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
class CoAPacket(Packet):
|
|
||||||
"""RADIUS CoA packets. This class is a specialization
|
|
||||||
of the generic :obj:`Packet` class for CoA packets.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, code=CoARequest, id=None, secret=b'',
|
|
||||||
authenticator=None, **attributes):
|
|
||||||
"""Constructor
|
|
||||||
|
|
||||||
:param dict: RADIUS dictionary
|
|
||||||
:type dict: pyrad.dictionary.Dictionary class
|
|
||||||
:param secret: secret needed to communicate with a RADIUS server
|
|
||||||
:type secret: string
|
|
||||||
:param id: packet identification number
|
|
||||||
:type id: integer (8 bits)
|
|
||||||
:param code: packet type code
|
|
||||||
:type code: integer (8bits)
|
|
||||||
:param packet: raw packet to decode
|
|
||||||
:type packet: string
|
|
||||||
"""
|
|
||||||
Packet.__init__(self, code, id, secret, authenticator, **attributes)
|
|
||||||
|
|
||||||
def CreateReply(self, **attributes):
|
|
||||||
"""Create a new packet as a reply to this one. This method
|
|
||||||
makes sure the authenticator and secret are copied over
|
|
||||||
to the new instance.
|
|
||||||
"""
|
|
||||||
return CoAPacket(CoAACK, self.id,
|
|
||||||
self.secret, self.authenticator, dict=self.dict,
|
|
||||||
**attributes)
|
|
||||||
|
|
||||||
def VerifyCoARequest(self):
|
|
||||||
"""Verify request authenticator.
|
|
||||||
|
|
||||||
:return: False if verification failed else True
|
|
||||||
:rtype: boolean
|
|
||||||
"""
|
|
||||||
assert self.raw_packet
|
|
||||||
hash = md5_constructor(self.raw_packet[0:4] + 16 * b'\x00' +
|
|
||||||
self.raw_packet[20:] + self.secret).digest()
|
|
||||||
return hash == self.authenticator
|
|
||||||
|
|
||||||
def RequestPacket(self):
|
|
||||||
"""Create a ready-to-transmit CoA request packet.
|
|
||||||
Return a RADIUS packet which can be directly transmitted
|
|
||||||
to a RADIUS server.
|
|
||||||
|
|
||||||
:return: raw packet
|
|
||||||
:rtype: string
|
|
||||||
"""
|
|
||||||
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
|
|
||||||
if self.id is None:
|
|
||||||
self.id = self.CreateID()
|
|
||||||
|
|
||||||
header = struct.pack('!BBH', self.code, self.id, (20 + len(attr)))
|
|
||||||
self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' +
|
|
||||||
attr + self.secret).digest()
|
|
||||||
|
|
||||||
if self.message_authenticator:
|
|
||||||
self._refresh_message_authenticator()
|
|
||||||
attr = self._pkt_encode_attributes()
|
|
||||||
self.authenticator = md5_constructor(header[0:4] + 16 * b'\x00' +
|
|
||||||
attr + self.secret).digest()
|
|
||||||
|
|
||||||
return header + self.authenticator + attr
|
|
||||||
|
|
||||||
|
|
||||||
def CreateID():
|
|
||||||
"""Generate a packet ID.
|
|
||||||
|
|
||||||
:return: packet ID
|
|
||||||
:rtype: 8 bit integer
|
|
||||||
"""
|
|
||||||
global CurrentID
|
|
||||||
|
|
||||||
CurrentID = (CurrentID + 1) % 256
|
|
||||||
return CurrentID
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
import pyrad
|
|
||||||
import sys
|
|
||||||
|
|
||||||
pyrad # keep pyflakes happy
|
|
||||||
home = sys.modules["pyrad"].__path__[0]
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
# A simple dictionary
|
|
||||||
|
|
||||||
ATTRIBUTE User-Name 1 string
|
|
||||||
ATTRIBUTE User-Password 2 string encrypt=1
|
|
||||||
ATTRIBUTE CHAP-Password 3 octets
|
|
||||||
ATTRIBUTE CHAP-Challenge 60 octets
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
# A failing dictionary
|
|
||||||
|
|
||||||
VALUE Not-Defined Undefined-Value 1
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
# A simple dictionary
|
|
||||||
|
|
||||||
ATTRIBUTE Test-String 1 string
|
|
||||||
ATTRIBUTE Test-Octets 2 octets
|
|
||||||
ATTRIBUTE Test-Integer 3 integer
|
|
||||||
|
|
||||||
VALUE Test-Integer Zero 0
|
|
||||||
VALUE Test-Integer One 1
|
|
||||||
VALUE Test-Integer Two 2
|
|
||||||
VALUE Test-Integer Three 3
|
|
||||||
VALUE Test-Integer Four 4
|
|
||||||
|
|
||||||
ATTRIBUTE Test-Tlv 4 tlv
|
|
||||||
ATTRIBUTE Test-Tlv-Str 4.1 string
|
|
||||||
ATTRIBUTE Test-Tlv-Int 4.2 integer
|
|
||||||
|
|
||||||
VENDOR Simplon 16
|
|
||||||
|
|
||||||
|
|
||||||
BEGIN-VENDOR Simplon
|
|
||||||
ATTRIBUTE Simplon-Number 1 integer
|
|
||||||
ATTRIBUTE Simplon-String 2 string
|
|
||||||
|
|
||||||
VALUE Simplon-Number Zero 0
|
|
||||||
VALUE Simplon-Number One 1
|
|
||||||
VALUE Simplon-Number Two 2
|
|
||||||
VALUE Simplon-Number Three 3
|
|
||||||
VALUE Simplon-Number Four 4
|
|
||||||
|
|
||||||
ATTRIBUTE Simplon-Tlv 3 tlv
|
|
||||||
ATTRIBUTE Simplon-Tlv-Str 3.1 string
|
|
||||||
ATTRIBUTE Simplon-Tlv-Int 3.2 integer
|
|
||||||
|
|
||||||
END-VENDOR Simplon
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
# A simple dictionary
|
|
||||||
|
|
||||||
ATTRIBUTE Test-String 1 string
|
|
||||||
ATTRIBUTE Test-Octets 2 octets
|
|
||||||
ATTRIBUTE Test-Integer 3 integer
|
|
||||||
ATTRIBUTE Test-Ip-Address 4 ipaddr
|
|
||||||
ATTRIBUTE Test-Ipv6-Address 5 ipv6addr
|
|
||||||
ATTRIBUTE Test-If-Id 6 ifid
|
|
||||||
ATTRIBUTE Test-Date 7 date
|
|
||||||
ATTRIBUTE Test-Abinary 8 abinary
|
|
||||||
ATTRIBUTE Test-Tlv 9 tlv
|
|
||||||
ATTRIBUTE Test-Tlv-Str 9.1 string
|
|
||||||
ATTRIBUTE Test-Tlv-Int 9.2 integer
|
|
||||||
ATTRIBUTE Test-Integer64 10 integer64
|
|
||||||
ATTRIBUTE Test-Integer64-Hex 0x0a integer64
|
|
||||||
ATTRIBUTE Test-Integer64-Oct 0o12 integer64
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# A simple dictionary
|
|
||||||
|
|
||||||
ATTRIBUTE Tunnel-Password 2 string encrypt=2
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
import fcntl
|
|
||||||
import os
|
|
||||||
from pyrad.packet import PacketError
|
|
||||||
|
|
||||||
|
|
||||||
class MockPacket:
|
|
||||||
reply = object()
|
|
||||||
|
|
||||||
def __init__(self, code, verify=False, error=False):
|
|
||||||
self.code = code
|
|
||||||
self.data = {}
|
|
||||||
self.verify = verify
|
|
||||||
self.error = error
|
|
||||||
|
|
||||||
def CreateReply(self, packet=None):
|
|
||||||
if self.error:
|
|
||||||
raise PacketError
|
|
||||||
return self.reply
|
|
||||||
|
|
||||||
def VerifyReply(self, reply, rawreply):
|
|
||||||
return self.verify
|
|
||||||
|
|
||||||
def RequestPacket(self):
|
|
||||||
return "request packet"
|
|
||||||
|
|
||||||
def __contains__(self, key):
|
|
||||||
return key in self.data
|
|
||||||
has_key = __contains__
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.data[key] = [value]
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return self.data[key]
|
|
||||||
|
|
||||||
|
|
||||||
class MockSocket:
|
|
||||||
def __init__(self, domain, type, data=None):
|
|
||||||
self.domain = domain
|
|
||||||
self.type = type
|
|
||||||
self.closed = False
|
|
||||||
self.options = []
|
|
||||||
self.address = None
|
|
||||||
self.output = []
|
|
||||||
|
|
||||||
if data is not None:
|
|
||||||
(self.read_end, self.write_end) = os.pipe()
|
|
||||||
fcntl.fcntl(self.write_end, fcntl.F_SETFL, os.O_NONBLOCK)
|
|
||||||
os.write(self.write_end, data)
|
|
||||||
self.data = data
|
|
||||||
else:
|
|
||||||
self.read_end = 1
|
|
||||||
self.write_end = None
|
|
||||||
|
|
||||||
def fileno(self):
|
|
||||||
return self.read_end
|
|
||||||
|
|
||||||
def bind(self, address):
|
|
||||||
self.address = address
|
|
||||||
|
|
||||||
def recv(self, buffer):
|
|
||||||
return self.data[:buffer]
|
|
||||||
|
|
||||||
def sendto(self, data, target):
|
|
||||||
self.output.append((data, target))
|
|
||||||
|
|
||||||
def setsockopt(self, level, opt, value):
|
|
||||||
self.options.append((level, opt, value))
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.closed = True
|
|
||||||
|
|
||||||
|
|
||||||
class MockFinished(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MockPoll:
|
|
||||||
results = []
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.registry = {}
|
|
||||||
|
|
||||||
def register(self, fd, options):
|
|
||||||
self.registry[fd] = options
|
|
||||||
|
|
||||||
def unregister(self, fd):
|
|
||||||
try:
|
|
||||||
del self.registry[fd]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def poll(self, timeout=None):
|
|
||||||
for result in self.results:
|
|
||||||
yield result
|
|
||||||
raise MockFinished
|
|
||||||
|
|
||||||
|
|
||||||
def origkey(klass):
|
|
||||||
return "_originals_" + klass.__name__
|
|
||||||
|
|
||||||
|
|
||||||
def MockClassMethod(klass, name, myfunc=None):
|
|
||||||
def func(self, *args, **kwargs):
|
|
||||||
if not hasattr(self, "called"):
|
|
||||||
self.called = []
|
|
||||||
self.called.append((name, args, kwargs))
|
|
||||||
|
|
||||||
key = origkey(klass)
|
|
||||||
if not hasattr(klass, key):
|
|
||||||
setattr(klass, key, {})
|
|
||||||
getattr(klass, key)[name] = getattr(klass, name)
|
|
||||||
if myfunc is None:
|
|
||||||
setattr(klass, name, func)
|
|
||||||
else:
|
|
||||||
setattr(klass, name, myfunc)
|
|
||||||
|
|
||||||
|
|
||||||
def UnmockClassMethods(klass):
|
|
||||||
key = origkey(klass)
|
|
||||||
if not hasattr(klass, key):
|
|
||||||
return
|
|
||||||
for (name, func) in getattr(klass, key).items():
|
|
||||||
setattr(klass, name, func)
|
|
||||||
|
|
||||||
delattr(klass, key)
|
|
||||||
|
|
||||||
|
|
||||||
class MockFd:
|
|
||||||
data = object()
|
|
||||||
source = object()
|
|
||||||
|
|
||||||
def __init__(self, fd=0):
|
|
||||||
self.fd = fd
|
|
||||||
|
|
||||||
def fileno(self):
|
|
||||||
return self.fd
|
|
||||||
|
|
||||||
def recvfrom(self, size):
|
|
||||||
self.size = size
|
|
||||||
return (self.data, self.source)
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
import operator
|
|
||||||
import unittest
|
|
||||||
from pyrad.bidict import BiDict
|
|
||||||
|
|
||||||
|
|
||||||
class BiDictTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.bidict = BiDict()
|
|
||||||
|
|
||||||
def testStartEmpty(self):
|
|
||||||
self.assertEqual(len(self.bidict), 0)
|
|
||||||
self.assertEqual(len(self.bidict.forward), 0)
|
|
||||||
self.assertEqual(len(self.bidict.backward), 0)
|
|
||||||
|
|
||||||
def testLength(self):
|
|
||||||
self.assertEqual(len(self.bidict), 0)
|
|
||||||
self.bidict.add("from", "to")
|
|
||||||
self.assertEqual(len(self.bidict), 1)
|
|
||||||
del self.bidict["from"]
|
|
||||||
self.assertEqual(len(self.bidict), 0)
|
|
||||||
|
|
||||||
def testDeletion(self):
|
|
||||||
self.assertRaises(KeyError, operator.delitem, self.bidict, "missing")
|
|
||||||
self.bidict.add("missing", "present")
|
|
||||||
del self.bidict["missing"]
|
|
||||||
|
|
||||||
def testBackwardDeletion(self):
|
|
||||||
self.assertRaises(KeyError, operator.delitem, self.bidict, "missing")
|
|
||||||
self.bidict.add("missing", "present")
|
|
||||||
del self.bidict["present"]
|
|
||||||
self.assertEqual(self.bidict.has_forward("missing"), False)
|
|
||||||
|
|
||||||
def testForwardAccess(self):
|
|
||||||
self.bidict.add("shake", "vanilla")
|
|
||||||
self.bidict.add("pie", "custard")
|
|
||||||
self.assertEqual(self.bidict.has_forward("shake"), True)
|
|
||||||
self.assertEqual(self.bidict.get_forward("shake"), "vanilla")
|
|
||||||
self.assertEqual(self.bidict.has_forward("pie"), True)
|
|
||||||
self.assertEqual(self.bidict.get_forward("pie"), "custard")
|
|
||||||
self.assertEqual(self.bidict.has_forward("missing"), False)
|
|
||||||
self.assertRaises(KeyError, self.bidict.get_forward, "missing")
|
|
||||||
|
|
||||||
def testBackwardAccess(self):
|
|
||||||
self.bidict.add("shake", "vanilla")
|
|
||||||
self.bidict.add("pie", "custard")
|
|
||||||
self.assertEqual(self.bidict.has_backward("vanilla"), True)
|
|
||||||
self.assertEqual(self.bidict.get_backward("vanilla"), "shake")
|
|
||||||
self.assertEqual(self.bidict.has_backward("missing"), False)
|
|
||||||
self.assertRaises(KeyError, self.bidict.get_backward, "missing")
|
|
||||||
|
|
||||||
def testItemAccessor(self):
|
|
||||||
self.bidict.add("shake", "vanilla")
|
|
||||||
self.bidict.add("pie", "custard")
|
|
||||||
self.assertRaises(KeyError, operator.getitem, self.bidict, "missing")
|
|
||||||
self.assertEquals(self.bidict["shake"], "vanilla")
|
|
||||||
self.assertEquals(self.bidict["pie"], "custard")
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
import select
|
|
||||||
import socket
|
|
||||||
import unittest
|
|
||||||
from pyrad.client import Client
|
|
||||||
from pyrad.client import Timeout
|
|
||||||
from pyrad.packet import AuthPacket
|
|
||||||
from pyrad.packet import AcctPacket
|
|
||||||
from pyrad.packet import AccessRequest
|
|
||||||
from pyrad.packet import AccountingRequest
|
|
||||||
from pyrad.tests.mock import MockPacket
|
|
||||||
from pyrad.tests.mock import MockPoll
|
|
||||||
from pyrad.tests.mock import MockSocket
|
|
||||||
|
|
||||||
BIND_IP = "127.0.0.1"
|
|
||||||
BIND_PORT = 53535
|
|
||||||
|
|
||||||
|
|
||||||
class ConstructionTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = object()
|
|
||||||
|
|
||||||
def testSimpleConstruction(self):
|
|
||||||
client = Client(self.server)
|
|
||||||
self.failUnless(client.server is self.server)
|
|
||||||
self.assertEqual(client.authport, 1812)
|
|
||||||
self.assertEqual(client.acctport, 1813)
|
|
||||||
self.assertEqual(client.secret, b'')
|
|
||||||
self.assertEqual(client.retries, 3)
|
|
||||||
self.assertEqual(client.timeout, 5)
|
|
||||||
self.failUnless(client.dict is None)
|
|
||||||
|
|
||||||
def testParameterOrder(self):
|
|
||||||
marker = object()
|
|
||||||
client = Client(self.server, 123, 456, 789, "secret", marker)
|
|
||||||
self.failUnless(client.server is self.server)
|
|
||||||
self.assertEqual(client.authport, 123)
|
|
||||||
self.assertEqual(client.acctport, 456)
|
|
||||||
self.assertEqual(client.coaport, 789)
|
|
||||||
self.assertEqual(client.secret, "secret")
|
|
||||||
self.failUnless(client.dict is marker)
|
|
||||||
|
|
||||||
def testNamedParameters(self):
|
|
||||||
marker = object()
|
|
||||||
client = Client(server=self.server, authport=123, acctport=456,
|
|
||||||
secret="secret", dict=marker)
|
|
||||||
self.failUnless(client.server is self.server)
|
|
||||||
self.assertEqual(client.authport, 123)
|
|
||||||
self.assertEqual(client.acctport, 456)
|
|
||||||
self.assertEqual(client.secret, "secret")
|
|
||||||
self.failUnless(client.dict is marker)
|
|
||||||
|
|
||||||
|
|
||||||
class SocketTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = object()
|
|
||||||
self.client = Client(self.server)
|
|
||||||
self.orgsocket = socket.socket
|
|
||||||
socket.socket = MockSocket
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
socket.socket = self.orgsocket
|
|
||||||
|
|
||||||
def testReopen(self):
|
|
||||||
self.client._socket_open()
|
|
||||||
sock = self.client._socket
|
|
||||||
self.client._socket_open()
|
|
||||||
self.failUnless(sock is self.client._socket)
|
|
||||||
|
|
||||||
def testBind(self):
|
|
||||||
self.client.bind((BIND_IP, BIND_PORT))
|
|
||||||
self.assertEqual(self.client._socket.address, (BIND_IP, BIND_PORT))
|
|
||||||
self.assertEqual(self.client._socket.options,
|
|
||||||
[(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)])
|
|
||||||
|
|
||||||
def testBindClosesSocket(self):
|
|
||||||
s = MockSocket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
self.client._socket = s
|
|
||||||
self.client._poll = MockPoll()
|
|
||||||
self.client.bind((BIND_IP, BIND_PORT))
|
|
||||||
self.assertEqual(s.closed, True)
|
|
||||||
|
|
||||||
def testSendPacket(self):
|
|
||||||
def MockSend(self, pkt, port):
|
|
||||||
self._mock_pkt = pkt
|
|
||||||
self._mock_port = port
|
|
||||||
|
|
||||||
_send_packet = Client._send_packet
|
|
||||||
Client._send_packet = MockSend
|
|
||||||
|
|
||||||
self.client.SendPacket(AuthPacket())
|
|
||||||
self.assertEqual(self.client._mock_port, self.client.authport)
|
|
||||||
|
|
||||||
self.client.SendPacket(AcctPacket())
|
|
||||||
self.assertEqual(self.client._mock_port, self.client.acctport)
|
|
||||||
|
|
||||||
Client._send_packet = _send_packet
|
|
||||||
|
|
||||||
def testNoRetries(self):
|
|
||||||
self.client.retries = 0
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, None, None)
|
|
||||||
|
|
||||||
def testSingleRetry(self):
|
|
||||||
self.client.retries = 1
|
|
||||||
self.client.timeout = 0
|
|
||||||
packet = MockPacket(AccessRequest)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
self.assertEqual(self.client._socket.output,
|
|
||||||
[("request packet", (self.server, 432))])
|
|
||||||
|
|
||||||
def testDoubleRetry(self):
|
|
||||||
self.client.retries = 2
|
|
||||||
self.client.timeout = 0
|
|
||||||
packet = MockPacket(AccessRequest)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
self.assertEqual(
|
|
||||||
self.client._socket.output,
|
|
||||||
[("request packet", (self.server, 432)),
|
|
||||||
("request packet", (self.server, 432))])
|
|
||||||
|
|
||||||
def testAuthDelay(self):
|
|
||||||
self.client.retries = 2
|
|
||||||
self.client.timeout = 1
|
|
||||||
packet = MockPacket(AccessRequest)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
self.failIf("Acct-Delay-Time" in packet)
|
|
||||||
|
|
||||||
def testSingleAccountDelay(self):
|
|
||||||
self.client.retries = 2
|
|
||||||
self.client.timeout = 1
|
|
||||||
packet = MockPacket(AccountingRequest)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
self.assertEqual(packet["Acct-Delay-Time"], [1])
|
|
||||||
|
|
||||||
def testDoubleAccountDelay(self):
|
|
||||||
self.client.retries = 3
|
|
||||||
self.client.timeout = 1
|
|
||||||
packet = MockPacket(AccountingRequest)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
self.assertEqual(packet["Acct-Delay-Time"], [2])
|
|
||||||
|
|
||||||
def testIgnorePacketError(self):
|
|
||||||
self.client.retries = 1
|
|
||||||
self.client.timeout = 1
|
|
||||||
self.client._socket = MockSocket(1, 2, b"valid reply")
|
|
||||||
packet = MockPacket(AccountingRequest, verify=True, error=True)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
|
|
||||||
def testValidReply(self):
|
|
||||||
self.client.retries = 1
|
|
||||||
self.client.timeout = 1
|
|
||||||
self.client._socket = MockSocket(1, 2, b"valid reply")
|
|
||||||
self.client._poll = MockPoll()
|
|
||||||
MockPoll.results = [(1, select.POLLIN)]
|
|
||||||
packet = MockPacket(AccountingRequest, verify=True)
|
|
||||||
reply = self.client._send_packet(packet, 432)
|
|
||||||
self.failUnless(reply is packet.reply)
|
|
||||||
|
|
||||||
def testInvalidReply(self):
|
|
||||||
self.client.retries = 1
|
|
||||||
self.client.timeout = 1
|
|
||||||
self.client._socket = MockSocket(1, 2, b"invalid reply")
|
|
||||||
MockPoll.results = [(1, select.POLLIN)]
|
|
||||||
packet = MockPacket(AccountingRequest, verify=False)
|
|
||||||
self.assertRaises(Timeout, self.client._send_packet, packet, 432)
|
|
||||||
|
|
||||||
|
|
||||||
class OtherTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = object()
|
|
||||||
self.client = Client(self.server, secret=b'zeer geheim')
|
|
||||||
|
|
||||||
def testCreateAuthPacket(self):
|
|
||||||
packet = self.client.CreateAuthPacket(id=15)
|
|
||||||
self.failUnless(isinstance(packet, AuthPacket))
|
|
||||||
self.failUnless(packet.dict is self.client.dict)
|
|
||||||
self.assertEqual(packet.id, 15)
|
|
||||||
self.assertEqual(packet.secret, b'zeer geheim')
|
|
||||||
|
|
||||||
def testCreateAcctPacket(self):
|
|
||||||
packet = self.client.CreateAcctPacket(id=15)
|
|
||||||
self.failUnless(isinstance(packet, AcctPacket))
|
|
||||||
self.failUnless(packet.dict is self.client.dict)
|
|
||||||
self.assertEqual(packet.id, 15)
|
|
||||||
self.assertEqual(packet.secret, b'zeer geheim')
|
|
||||||
@@ -1,332 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import operator
|
|
||||||
import os
|
|
||||||
from io import StringIO
|
|
||||||
from pyrad.tests import home
|
|
||||||
from pyrad.dictionary import Attribute
|
|
||||||
from pyrad.dictionary import Dictionary
|
|
||||||
from pyrad.dictionary import ParseError
|
|
||||||
from pyrad.tools import DecodeAttr
|
|
||||||
from pyrad.dictfile import DictFile
|
|
||||||
|
|
||||||
|
|
||||||
class AttributeTests(unittest.TestCase):
|
|
||||||
def testInvalidDataType(self):
|
|
||||||
self.assertRaises(ValueError, Attribute, 'name', 'code', 'datatype')
|
|
||||||
|
|
||||||
def testConstructionParameters(self):
|
|
||||||
attr = Attribute('name', 'code', 'integer', False, 'vendor')
|
|
||||||
self.assertEqual(attr.name, 'name')
|
|
||||||
self.assertEqual(attr.code, 'code')
|
|
||||||
self.assertEqual(attr.type, 'integer')
|
|
||||||
self.assertEqual(attr.is_sub_attribute, False)
|
|
||||||
self.assertEqual(attr.vendor, 'vendor')
|
|
||||||
self.assertEqual(len(attr.values), 0)
|
|
||||||
self.assertEqual(len(attr.sub_attributes), 0)
|
|
||||||
|
|
||||||
def testNamedConstructionParameters(self):
|
|
||||||
attr = Attribute(name='name', code='code', datatype='integer',
|
|
||||||
vendor='vendor')
|
|
||||||
self.assertEqual(attr.name, 'name')
|
|
||||||
self.assertEqual(attr.code, 'code')
|
|
||||||
self.assertEqual(attr.type, 'integer')
|
|
||||||
self.assertEqual(attr.vendor, 'vendor')
|
|
||||||
self.assertEqual(len(attr.values), 0)
|
|
||||||
|
|
||||||
def testValues(self):
|
|
||||||
attr = Attribute('name', 'code', 'integer', False, 'vendor',
|
|
||||||
dict(pie='custard', shake='vanilla'))
|
|
||||||
self.assertEqual(len(attr.values), 2)
|
|
||||||
self.assertEqual(attr.values['shake'], 'vanilla')
|
|
||||||
|
|
||||||
|
|
||||||
class DictionaryInterfaceTests(unittest.TestCase):
|
|
||||||
def testEmptyDictionary(self):
|
|
||||||
dict = Dictionary()
|
|
||||||
self.assertEqual(len(dict), 0)
|
|
||||||
|
|
||||||
def testContainment(self):
|
|
||||||
dict = Dictionary()
|
|
||||||
self.assertEqual('test' in dict, False)
|
|
||||||
dict.attributes['test'] = 'dummy'
|
|
||||||
self.assertEqual('test' in dict, True)
|
|
||||||
|
|
||||||
def testReadonlyContainer(self):
|
|
||||||
dict = Dictionary()
|
|
||||||
self.assertRaises(TypeError,
|
|
||||||
operator.setitem, dict, 'test', 'dummy')
|
|
||||||
self.assertRaises(AttributeError,
|
|
||||||
operator.attrgetter('clear'), dict)
|
|
||||||
self.assertRaises(AttributeError,
|
|
||||||
operator.attrgetter('update'), dict)
|
|
||||||
|
|
||||||
|
|
||||||
class DictionaryParsingTests(unittest.TestCase):
|
|
||||||
|
|
||||||
simple_dict_values = [
|
|
||||||
('Test-String', 1, 'string'),
|
|
||||||
('Test-Octets', 2, 'octets'),
|
|
||||||
('Test-Integer', 0x03, 'integer'),
|
|
||||||
('Test-Ip-Address', 4, 'ipaddr'),
|
|
||||||
('Test-Ipv6-Address', 5, 'ipv6addr'),
|
|
||||||
('Test-If-Id', 6, 'ifid'),
|
|
||||||
('Test-Date', 7, 'date'),
|
|
||||||
('Test-Abinary', 8, 'abinary'),
|
|
||||||
('Test-Tlv', 9, 'tlv'),
|
|
||||||
('Test-Tlv-Str', 1, 'string'),
|
|
||||||
('Test-Tlv-Int', 2, 'integer'),
|
|
||||||
('Test-Integer64', 10, 'integer64'),
|
|
||||||
('Test-Integer64-Hex', 10, 'integer64'),
|
|
||||||
('Test-Integer64-Oct', 10, 'integer64'),
|
|
||||||
]
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'simple'))
|
|
||||||
|
|
||||||
def testParseEmptyDictionary(self):
|
|
||||||
dict = Dictionary(StringIO(''))
|
|
||||||
self.assertEqual(len(dict), 0)
|
|
||||||
|
|
||||||
def testParseMultipleDictionaries(self):
|
|
||||||
dict = Dictionary(StringIO(''))
|
|
||||||
self.assertEqual(len(dict), 0)
|
|
||||||
one = StringIO('ATTRIBUTE Test-First 1 string')
|
|
||||||
two = StringIO('ATTRIBUTE Test-Second 2 string')
|
|
||||||
dict = Dictionary(StringIO(''), one, two)
|
|
||||||
self.assertEqual(len(dict), 2)
|
|
||||||
|
|
||||||
def testParseSimpleDictionary(self):
|
|
||||||
self.assertEqual(len(self.dict), len(self.simple_dict_values))
|
|
||||||
for (attr, code, type) in self.simple_dict_values:
|
|
||||||
attr = self.dict[attr]
|
|
||||||
self.assertEqual(attr.code, code)
|
|
||||||
self.assertEqual(attr.type, type)
|
|
||||||
|
|
||||||
def testAttributeTooFewColumnsError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(
|
|
||||||
StringIO('ATTRIBUTE Oops-Too-Few-Columns'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('attribute' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testAttributeUnknownTypeError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 dummy'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('dummy' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testAttributeUnknownVendorError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 Simplon'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('Simplon' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testAttributeOptions(self):
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'ATTRIBUTE Option-Type 1 string has_tag,encrypt=1'))
|
|
||||||
self.assertEqual(self.dict['Option-Type'].has_tag, True)
|
|
||||||
self.assertEqual(self.dict['Option-Type'].encrypt, 1)
|
|
||||||
|
|
||||||
def testAttributeEncryptionError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'ATTRIBUTE Test-Type 1 string encrypt=4'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('encrypt' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testValueTooFewColumnsError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('VALUE Oops-Too-Few-Columns'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('value' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testValueForUnknownAttributeError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'VALUE Test-Attribute Test-Text 1'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('unknown attribute' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testIntegerValueParsing(self):
|
|
||||||
self.assertEqual(len(self.dict['Test-Integer'].values), 0)
|
|
||||||
self.dict.read_dictionary(StringIO('VALUE Test-Integer Value-Six 5'))
|
|
||||||
self.assertEqual(len(self.dict['Test-Integer'].values), 1)
|
|
||||||
self.assertEqual(
|
|
||||||
DecodeAttr('integer', self.dict['Test-Integer'].values['Value-Six']),
|
|
||||||
5)
|
|
||||||
|
|
||||||
def testInteger64ValueParsing(self):
|
|
||||||
self.assertEqual(len(self.dict['Test-Integer64'].values), 0)
|
|
||||||
self.dict.read_dictionary(StringIO('VALUE Test-Integer64 Value-Six 5'))
|
|
||||||
self.assertEqual(len(self.dict['Test-Integer64'].values), 1)
|
|
||||||
self.assertEqual(
|
|
||||||
DecodeAttr('integer64', self.dict['Test-Integer64'].values['Value-Six']),
|
|
||||||
5)
|
|
||||||
|
|
||||||
def testStringValueParsing(self):
|
|
||||||
self.assertEqual(len(self.dict['Test-String'].values), 0)
|
|
||||||
self.dict.read_dictionary(StringIO('VALUE Test-String Value-Custard custardpie'))
|
|
||||||
self.assertEqual(len(self.dict['Test-String'].values), 1)
|
|
||||||
self.assertEqual(
|
|
||||||
DecodeAttr('string', self.dict['Test-String'].values['Value-Custard']),
|
|
||||||
'custardpie')
|
|
||||||
|
|
||||||
def testTlvParsing(self):
|
|
||||||
self.assertEqual(len(self.dict['Test-Tlv'].sub_attributes), 2)
|
|
||||||
self.assertEqual(self.dict['Test-Tlv'].sub_attributes,
|
|
||||||
{1: 'Test-Tlv-Str', 2: 'Test-Tlv-Int'})
|
|
||||||
|
|
||||||
def testSubTlvParsing(self):
|
|
||||||
for (attr, _, _) in self.simple_dict_values:
|
|
||||||
if attr.startswith('Test-Tlv-'):
|
|
||||||
self.assertEqual(self.dict[attr].is_sub_attribute, True)
|
|
||||||
self.assertEqual(self.dict[attr].parent, self.dict['Test-Tlv'])
|
|
||||||
else:
|
|
||||||
self.assertEqual(self.dict[attr].is_sub_attribute, False)
|
|
||||||
self.assertEqual(self.dict[attr].parent, None)
|
|
||||||
|
|
||||||
# tlv with vendor
|
|
||||||
full_dict = Dictionary(os.path.join(self.path, 'full'))
|
|
||||||
self.assertEqual(full_dict['Simplon-Tlv-Str'].is_sub_attribute, True)
|
|
||||||
self.assertEqual(full_dict['Simplon-Tlv-Str'].parent, full_dict['Simplon-Tlv'])
|
|
||||||
self.assertEqual(full_dict['Simplon-Tlv-Int'].is_sub_attribute, True)
|
|
||||||
self.assertEqual(full_dict['Simplon-Tlv-Int'].parent, full_dict['Simplon-Tlv'])
|
|
||||||
|
|
||||||
def testVenderTooFewColumnsError(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('VENDOR Simplon'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('vendor' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testVendorParsing(self):
|
|
||||||
self.assertRaises(ParseError, self.dict.read_dictionary,
|
|
||||||
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
|
|
||||||
self.dict.read_dictionary(StringIO('VENDOR Simplon 42'))
|
|
||||||
self.assertEqual(self.dict.vendors['Simplon'], 42)
|
|
||||||
self.dict.read_dictionary(StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
|
|
||||||
self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1))
|
|
||||||
|
|
||||||
def testVendorOptionError(self):
|
|
||||||
self.assertRaises(ParseError, self.dict.read_dictionary,
|
|
||||||
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('VENDOR Simplon 42 badoption'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('option' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testVendorFormatError(self):
|
|
||||||
self.assertRaises(ParseError, self.dict.read_dictionary,
|
|
||||||
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('VENDOR Simplon 42 format=5,4'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('format' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testVendorFormatSyntaxError(self):
|
|
||||||
self.assertRaises(ParseError, self.dict.read_dictionary,
|
|
||||||
StringIO('ATTRIBUTE Test-Type 1 integer Simplon'))
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('VENDOR Simplon 42 format=a,1'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('Syntax' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testBeginVendorTooFewColumns(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('BEGIN-VENDOR'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('begin-vendor' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testBeginVendorUnknownVendor(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('BEGIN-VENDOR Simplon'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('Simplon' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testBeginVendorParsing(self):
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'VENDOR Simplon 42\n'
|
|
||||||
'BEGIN-VENDOR Simplon\n'
|
|
||||||
'ATTRIBUTE Test-Type 1 integer'))
|
|
||||||
self.assertEqual(self.dict.attrindex['Test-Type'], (42, 1))
|
|
||||||
|
|
||||||
def testEndVendorUnknownVendor(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO('END-VENDOR'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('end-vendor' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testEndVendorUnbalanced(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'VENDOR Simplon 42\n'
|
|
||||||
'BEGIN-VENDOR Simplon\n'
|
|
||||||
'END-VENDOR Oops\n'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('Oops' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testEndVendorParsing(self):
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'VENDOR Simplon 42\n'
|
|
||||||
'BEGIN-VENDOR Simplon\n'
|
|
||||||
'END-VENDOR Simplon\n'
|
|
||||||
'ATTRIBUTE Test-Type 1 integer'))
|
|
||||||
self.assertEqual(self.dict.attrindex['Test-Type'], 1)
|
|
||||||
|
|
||||||
def testInclude(self):
|
|
||||||
try:
|
|
||||||
self.dict.read_dictionary(StringIO(
|
|
||||||
'$INCLUDE this_file_does_not_exist\n'
|
|
||||||
'VENDOR Simplon 42\n'
|
|
||||||
'BEGIN-VENDOR Simplon\n'
|
|
||||||
'END-VENDOR Simplon\n'
|
|
||||||
'ATTRIBUTE Test-Type 1 integer'))
|
|
||||||
except IOError as e:
|
|
||||||
self.assertEqual('this_file_does_not_exist' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testDictFilePostParse(self):
|
|
||||||
f = DictFile(StringIO(
|
|
||||||
'VENDOR Simplon 42\n'))
|
|
||||||
for _ in f:
|
|
||||||
pass
|
|
||||||
self.assertEqual(f.file(), '')
|
|
||||||
self.assertEqual(f.line(), -1)
|
|
||||||
|
|
||||||
def testDictFileParseError(self):
|
|
||||||
tmpdict = Dictionary()
|
|
||||||
try:
|
|
||||||
tmpdict.read_dictionary(os.path.join(self.path, 'dictfiletest'))
|
|
||||||
except ParseError as e:
|
|
||||||
self.assertEqual('dictfiletest' in str(e), True)
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
import unittest
|
|
||||||
from pyrad.host import Host
|
|
||||||
from pyrad.packet import Packet
|
|
||||||
from pyrad.packet import AuthPacket
|
|
||||||
from pyrad.packet import AcctPacket
|
|
||||||
|
|
||||||
|
|
||||||
class ConstructionTests(unittest.TestCase):
|
|
||||||
def testSimpleConstruction(self):
|
|
||||||
host = Host()
|
|
||||||
self.assertEqual(host.authport, 1812)
|
|
||||||
self.assertEqual(host.acctport, 1813)
|
|
||||||
|
|
||||||
def testParameterOrder(self):
|
|
||||||
host = Host(123, 456, 789, 101)
|
|
||||||
self.assertEqual(host.authport, 123)
|
|
||||||
self.assertEqual(host.acctport, 456)
|
|
||||||
self.assertEqual(host.coaport, 789)
|
|
||||||
self.assertEqual(host.dict, 101)
|
|
||||||
|
|
||||||
def testNamedParameters(self):
|
|
||||||
host = Host(authport=123, acctport=456, coaport=789, dict=101)
|
|
||||||
self.assertEqual(host.authport, 123)
|
|
||||||
self.assertEqual(host.acctport, 456)
|
|
||||||
self.assertEqual(host.coaport, 789)
|
|
||||||
self.assertEqual(host.dict, 101)
|
|
||||||
|
|
||||||
|
|
||||||
class PacketCreationTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.host = Host()
|
|
||||||
|
|
||||||
def testCreatePacket(self):
|
|
||||||
packet = self.host.CreatePacket(id=15)
|
|
||||||
self.failUnless(isinstance(packet, Packet))
|
|
||||||
self.failUnless(packet.dict is self.host.dict)
|
|
||||||
self.assertEqual(packet.id, 15)
|
|
||||||
|
|
||||||
def testCreateAuthPacket(self):
|
|
||||||
packet = self.host.CreateAuthPacket(id=15)
|
|
||||||
self.failUnless(isinstance(packet, AuthPacket))
|
|
||||||
self.failUnless(packet.dict is self.host.dict)
|
|
||||||
self.assertEqual(packet.id, 15)
|
|
||||||
|
|
||||||
def testCreateAcctPacket(self):
|
|
||||||
packet = self.host.CreateAcctPacket(id=15)
|
|
||||||
self.failUnless(isinstance(packet, AcctPacket))
|
|
||||||
self.failUnless(packet.dict is self.host.dict)
|
|
||||||
self.assertEqual(packet.id, 15)
|
|
||||||
|
|
||||||
|
|
||||||
class MockPacket:
|
|
||||||
packet = object()
|
|
||||||
replypacket = object()
|
|
||||||
source = object()
|
|
||||||
|
|
||||||
def Packet(self):
|
|
||||||
return self.packet
|
|
||||||
|
|
||||||
def ReplyPacket(self):
|
|
||||||
return self.replypacket
|
|
||||||
|
|
||||||
|
|
||||||
class MockFd:
|
|
||||||
data = None
|
|
||||||
target = None
|
|
||||||
|
|
||||||
def sendto(self, data, target):
|
|
||||||
self.data = data
|
|
||||||
self.target = target
|
|
||||||
|
|
||||||
|
|
||||||
class PacketSendTest(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.host = Host()
|
|
||||||
self.fd = MockFd()
|
|
||||||
self.packet = MockPacket()
|
|
||||||
|
|
||||||
def testSendPacket(self):
|
|
||||||
self.host.SendPacket(self.fd, self.packet)
|
|
||||||
self.failUnless(self.fd.data is self.packet.packet)
|
|
||||||
self.failUnless(self.fd.target is self.packet.source)
|
|
||||||
|
|
||||||
def testSendReplyPacket(self):
|
|
||||||
self.host.SendReplyPacket(self.fd, self.packet)
|
|
||||||
self.failUnless(self.fd.data is self.packet.replypacket)
|
|
||||||
self.failUnless(self.fd.target is self.packet.source)
|
|
||||||
@@ -1,535 +0,0 @@
|
|||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from pyrad import packet
|
|
||||||
from pyrad.client import Client
|
|
||||||
from pyrad.tests import home
|
|
||||||
from pyrad.dictionary import Dictionary
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
md5_constructor = hashlib.md5
|
|
||||||
|
|
||||||
|
|
||||||
class UtilityTests(unittest.TestCase):
|
|
||||||
def testGenerateID(self):
|
|
||||||
id = packet.CreateID()
|
|
||||||
self.assertTrue(isinstance(id, int))
|
|
||||||
newid = packet.CreateID()
|
|
||||||
self.assertNotEqual(id, newid)
|
|
||||||
|
|
||||||
|
|
||||||
class PacketConstructionTests(unittest.TestCase):
|
|
||||||
klass = packet.Packet
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'simple'))
|
|
||||||
|
|
||||||
def testBasicConstructor(self):
|
|
||||||
pkt = self.klass()
|
|
||||||
self.assertTrue(isinstance(pkt.code, int))
|
|
||||||
self.assertTrue(isinstance(pkt.id, int))
|
|
||||||
self.assertTrue(isinstance(pkt.secret, bytes))
|
|
||||||
|
|
||||||
def testNamedConstructor(self):
|
|
||||||
pkt = self.klass(code=26, id=38, secret=b'secret',
|
|
||||||
authenticator=b'authenticator',
|
|
||||||
dict='fakedict')
|
|
||||||
self.assertEqual(pkt.code, 26)
|
|
||||||
self.assertEqual(pkt.id, 38)
|
|
||||||
self.assertEqual(pkt.secret, b'secret')
|
|
||||||
self.assertEqual(pkt.authenticator, b'authenticator')
|
|
||||||
self.assertEqual(pkt.dict, 'fakedict')
|
|
||||||
|
|
||||||
def testConstructWithDictionary(self):
|
|
||||||
pkt = self.klass(dict=self.dict)
|
|
||||||
self.assertTrue(pkt.dict is self.dict)
|
|
||||||
|
|
||||||
def testConstructorIgnoredParameters(self):
|
|
||||||
marker = []
|
|
||||||
pkt = self.klass(fd=marker)
|
|
||||||
self.assertFalse(getattr(pkt, 'fd', None) is marker)
|
|
||||||
|
|
||||||
def testSecretMustBeBytestring(self):
|
|
||||||
self.assertRaises(TypeError, self.klass, secret='secret')
|
|
||||||
|
|
||||||
def testConstructorWithAttributes(self):
|
|
||||||
pkt = self.klass(**{'Test-String': 'this works', 'dict': self.dict})
|
|
||||||
self.assertEqual(pkt['Test-String'], ['this works'])
|
|
||||||
|
|
||||||
def testConstructorWithTlvAttribute(self):
|
|
||||||
pkt = self.klass(**{
|
|
||||||
'Test-Tlv-Str': 'this works',
|
|
||||||
'Test-Tlv-Int': 10,
|
|
||||||
'dict': self.dict
|
|
||||||
})
|
|
||||||
self.assertEqual(
|
|
||||||
pkt['Test-Tlv'],
|
|
||||||
{'Test-Tlv-Str': ['this works'], 'Test-Tlv-Int': [10]}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PacketTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'full'))
|
|
||||||
self.packet = packet.Packet(
|
|
||||||
id=0, secret=b'secret',
|
|
||||||
authenticator=b'01234567890ABCDEF', dict=self.dict)
|
|
||||||
|
|
||||||
def testCreateReply(self):
|
|
||||||
reply = self.packet.CreateReply(**{'Test-Integer': 10})
|
|
||||||
self.assertEqual(reply.id, self.packet.id)
|
|
||||||
self.assertEqual(reply.secret, self.packet.secret)
|
|
||||||
self.assertEqual(reply.authenticator, self.packet.authenticator)
|
|
||||||
self.assertEqual(reply['Test-Integer'], [10])
|
|
||||||
|
|
||||||
def testAttributeAccess(self):
|
|
||||||
self.packet['Test-Integer'] = 10
|
|
||||||
self.assertEqual(self.packet['Test-Integer'], [10])
|
|
||||||
self.assertEqual(self.packet[3], [b'\x00\x00\x00\x0a'])
|
|
||||||
|
|
||||||
self.packet['Test-String'] = 'dummy'
|
|
||||||
self.assertEqual(self.packet['Test-String'], ['dummy'])
|
|
||||||
self.assertEqual(self.packet[1], [b'dummy'])
|
|
||||||
|
|
||||||
def testAttributeValueAccess(self):
|
|
||||||
self.packet['Test-Integer'] = 'Three'
|
|
||||||
self.assertEqual(self.packet['Test-Integer'], ['Three'])
|
|
||||||
self.assertEqual(self.packet[3], [b'\x00\x00\x00\x03'])
|
|
||||||
|
|
||||||
def testVendorAttributeAccess(self):
|
|
||||||
self.packet['Simplon-Number'] = 10
|
|
||||||
self.assertEqual(self.packet['Simplon-Number'], [10])
|
|
||||||
self.assertEqual(self.packet[(16, 1)], [b'\x00\x00\x00\x0a'])
|
|
||||||
|
|
||||||
self.packet['Simplon-Number'] = 'Four'
|
|
||||||
self.assertEqual(self.packet['Simplon-Number'], ['Four'])
|
|
||||||
self.assertEqual(self.packet[(16, 1)], [b'\x00\x00\x00\x04'])
|
|
||||||
|
|
||||||
def testRawAttributeAccess(self):
|
|
||||||
marker = [b'']
|
|
||||||
self.packet[1] = marker
|
|
||||||
self.assertTrue(self.packet[1] is marker)
|
|
||||||
self.packet[(16, 1)] = marker
|
|
||||||
self.assertTrue(self.packet[(16, 1)] is marker)
|
|
||||||
|
|
||||||
def testHasKey(self):
|
|
||||||
self.assertEqual('Test-String' in self.packet, False)
|
|
||||||
self.packet['Test-String'] = 'dummy'
|
|
||||||
self.assertEqual('Test-String' in self.packet, True)
|
|
||||||
self.assertEqual(1 in self.packet, True)
|
|
||||||
|
|
||||||
def testHasKeyWithUnknownKey(self):
|
|
||||||
self.assertEqual('Unknown-Attribute' in self.packet, False)
|
|
||||||
|
|
||||||
def testDelItem(self):
|
|
||||||
self.packet['Test-String'] = 'dummy'
|
|
||||||
del self.packet['Test-String']
|
|
||||||
self.assertEqual('Test-String' in self.packet, False)
|
|
||||||
self.packet['Test-String'] = 'dummy'
|
|
||||||
del self.packet[1]
|
|
||||||
self.assertEqual('Test-String' in self.packet, False)
|
|
||||||
|
|
||||||
def testKeys(self):
|
|
||||||
self.assertEqual(self.packet.keys(), [])
|
|
||||||
self.packet['Test-String'] = 'dummy'
|
|
||||||
self.assertEqual(self.packet.keys(), ['Test-String'])
|
|
||||||
self.packet['Test-Integer'] = 10
|
|
||||||
self.assertEqual(self.packet.keys(), ['Test-String', 'Test-Integer'])
|
|
||||||
OrderedDict.__setitem__(self.packet, 12345, None)
|
|
||||||
self.assertEqual(self.packet.keys(),
|
|
||||||
['Test-String', 'Test-Integer', 12345])
|
|
||||||
|
|
||||||
def testCreateAuthenticator(self):
|
|
||||||
a = packet.Packet.CreateAuthenticator()
|
|
||||||
self.assertTrue(isinstance(a, bytes))
|
|
||||||
self.assertEqual(len(a), 16)
|
|
||||||
|
|
||||||
b = packet.Packet.CreateAuthenticator()
|
|
||||||
self.assertNotEqual(a, b)
|
|
||||||
|
|
||||||
def testGenerateID(self):
|
|
||||||
id = self.packet.CreateID()
|
|
||||||
self.assertTrue(isinstance(id, int))
|
|
||||||
newid = self.packet.CreateID()
|
|
||||||
self.assertNotEqual(id, newid)
|
|
||||||
|
|
||||||
def testReplyPacket(self):
|
|
||||||
reply = self.packet.ReplyPacket()
|
|
||||||
self.assertEqual(
|
|
||||||
reply,
|
|
||||||
(b'\x00\x00\x00\x14\xb0\x5e\x4b\xfb\xcc\x1c'
|
|
||||||
b'\x8c\x8e\xc4\x72\xac\xea\x87\x45\x63\xa7'))
|
|
||||||
|
|
||||||
def testVerifyReply(self):
|
|
||||||
reply = self.packet.CreateReply()
|
|
||||||
reply.id += 1
|
|
||||||
with self.assertRaises(packet.PacketError):
|
|
||||||
self.packet.VerifyReply(reply.ReplyPacket())
|
|
||||||
reply.id = self.packet.id
|
|
||||||
|
|
||||||
reply.secret = b'different'
|
|
||||||
with self.assertRaises(packet.PacketError):
|
|
||||||
self.packet.VerifyReply(reply.ReplyPacket())
|
|
||||||
reply.secret = self.packet.secret
|
|
||||||
|
|
||||||
reply.authenticator = b'X' * 16
|
|
||||||
with self.assertRaises(packet.PacketError):
|
|
||||||
self.packet.VerifyReply(reply.ReplyPacket())
|
|
||||||
|
|
||||||
def testPktEncodeAttribute(self):
|
|
||||||
encode = self.packet._pkt_encode_attribute
|
|
||||||
|
|
||||||
# Encode a normal attribute
|
|
||||||
self.assertEqual(
|
|
||||||
encode(1, b'value'),
|
|
||||||
b'\x01\x07value')
|
|
||||||
# Encode a vendor attribute
|
|
||||||
self.assertEqual(
|
|
||||||
encode((1, 2), b'value'),
|
|
||||||
b'\x1a\x0d\x00\x00\x00\x01\x02\x07value')
|
|
||||||
|
|
||||||
def testPktEncodeTlvAttribute(self):
|
|
||||||
encode = self.packet._pkt_encode_tlv
|
|
||||||
|
|
||||||
# Encode a normal tlv attribute
|
|
||||||
self.assertEqual(
|
|
||||||
encode(4, {1: [b'value'], 2: [b'\x00\x00\x00\x02']}),
|
|
||||||
b'\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02')
|
|
||||||
|
|
||||||
# Encode a normal tlv attribute with several sub attribute instances
|
|
||||||
self.assertEqual(
|
|
||||||
encode(4, {1: [b'value', b'other'], 2: [b'\x00\x00\x00\x02']}),
|
|
||||||
b'\x04\x16\x01\x07value\x02\x06\x00\x00\x00\x02\x01\x07other')
|
|
||||||
# Encode a vendor tlv attribute
|
|
||||||
self.assertEqual(
|
|
||||||
encode((16, 3), {1: [b'value'], 2: [b'\x00\x00\x00\x02']}),
|
|
||||||
b'\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02')
|
|
||||||
|
|
||||||
def testPktEncodeLongTlvAttribute(self):
|
|
||||||
encode = self.packet._pkt_encode_tlv
|
|
||||||
|
|
||||||
long_str = b'a' * 245
|
|
||||||
# Encode a long tlv attribute - check it is split between AVPs
|
|
||||||
self.assertEqual(
|
|
||||||
encode(4, {1: [b'value', long_str], 2: [b'\x00\x00\x00\x02']}),
|
|
||||||
b'\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x02\x04\xf9\x01\xf7' + long_str)
|
|
||||||
|
|
||||||
# Encode a long vendor tlv attribute
|
|
||||||
first_avp = b'\x1a\x15\x00\x00\x00\x10\x03\x0f\x01\x07value\x02\x06\x00\x00\x00\x02'
|
|
||||||
second_avp = b'\x1a\xff\x00\x00\x00\x10\x03\xf9\x01\xf7' + long_str
|
|
||||||
self.assertEqual(
|
|
||||||
encode((16, 3), {1: [b'value', long_str], 2: [b'\x00\x00\x00\x02']}),
|
|
||||||
first_avp + second_avp)
|
|
||||||
|
|
||||||
def testpkt_encode_attributes(self):
|
|
||||||
self.packet[1] = [b'value']
|
|
||||||
self.assertEqual(self.packet._pkt_encode_attributes(),
|
|
||||||
b'\x01\x07value')
|
|
||||||
|
|
||||||
self.packet.clear()
|
|
||||||
self.packet[(16, 2)] = [b'value']
|
|
||||||
self.assertEqual(self.packet._pkt_encode_attributes(),
|
|
||||||
b'\x1a\x0d\x00\x00\x00\x10\x02\x07value')
|
|
||||||
|
|
||||||
self.packet.clear()
|
|
||||||
self.packet[1] = [b'one', b'two', b'three']
|
|
||||||
self.assertEqual(self.packet._pkt_encode_attributes(),
|
|
||||||
b'\x01\x05one\x01\x05two\x01\x07three')
|
|
||||||
|
|
||||||
self.packet.clear()
|
|
||||||
self.packet[1] = [b'value']
|
|
||||||
self.packet[(16, 2)] = [b'value']
|
|
||||||
self.assertEqual(self.packet._pkt_encode_attributes(),
|
|
||||||
b'\x01\x07value\x1a\x0d\x00\x00\x00\x10\x02\x07value')
|
|
||||||
|
|
||||||
def testPktDecodeVendorAttribute(self):
|
|
||||||
decode = self.packet._pkt_decode_vendor_attribute
|
|
||||||
|
|
||||||
# Non-RFC2865 recommended form
|
|
||||||
self.assertEqual(decode(b''), [(26, b'')])
|
|
||||||
self.assertEqual(decode(b'12345'), [(26, b'12345')])
|
|
||||||
|
|
||||||
# Almost RFC2865 recommended form: bad length value
|
|
||||||
self.assertEqual(decode(b'\x00\x00\x00\x01\x02\x06value'),
|
|
||||||
[(26, b'\x00\x00\x00\x01\x02\x06value')])
|
|
||||||
|
|
||||||
# Proper RFC2865 recommended form
|
|
||||||
self.assertEqual(decode(b'\x00\x00\x00\x10\x02\x07value'),
|
|
||||||
[((16, 2), b'value')])
|
|
||||||
|
|
||||||
def testPktDecodeTlvAttribute(self):
|
|
||||||
decode = self.packet._pkt_decode_tlv_attribute
|
|
||||||
|
|
||||||
decode(4, b'\x01\x07value')
|
|
||||||
self.assertEqual(self.packet[4], {1: [b'value']})
|
|
||||||
|
|
||||||
# add another instance of the same sub attribute
|
|
||||||
decode(4, b'\x01\x07other')
|
|
||||||
self.assertEqual(self.packet[4], {1: [b'value', b'other']})
|
|
||||||
|
|
||||||
# add a different sub attribute
|
|
||||||
decode(4, b'\x02\x07\x00\x00\x00\x01')
|
|
||||||
self.assertEqual(self.packet[4], {
|
|
||||||
1: [b'value', b'other'],
|
|
||||||
2: [b'\x00\x00\x00\x01']
|
|
||||||
})
|
|
||||||
|
|
||||||
def testDecodePacketWithEmptyPacket(self):
|
|
||||||
try:
|
|
||||||
self.packet.DecodePacket(b'')
|
|
||||||
except packet.PacketError as e:
|
|
||||||
self.assertTrue('header is corrupt' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testDecodePacketWithInvalidLength(self):
|
|
||||||
try:
|
|
||||||
self.packet.DecodePacket(b'\x00\x00\x00\x001234567890123456')
|
|
||||||
except packet.PacketError as e:
|
|
||||||
self.assertTrue('invalid length' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testDecodePacketWithTooBigPacket(self):
|
|
||||||
try:
|
|
||||||
self.packet.DecodePacket(b'\x00\x00\x24\x00' + (0x2400 - 4) * b'X')
|
|
||||||
except packet.PacketError as e:
|
|
||||||
self.assertTrue('too long' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testDecodePacketWithPartialAttributes(self):
|
|
||||||
try:
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x151234567890123456\x00')
|
|
||||||
except packet.PacketError as e:
|
|
||||||
self.assertTrue('header is corrupt' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testDecodePacketWithoutAttributes(self):
|
|
||||||
self.packet.DecodePacket(b'\x01\x02\x00\x141234567890123456')
|
|
||||||
self.assertEqual(self.packet.code, 1)
|
|
||||||
self.assertEqual(self.packet.id, 2)
|
|
||||||
self.assertEqual(self.packet.authenticator, b'1234567890123456')
|
|
||||||
self.assertEqual(self.packet.keys(), [])
|
|
||||||
|
|
||||||
def testDecodePacketWithBadAttribute(self):
|
|
||||||
try:
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x161234567890123456\x00\x01')
|
|
||||||
except packet.PacketError as e:
|
|
||||||
self.assertTrue('too small' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testDecodePacketWithEmptyAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x161234567890123456\x01\x02')
|
|
||||||
self.assertEqual(self.packet[1], [b''])
|
|
||||||
|
|
||||||
def testDecodePacketWithAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x1b1234567890123456\x01\x07value')
|
|
||||||
self.assertEqual(self.packet[1], [b'value'])
|
|
||||||
|
|
||||||
def testDecodePacketWithTlvAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x1d1234567890123456\x04\x09\x01\x07value')
|
|
||||||
self.assertEqual(self.packet[4], {1: [b'value']})
|
|
||||||
|
|
||||||
def testDecodePacketWithVendorTlvAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x231234567890123456\x1a\x0f\x00\x00\x00\x10\x03\x09\x01\x07value')
|
|
||||||
self.assertEqual(self.packet[(16, 3)], {1: [b'value']})
|
|
||||||
|
|
||||||
def testDecodePacketWithTlvAttributeWith2SubAttributes(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x231234567890123456\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x09')
|
|
||||||
self.assertEqual(self.packet[4], {1: [b'value'], 2: [b'\x00\x00\x00\x09']})
|
|
||||||
|
|
||||||
def testDecodePacketWithSplitTlvAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x251234567890123456\x04\x09\x01\x07value\x04\x09\x02\x06\x00\x00\x00\x09')
|
|
||||||
self.assertEqual(self.packet[4], {1: [b'value'], 2: [b'\x00\x00\x00\x09']})
|
|
||||||
|
|
||||||
def testDecodePacketWithMultiValuedAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two')
|
|
||||||
self.assertEqual(self.packet[1], [b'one', b'two'])
|
|
||||||
|
|
||||||
def testDecodePacketWithTwoAttributes(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x1e1234567890123456\x01\x05one\x01\x05two')
|
|
||||||
self.assertEqual(self.packet[1], [b'one', b'two'])
|
|
||||||
|
|
||||||
def testDecodePacketWithVendorAttribute(self):
|
|
||||||
self.packet.DecodePacket(
|
|
||||||
b'\x01\x02\x00\x1b1234567890123456\x1a\x07value')
|
|
||||||
self.assertEqual(self.packet[26], [b'value'])
|
|
||||||
|
|
||||||
def testEncodeKeyValues(self):
|
|
||||||
self.assertEqual(self.packet._encode_key_values(1, '1234'), (1, '1234'))
|
|
||||||
|
|
||||||
def testEncodeKey(self):
|
|
||||||
self.assertEqual(self.packet._encode_key(1), 1)
|
|
||||||
|
|
||||||
def testAddAttribute(self):
|
|
||||||
self.packet.AddAttribute('Test-String', '1')
|
|
||||||
self.assertEqual(self.packet['Test-String'], ['1'])
|
|
||||||
self.packet.AddAttribute('Test-String', '1')
|
|
||||||
self.assertEqual(self.packet['Test-String'], ['1', '1'])
|
|
||||||
self.packet.AddAttribute('Test-String', ['2', '3'])
|
|
||||||
self.assertEqual(self.packet['Test-String'], ['1', '1', '2', '3'])
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacketConstructionTests(PacketConstructionTests):
|
|
||||||
klass = packet.AuthPacket
|
|
||||||
|
|
||||||
def testConstructorDefaults(self):
|
|
||||||
pkt = self.klass()
|
|
||||||
self.assertEqual(pkt.code, packet.AccessRequest)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacketTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'full'))
|
|
||||||
self.packet = packet.AuthPacket(
|
|
||||||
id=0, secret=b'secret',
|
|
||||||
authenticator=b'01234567890ABCDEF', dict=self.dict)
|
|
||||||
|
|
||||||
def testCreateReply(self):
|
|
||||||
reply = self.packet.CreateReply(**{'Test-Integer': 10})
|
|
||||||
self.assertEqual(reply.code, packet.AccessAccept)
|
|
||||||
self.assertEqual(reply.id, self.packet.id)
|
|
||||||
self.assertEqual(reply.secret, self.packet.secret)
|
|
||||||
self.assertEqual(reply.authenticator, self.packet.authenticator)
|
|
||||||
self.assertEqual(reply['Test-Integer'], [10])
|
|
||||||
|
|
||||||
def testRequestPacket(self):
|
|
||||||
self.assertEqual(self.packet.RequestPacket(),
|
|
||||||
b'\x01\x00\x00\x1401234567890ABCDE')
|
|
||||||
|
|
||||||
def testRequestPacketCreatesAuthenticator(self):
|
|
||||||
self.packet.authenticator = None
|
|
||||||
self.packet.RequestPacket()
|
|
||||||
self.assertTrue(self.packet.authenticator is not None)
|
|
||||||
|
|
||||||
def testRequestPacketCreatesID(self):
|
|
||||||
self.packet.id = None
|
|
||||||
self.packet.RequestPacket()
|
|
||||||
self.assertTrue(self.packet.id is not None)
|
|
||||||
|
|
||||||
def testPwCryptEmptyPassword(self):
|
|
||||||
self.assertEqual(self.packet.PwCrypt(''), b'')
|
|
||||||
|
|
||||||
def testPwCryptPassword(self):
|
|
||||||
self.assertEqual(self.packet.PwCrypt('Simplon'),
|
|
||||||
b'\xd3U;\xb23\r\x11\xba\x07\xe3\xa8*\xa8x\x14\x01')
|
|
||||||
|
|
||||||
def testPwCryptSetsAuthenticator(self):
|
|
||||||
self.packet.authenticator = None
|
|
||||||
self.packet.PwCrypt('')
|
|
||||||
self.assertTrue(self.packet.authenticator is not None)
|
|
||||||
|
|
||||||
def testPwDecryptEmptyPassword(self):
|
|
||||||
self.assertEqual(self.packet.PwDecrypt(b''), '')
|
|
||||||
|
|
||||||
def testPwDecryptPassword(self):
|
|
||||||
self.assertEqual(
|
|
||||||
self.packet.PwDecrypt(b'\xd3U;\xb23\r\x11\xba\x07\xe3\xa8*\xa8x\x14\x01'),
|
|
||||||
'Simplon')
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacketChapTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'chap'))
|
|
||||||
self.client = Client(server='localhost', secret=b'secret',
|
|
||||||
dict=self.dict)
|
|
||||||
|
|
||||||
def testVerifyChapPasswd(self):
|
|
||||||
chap_id = b'9'
|
|
||||||
chap_challenge = b'987654321'
|
|
||||||
chap_password = chap_id + md5_constructor(
|
|
||||||
chap_id + b'test_password' + chap_challenge).digest()
|
|
||||||
pkt = self.client.CreateAuthPacket(
|
|
||||||
code=packet.AccessChallenge,
|
|
||||||
authenticator=b'ABCDEFG',
|
|
||||||
User_Name='test_name',
|
|
||||||
CHAP_Challenge=chap_challenge,
|
|
||||||
CHAP_Password=chap_password
|
|
||||||
)
|
|
||||||
self.assertEqual(pkt['CHAP-Challenge'][0], chap_challenge)
|
|
||||||
self.assertEqual(pkt['CHAP-Password'][0], chap_password)
|
|
||||||
self.assertEqual(pkt.VerifyChapPasswd('test_password'), True)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacketSaltTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'tunnelPassword'))
|
|
||||||
self.packet = packet.Packet(id=0, secret=b'secret',
|
|
||||||
dict=self.dict)
|
|
||||||
|
|
||||||
def testSaltCrypt(self):
|
|
||||||
self.packet['Tunnel-Password:1'] = 'test'
|
|
||||||
# TODO: need to get a correct reference values
|
|
||||||
# self.assertEqual(self.packet['Tunnel-Password'], b'')
|
|
||||||
|
|
||||||
|
|
||||||
class AcctPacketConstructionTests(PacketConstructionTests):
|
|
||||||
klass = packet.AcctPacket
|
|
||||||
|
|
||||||
def testConstructorDefaults(self):
|
|
||||||
pkt = self.klass()
|
|
||||||
self.assertEqual(pkt.code, packet.AccountingRequest)
|
|
||||||
|
|
||||||
def testConstructorRawPacket(self):
|
|
||||||
raw = (b'\x00\x00\x00\x14\xb0\x5e\x4b\xfb\xcc\x1c'
|
|
||||||
b'\x8c\x8e\xc4\x72\xac\xea\x87\x45\x63\xa7')
|
|
||||||
pkt = self.klass(packet=raw)
|
|
||||||
self.assertEqual(pkt.raw_packet, raw)
|
|
||||||
|
|
||||||
|
|
||||||
class AcctPacketTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.path = os.path.join(home, 'tests', 'data')
|
|
||||||
self.dict = Dictionary(os.path.join(self.path, 'full'))
|
|
||||||
self.packet = packet.AcctPacket(
|
|
||||||
id=0, secret=b'secret',
|
|
||||||
authenticator=b'01234567890ABCDEF', dict=self.dict)
|
|
||||||
|
|
||||||
def testCreateReply(self):
|
|
||||||
reply = self.packet.CreateReply(**{'Test-Integer': 10})
|
|
||||||
self.assertEqual(reply.code, packet.AccountingResponse)
|
|
||||||
self.assertEqual(reply.id, self.packet.id)
|
|
||||||
self.assertEqual(reply.secret, self.packet.secret)
|
|
||||||
self.assertEqual(reply.authenticator, self.packet.authenticator)
|
|
||||||
self.assertEqual(reply['Test-Integer'], [10])
|
|
||||||
|
|
||||||
def testVerifyAcctRequest(self):
|
|
||||||
rawpacket = self.packet.RequestPacket()
|
|
||||||
pkt = packet.AcctPacket(secret=b'secret', packet=rawpacket)
|
|
||||||
self.assertEqual(pkt.VerifyAcctRequest(), True)
|
|
||||||
|
|
||||||
pkt.secret = b'different'
|
|
||||||
self.assertEqual(pkt.VerifyAcctRequest(), False)
|
|
||||||
pkt.secret = b'secret'
|
|
||||||
|
|
||||||
pkt.raw_packet = b'X' + pkt.raw_packet[1:]
|
|
||||||
self.assertEqual(pkt.VerifyAcctRequest(), False)
|
|
||||||
|
|
||||||
def testRequestPacket(self):
|
|
||||||
self.assertEqual(
|
|
||||||
self.packet.RequestPacket(),
|
|
||||||
b'\x04\x00\x00\x14\x95\xdf\x90\xccbn\xfb\x15G!\x13\xea\xfa>6\x0f')
|
|
||||||
|
|
||||||
def testRequestPacketSetsId(self):
|
|
||||||
self.packet.id = None
|
|
||||||
self.packet.RequestPacket()
|
|
||||||
self.assertTrue(self.packet.id is not None)
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
import select
|
|
||||||
import socket
|
|
||||||
import unittest
|
|
||||||
from pyrad.proxy import Proxy
|
|
||||||
from pyrad.packet import AccessAccept
|
|
||||||
from pyrad.packet import AccessRequest
|
|
||||||
from pyrad.server import ServerPacketError
|
|
||||||
from pyrad.server import Server
|
|
||||||
from pyrad.tests.mock import MockFd
|
|
||||||
from pyrad.tests.mock import MockPoll
|
|
||||||
from pyrad.tests.mock import MockSocket
|
|
||||||
from pyrad.tests.mock import MockClassMethod
|
|
||||||
from pyrad.tests.mock import UnmockClassMethods
|
|
||||||
|
|
||||||
|
|
||||||
class TrivialObject:
|
|
||||||
"""dummy object"""
|
|
||||||
|
|
||||||
|
|
||||||
class SocketTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.orgsocket = socket.socket
|
|
||||||
socket.socket = MockSocket
|
|
||||||
self.proxy = Proxy()
|
|
||||||
self.proxy._fdmap = {}
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
socket.socket = self.orgsocket
|
|
||||||
|
|
||||||
def testProxyFd(self):
|
|
||||||
self.proxy._poll = MockPoll()
|
|
||||||
self.proxy._prepare_sockets()
|
|
||||||
self.failUnless(isinstance(self.proxy._proxyfd, MockSocket))
|
|
||||||
self.assertEqual(list(self.proxy._fdmap.keys()), [1])
|
|
||||||
self.assertEqual(
|
|
||||||
self.proxy._poll.registry,
|
|
||||||
{1: select.POLLIN | select.POLLPRI | select.POLLERR})
|
|
||||||
|
|
||||||
|
|
||||||
class ProxyPacketHandlingTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.proxy = Proxy()
|
|
||||||
self.proxy.hosts['host'] = TrivialObject()
|
|
||||||
self.proxy.hosts['host'].secret = 'supersecret'
|
|
||||||
self.packet = TrivialObject()
|
|
||||||
self.packet.code = AccessAccept
|
|
||||||
self.packet.source = ('host', 'port')
|
|
||||||
|
|
||||||
def testHandleProxyPacketUnknownHost(self):
|
|
||||||
self.packet.source = ('stranger', 'port')
|
|
||||||
try:
|
|
||||||
self.proxy._handle_proxy_packet(self.packet)
|
|
||||||
except ServerPacketError as e:
|
|
||||||
self.failUnless('unknown host' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testHandleProxyPacketSetsSecret(self):
|
|
||||||
self.proxy._handle_proxy_packet(self.packet)
|
|
||||||
self.assertEqual(self.packet.secret, 'supersecret')
|
|
||||||
|
|
||||||
def testHandleProxyPacketHandlesWrongPacket(self):
|
|
||||||
self.packet.code = AccessRequest
|
|
||||||
try:
|
|
||||||
self.proxy._handle_proxy_packet(self.packet)
|
|
||||||
except ServerPacketError as e:
|
|
||||||
self.failUnless('non-response' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
|
|
||||||
class OtherTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.proxy = Proxy()
|
|
||||||
self.proxy._proxyfd = MockFd()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
UnmockClassMethods(Proxy)
|
|
||||||
UnmockClassMethods(Server)
|
|
||||||
|
|
||||||
def testProcessInputNonProxyPort(self):
|
|
||||||
fd = MockFd(fd=111)
|
|
||||||
MockClassMethod(Server, '_process_input')
|
|
||||||
self.proxy._process_input(fd)
|
|
||||||
self.assertEqual(
|
|
||||||
self.proxy.called,
|
|
||||||
[('_process_input', (fd,), {})])
|
|
||||||
|
|
||||||
def testProcessInput(self):
|
|
||||||
MockClassMethod(Proxy, '_grab_packet')
|
|
||||||
MockClassMethod(Proxy, '_handle_proxy_packet')
|
|
||||||
self.proxy._process_input(self.proxy._proxyfd)
|
|
||||||
self.assertEqual(
|
|
||||||
[x[0] for x in self.proxy.called],
|
|
||||||
['_grab_packet', '_handle_proxy_packet'])
|
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(select, 'poll'):
|
|
||||||
del SocketTests
|
|
||||||
@@ -1,329 +0,0 @@
|
|||||||
import select
|
|
||||||
import socket
|
|
||||||
import unittest
|
|
||||||
from pyrad.packet import PacketError
|
|
||||||
from pyrad.server import RemoteHost
|
|
||||||
from pyrad.server import Server
|
|
||||||
from pyrad.server import ServerPacketError
|
|
||||||
from pyrad.tests.mock import MockFinished
|
|
||||||
from pyrad.tests.mock import MockFd
|
|
||||||
from pyrad.tests.mock import MockPoll
|
|
||||||
from pyrad.tests.mock import MockSocket
|
|
||||||
from pyrad.tests.mock import MockClassMethod
|
|
||||||
from pyrad.tests.mock import UnmockClassMethods
|
|
||||||
from pyrad.packet import AccessRequest
|
|
||||||
from pyrad.packet import AccountingRequest
|
|
||||||
|
|
||||||
|
|
||||||
class TrivialObject:
|
|
||||||
"""dummy objec"""
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteHostTests(unittest.TestCase):
|
|
||||||
def testSimpleConstruction(self):
|
|
||||||
host = RemoteHost('address', 'secret', 'name', 'authport', 'acctport', 'coaport')
|
|
||||||
self.assertEqual(host.address, 'address')
|
|
||||||
self.assertEqual(host.secret, 'secret')
|
|
||||||
self.assertEqual(host.name, 'name')
|
|
||||||
self.assertEqual(host.authport, 'authport')
|
|
||||||
self.assertEqual(host.acctport, 'acctport')
|
|
||||||
self.assertEqual(host.coaport, 'coaport')
|
|
||||||
|
|
||||||
def testNamedConstruction(self):
|
|
||||||
host = RemoteHost(
|
|
||||||
address='address', secret='secret', name='name',
|
|
||||||
authport='authport', acctport='acctport', coaport='coaport')
|
|
||||||
self.assertEqual(host.address, 'address')
|
|
||||||
self.assertEqual(host.secret, 'secret')
|
|
||||||
self.assertEqual(host.name, 'name')
|
|
||||||
self.assertEqual(host.authport, 'authport')
|
|
||||||
self.assertEqual(host.acctport, 'acctport')
|
|
||||||
self.assertEqual(host.coaport, 'coaport')
|
|
||||||
|
|
||||||
|
|
||||||
class ServerConstructiontests(unittest.TestCase):
|
|
||||||
def testSimpleConstruction(self):
|
|
||||||
server = Server()
|
|
||||||
self.assertEqual(server.authfds, [])
|
|
||||||
self.assertEqual(server.acctfds, [])
|
|
||||||
self.assertEqual(server.authport, 1812)
|
|
||||||
self.assertEqual(server.acctport, 1813)
|
|
||||||
self.assertEqual(server.coaport, 3799)
|
|
||||||
self.assertEqual(server.hosts, {})
|
|
||||||
|
|
||||||
def testParameterOrder(self):
|
|
||||||
server = Server([], 'authport', 'acctport', 'coaport', 'hosts', 'dict')
|
|
||||||
self.assertEqual(server.authfds, [])
|
|
||||||
self.assertEqual(server.acctfds, [])
|
|
||||||
self.assertEqual(server.authport, 'authport')
|
|
||||||
self.assertEqual(server.acctport, 'acctport')
|
|
||||||
self.assertEqual(server.coaport, 'coaport')
|
|
||||||
self.assertEqual(server.dict, 'dict')
|
|
||||||
|
|
||||||
def testBindDuringConstruction(self):
|
|
||||||
def BindToAddress(self, addr):
|
|
||||||
self.bound.append(addr)
|
|
||||||
bta = Server.BindToAddress
|
|
||||||
Server.BindToAddress = BindToAddress
|
|
||||||
|
|
||||||
Server.bound = []
|
|
||||||
server = Server(['one', 'two', 'three'])
|
|
||||||
self.assertEqual(server.bound, ['one', 'two', 'three'])
|
|
||||||
del Server.bound
|
|
||||||
|
|
||||||
Server.BindToAddress = bta
|
|
||||||
|
|
||||||
|
|
||||||
class SocketTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.orgsocket = socket.socket
|
|
||||||
socket.socket = MockSocket
|
|
||||||
self.server = Server()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
socket.socket = self.orgsocket
|
|
||||||
|
|
||||||
def testBind(self):
|
|
||||||
self.server.BindToAddress('192.168.13.13')
|
|
||||||
self.assertEqual(len(self.server.authfds), 1)
|
|
||||||
self.assertEqual(self.server.authfds[0].address,
|
|
||||||
('192.168.13.13', 1812))
|
|
||||||
|
|
||||||
self.assertEqual(len(self.server.acctfds), 1)
|
|
||||||
self.assertEqual(self.server.acctfds[0].address,
|
|
||||||
('192.168.13.13', 1813))
|
|
||||||
|
|
||||||
def testBindv6(self):
|
|
||||||
self.server.BindToAddress('2001:db8:123::1')
|
|
||||||
self.assertEqual(len(self.server.authfds), 1)
|
|
||||||
self.assertEqual(self.server.authfds[0].address,
|
|
||||||
('2001:db8:123::1', 1812))
|
|
||||||
|
|
||||||
self.assertEqual(len(self.server.acctfds), 1)
|
|
||||||
self.assertEqual(self.server.acctfds[0].address,
|
|
||||||
('2001:db8:123::1', 1813))
|
|
||||||
|
|
||||||
def testGrabPacket(self):
|
|
||||||
def gen(data):
|
|
||||||
res = TrivialObject()
|
|
||||||
res.data = data
|
|
||||||
return res
|
|
||||||
|
|
||||||
fd = MockFd()
|
|
||||||
fd.source = object()
|
|
||||||
pkt = self.server._grab_packet(gen, fd)
|
|
||||||
self.failUnless(isinstance(pkt, TrivialObject))
|
|
||||||
self.failUnless(pkt.fd is fd)
|
|
||||||
self.failUnless(pkt.source is fd.source)
|
|
||||||
self.failUnless(pkt.data is fd.data)
|
|
||||||
|
|
||||||
def testPrepareSocketNoFds(self):
|
|
||||||
self.server._poll = MockPoll()
|
|
||||||
self.server._prepare_sockets()
|
|
||||||
|
|
||||||
self.assertEqual(self.server._poll.registry, {})
|
|
||||||
self.assertEqual(self.server._realauthfds, [])
|
|
||||||
self.assertEqual(self.server._realacctfds, [])
|
|
||||||
|
|
||||||
def testPrepareSocketAuthFds(self):
|
|
||||||
self.server._poll = MockPoll()
|
|
||||||
self.server._fdmap = {}
|
|
||||||
self.server.authfds = [MockFd(12), MockFd(14)]
|
|
||||||
self.server._prepare_sockets()
|
|
||||||
|
|
||||||
self.assertEqual(list(self.server._fdmap.keys()), [12, 14])
|
|
||||||
self.assertEqual(
|
|
||||||
self.server._poll.registry,
|
|
||||||
{12: select.POLLIN | select.POLLPRI | select.POLLERR,
|
|
||||||
14: select.POLLIN | select.POLLPRI | select.POLLERR})
|
|
||||||
|
|
||||||
def testPrepareSocketAcctFds(self):
|
|
||||||
self.server._poll = MockPoll()
|
|
||||||
self.server._fdmap = {}
|
|
||||||
self.server.acctfds = [MockFd(12), MockFd(14)]
|
|
||||||
self.server._prepare_sockets()
|
|
||||||
|
|
||||||
self.assertEqual(list(self.server._fdmap.keys()), [12, 14])
|
|
||||||
self.assertEqual(
|
|
||||||
self.server._poll.registry,
|
|
||||||
{12: select.POLLIN | select.POLLPRI | select.POLLERR,
|
|
||||||
14: select.POLLIN | select.POLLPRI | select.POLLERR})
|
|
||||||
|
|
||||||
|
|
||||||
class AuthPacketHandlingTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = Server()
|
|
||||||
self.server.hosts['host'] = TrivialObject()
|
|
||||||
self.server.hosts['host'].secret = 'supersecret'
|
|
||||||
self.packet = TrivialObject()
|
|
||||||
self.packet.code = AccessRequest
|
|
||||||
self.packet.source = ('host', 'port')
|
|
||||||
|
|
||||||
def testHandleAuthPacketUnknownHost(self):
|
|
||||||
self.packet.source = ('stranger', 'port')
|
|
||||||
try:
|
|
||||||
self.server._handle_auth_packet(self.packet)
|
|
||||||
except ServerPacketError as e:
|
|
||||||
self.failUnless('unknown host' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testHandleAuthPacketWrongPort(self):
|
|
||||||
self.packet.code = AccountingRequest
|
|
||||||
try:
|
|
||||||
self.server._handle_auth_packet(self.packet)
|
|
||||||
except ServerPacketError as e:
|
|
||||||
self.failUnless('port' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testHandleAuthPacket(self):
|
|
||||||
def HandleAuthPacket(self, pkt):
|
|
||||||
self.handled = pkt
|
|
||||||
hap = Server.HandleAuthPacket
|
|
||||||
Server.HandleAuthPacket = HandleAuthPacket
|
|
||||||
|
|
||||||
self.server._handle_auth_packet(self.packet)
|
|
||||||
self.failUnless(self.server.handled is self.packet)
|
|
||||||
|
|
||||||
Server.HandleAuthPacket = hap
|
|
||||||
|
|
||||||
|
|
||||||
class AcctPacketHandlingTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = Server()
|
|
||||||
self.server.hosts['host'] = TrivialObject()
|
|
||||||
self.server.hosts['host'].secret = 'supersecret'
|
|
||||||
self.packet = TrivialObject()
|
|
||||||
self.packet.code = AccountingRequest
|
|
||||||
self.packet.source = ('host', 'port')
|
|
||||||
|
|
||||||
def testHandleAcctPacketUnknownHost(self):
|
|
||||||
self.packet.source = ('stranger', 'port')
|
|
||||||
try:
|
|
||||||
self.server._handle_acct_packet(self.packet)
|
|
||||||
except ServerPacketError as e:
|
|
||||||
self.failUnless('unknown host' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testHandleAcctPacketWrongPort(self):
|
|
||||||
self.packet.code = AccessRequest
|
|
||||||
try:
|
|
||||||
self.server._handle_acct_packet(self.packet)
|
|
||||||
except ServerPacketError as e:
|
|
||||||
self.failUnless('port' in str(e))
|
|
||||||
else:
|
|
||||||
self.fail()
|
|
||||||
|
|
||||||
def testHandleAcctPacket(self):
|
|
||||||
def HandleAcctPacket(self, pkt):
|
|
||||||
self.handled = pkt
|
|
||||||
hap = Server.HandleAcctPacket
|
|
||||||
Server.HandleAcctPacket = HandleAcctPacket
|
|
||||||
|
|
||||||
self.server._handle_acct_packet(self.packet)
|
|
||||||
self.failUnless(self.server.handled is self.packet)
|
|
||||||
|
|
||||||
Server.HandleAcctPacket = hap
|
|
||||||
|
|
||||||
|
|
||||||
class OtherTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = Server()
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
UnmockClassMethods(Server)
|
|
||||||
|
|
||||||
def testCreateReplyPacket(self):
|
|
||||||
class TrivialPacket:
|
|
||||||
source = object()
|
|
||||||
|
|
||||||
def CreateReply(self, **kw):
|
|
||||||
reply = TrivialObject()
|
|
||||||
reply.kw = kw
|
|
||||||
return reply
|
|
||||||
|
|
||||||
reply = self.server.CreateReplyPacket(
|
|
||||||
TrivialPacket(),
|
|
||||||
one='one', two='two')
|
|
||||||
self.failUnless(isinstance(reply, TrivialObject))
|
|
||||||
self.failUnless(reply.source is TrivialPacket.source)
|
|
||||||
self.assertEqual(reply.kw, dict(one='one', two='two'))
|
|
||||||
|
|
||||||
def testAuthProcessInput(self):
|
|
||||||
fd = MockFd(1)
|
|
||||||
self.server._realauthfds = [1]
|
|
||||||
MockClassMethod(Server, '_grab_packet')
|
|
||||||
MockClassMethod(Server, '_handle_auth_packet')
|
|
||||||
|
|
||||||
self.server._process_input(fd)
|
|
||||||
self.assertEqual(
|
|
||||||
[x[0] for x in self.server.called],
|
|
||||||
['_grab_packet', '_handle_auth_packet'])
|
|
||||||
self.assertEqual(self.server.called[0][1][1], fd)
|
|
||||||
|
|
||||||
def testAcctProcessInput(self):
|
|
||||||
fd = MockFd(1)
|
|
||||||
self.server._realauthfds = []
|
|
||||||
self.server._realacctfds = [1]
|
|
||||||
MockClassMethod(Server, '_grab_packet')
|
|
||||||
MockClassMethod(Server, '_handle_acct_packet')
|
|
||||||
|
|
||||||
self.server._process_input(fd)
|
|
||||||
self.assertEqual(
|
|
||||||
[x[0] for x in self.server.called],
|
|
||||||
['_grab_packet', '_handle_acct_packet'])
|
|
||||||
self.assertEqual(self.server.called[0][1][1], fd)
|
|
||||||
|
|
||||||
|
|
||||||
class ServerRunTests(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.server = Server()
|
|
||||||
self.origpoll = select.poll
|
|
||||||
select.poll = MockPoll
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
MockPoll.results = []
|
|
||||||
select.poll = self.origpoll
|
|
||||||
UnmockClassMethods(Server)
|
|
||||||
|
|
||||||
def testRunInitializes(self):
|
|
||||||
MockClassMethod(Server, '_prepare_sockets')
|
|
||||||
self.assertRaises(MockFinished, self.server.Run)
|
|
||||||
self.assertEqual(self.server.called, [('_prepare_sockets', (), {})])
|
|
||||||
self.failUnless(isinstance(self.server._fdmap, dict))
|
|
||||||
self.failUnless(isinstance(self.server._poll, MockPoll))
|
|
||||||
|
|
||||||
def testRunIgnoresPollErrors(self):
|
|
||||||
self.server.authfds = [MockFd()]
|
|
||||||
MockPoll.results = [(0, select.POLLERR)]
|
|
||||||
self.assertRaises(MockFinished, self.server.Run)
|
|
||||||
|
|
||||||
def testRunIgnoresServerPacketErrors(self):
|
|
||||||
def RaisePacketError(self, fd):
|
|
||||||
raise ServerPacketError
|
|
||||||
MockClassMethod(Server, '_process_input', RaisePacketError)
|
|
||||||
self.server.authfds = [MockFd()]
|
|
||||||
MockPoll.results = [(0, select.POLLIN)]
|
|
||||||
self.assertRaises(MockFinished, self.server.Run)
|
|
||||||
|
|
||||||
def testRunIgnoresPacketErrors(self):
|
|
||||||
def RaisePacketError(self, fd):
|
|
||||||
raise PacketError
|
|
||||||
MockClassMethod(Server, '_process_input', RaisePacketError)
|
|
||||||
self.server.authfds = [MockFd()]
|
|
||||||
MockPoll.results = [(0, select.POLLIN)]
|
|
||||||
self.assertRaises(MockFinished, self.server.Run)
|
|
||||||
|
|
||||||
def testRunRunsProcessInput(self):
|
|
||||||
MockClassMethod(Server, '_process_input')
|
|
||||||
self.server.authfds = fd = [MockFd()]
|
|
||||||
MockPoll.results = [(0, select.POLLIN)]
|
|
||||||
self.assertRaises(MockFinished, self.server.Run)
|
|
||||||
self.assertEqual(self.server.called, [('_process_input', (fd[0],), {})])
|
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(select, 'poll'):
|
|
||||||
del SocketTests
|
|
||||||
del ServerRunTests
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
from ipaddress import AddressValueError
|
|
||||||
from pyrad import tools
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
|
|
||||||
class EncodingTests(unittest.TestCase):
|
|
||||||
def testStringEncoding(self):
|
|
||||||
self.assertRaises(ValueError, tools.EncodeString, 'x' * 254)
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeString('1234567890'),
|
|
||||||
b'1234567890')
|
|
||||||
|
|
||||||
def testInvalidStringEncodingRaisesTypeError(self):
|
|
||||||
self.assertRaises(TypeError, tools.EncodeString, 1)
|
|
||||||
|
|
||||||
def testAddressEncoding(self):
|
|
||||||
self.assertRaises(AddressValueError, tools.EncodeAddress, 'TEST123')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAddress('192.168.0.255'),
|
|
||||||
b'\xc0\xa8\x00\xff')
|
|
||||||
|
|
||||||
def testInvalidAddressEncodingRaisesTypeError(self):
|
|
||||||
self.assertRaises(TypeError, tools.EncodeAddress, 1)
|
|
||||||
|
|
||||||
def testIntegerEncoding(self):
|
|
||||||
self.assertEqual(tools.EncodeInteger(0x01020304), b'\x01\x02\x03\x04')
|
|
||||||
|
|
||||||
def testInteger64Encoding(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeInteger64(0xFFFFFFFFFFFFFFFF), b'\xff' * 8
|
|
||||||
)
|
|
||||||
|
|
||||||
def testUnsignedIntegerEncoding(self):
|
|
||||||
self.assertEqual(tools.EncodeInteger(0xFFFFFFFF), b'\xff\xff\xff\xff')
|
|
||||||
|
|
||||||
def testInvalidIntegerEncodingRaisesTypeError(self):
|
|
||||||
self.assertRaises(TypeError, tools.EncodeInteger, 'ONE')
|
|
||||||
|
|
||||||
def testDateEncoding(self):
|
|
||||||
self.assertEqual(tools.EncodeDate(0x01020304), b'\x01\x02\x03\x04')
|
|
||||||
|
|
||||||
def testInvalidDataEncodingRaisesTypeError(self):
|
|
||||||
self.assertRaises(TypeError, tools.EncodeDate, '1')
|
|
||||||
|
|
||||||
def testEncodeAscendBinary(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAscendBinary('family=ipv4 action=discard direction=in dst=10.10.255.254/32'),
|
|
||||||
b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00')
|
|
||||||
|
|
||||||
def testStringDecoding(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeString(b'1234567890'),
|
|
||||||
'1234567890')
|
|
||||||
|
|
||||||
def testAddressDecoding(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeAddress(b'\xc0\xa8\x00\xff'),
|
|
||||||
'192.168.0.255')
|
|
||||||
|
|
||||||
def testIntegerDecoding(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeInteger(b'\x01\x02\x03\x04'),
|
|
||||||
0x01020304)
|
|
||||||
|
|
||||||
def testInteger64Decoding(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeInteger64(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF
|
|
||||||
)
|
|
||||||
|
|
||||||
def testDateDecoding(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeDate(b'\x01\x02\x03\x04'),
|
|
||||||
0x01020304)
|
|
||||||
|
|
||||||
def testUnknownTypeEncoding(self):
|
|
||||||
self.assertRaises(ValueError, tools.EncodeAttr, 'unknown', None)
|
|
||||||
|
|
||||||
def testUnknownTypeDecoding(self):
|
|
||||||
self.assertRaises(ValueError, tools.DecodeAttr, 'unknown', None)
|
|
||||||
|
|
||||||
def testEncodeFunction(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('string', 'string'),
|
|
||||||
b'string')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('octets', b'string'),
|
|
||||||
b'string')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('ipaddr', '192.168.0.255'),
|
|
||||||
b'\xc0\xa8\x00\xff')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('integer', 0x01020304),
|
|
||||||
b'\x01\x02\x03\x04')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('date', 0x01020304),
|
|
||||||
b'\x01\x02\x03\x04')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('integer64', 0xFFFFFFFFFFFFFFFF),
|
|
||||||
b'\xff'*8)
|
|
||||||
|
|
||||||
def testDecodeFunction(self):
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeAttr('string', b'string'),
|
|
||||||
'string')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.EncodeAttr('octets', b'string'),
|
|
||||||
b'string')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeAttr('ipaddr', b'\xc0\xa8\x00\xff'),
|
|
||||||
'192.168.0.255')
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeAttr('integer', b'\x01\x02\x03\x04'),
|
|
||||||
0x01020304)
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeAttr('integer64', b'\xff'*8),
|
|
||||||
0xFFFFFFFFFFFFFFFF)
|
|
||||||
self.assertEqual(
|
|
||||||
tools.DecodeAttr('date', b'\x01\x02\x03\x04'),
|
|
||||||
0x01020304)
|
|
||||||
226
pyrad3/tools.py
226
pyrad3/tools.py
@@ -1,226 +0,0 @@
|
|||||||
# tools.py
|
|
||||||
#
|
|
||||||
# Utility functions
|
|
||||||
import binascii
|
|
||||||
import ipaddress
|
|
||||||
import struct
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeString(string):
|
|
||||||
if len(string) > 253:
|
|
||||||
raise ValueError('Can only encode strings of <= 253 characters')
|
|
||||||
if isinstance(string, str):
|
|
||||||
return string.encode('utf-8')
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeOctets(string):
|
|
||||||
if len(string) > 253:
|
|
||||||
raise ValueError('Can only encode strings of <= 253 characters')
|
|
||||||
|
|
||||||
if string.startswith(b'0x'):
|
|
||||||
hexstring = string.split(b'0x')[1]
|
|
||||||
return binascii.unhexlify(hexstring)
|
|
||||||
else:
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeAddress(addr):
|
|
||||||
if not isinstance(addr, str):
|
|
||||||
raise TypeError('Address has to be a string')
|
|
||||||
return ipaddress.IPv4Address(addr).packed
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeIPv6Prefix(addr):
|
|
||||||
if not isinstance(addr, str):
|
|
||||||
raise TypeError('IPv6 Prefix has to be a string')
|
|
||||||
ip = ipaddress.IPv6Network(addr)
|
|
||||||
return struct.pack('2B', *[0, ip.prefixlen]) + ip.network_address.packed
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeIPv6Address(addr):
|
|
||||||
if not isinstance(addr, str):
|
|
||||||
raise TypeError('IPv6 Address has to be a string')
|
|
||||||
return ipaddress.IPv6Address(addr).packed
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeAscendBinary(string):
|
|
||||||
"""
|
|
||||||
Format: List of type=value pairs sperated by spaces.
|
|
||||||
|
|
||||||
Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'
|
|
||||||
|
|
||||||
Type:
|
|
||||||
family ipv4(default) or ipv6
|
|
||||||
action discard(default) or accept
|
|
||||||
direction in(default) or out
|
|
||||||
src source prefix (default ignore)
|
|
||||||
dst destination prefix (default ignore)
|
|
||||||
proto protocol number / next-header number (default ignore)
|
|
||||||
sport source port (default ignore)
|
|
||||||
dport destination port (default ignore)
|
|
||||||
sportq source port qualifier (default 0)
|
|
||||||
dportq destination port qualifier (default 0)
|
|
||||||
|
|
||||||
Source/Destination Port Qualifier:
|
|
||||||
0 no compare
|
|
||||||
1 less than
|
|
||||||
2 equal to
|
|
||||||
3 greater than
|
|
||||||
4 not equal to
|
|
||||||
"""
|
|
||||||
|
|
||||||
terms = {
|
|
||||||
'family': b'\x01',
|
|
||||||
'action': b'\x00',
|
|
||||||
'direction': b'\x01',
|
|
||||||
'src': b'\x00\x00\x00\x00',
|
|
||||||
'dst': b'\x00\x00\x00\x00',
|
|
||||||
'srcl': b'\x00',
|
|
||||||
'dstl': b'\x00',
|
|
||||||
'proto': b'\x00',
|
|
||||||
'sport': b'\x00\x00',
|
|
||||||
'dport': b'\x00\x00',
|
|
||||||
'sportq': b'\x00',
|
|
||||||
'dportq': b'\x00'
|
|
||||||
}
|
|
||||||
|
|
||||||
for t in string.split(' '):
|
|
||||||
key, value = t.split('=')
|
|
||||||
if key == 'family' and value == 'ipv6':
|
|
||||||
terms[key] = b'\x03'
|
|
||||||
if terms['src'] == b'\x00\x00\x00\x00':
|
|
||||||
terms['src'] = 16 * b'\x00'
|
|
||||||
if terms['dst'] == b'\x00\x00\x00\x00':
|
|
||||||
terms['dst'] = 16 * b'\x00'
|
|
||||||
elif key == 'action' and value == 'accept':
|
|
||||||
terms[key] = b'\x01'
|
|
||||||
elif key == 'direction' and value == 'out':
|
|
||||||
terms[key] = b'\x00'
|
|
||||||
elif key in ('src', 'dst'):
|
|
||||||
ip = ipaddress.ip_network(value)
|
|
||||||
terms[key] = ip.network_address.packed
|
|
||||||
terms[key+'l'] = struct.pack('B', ip.prefixlen)
|
|
||||||
elif key in ('sport', 'dport'):
|
|
||||||
terms[key] = struct.pack('!H', int(value))
|
|
||||||
elif key in ('sportq', 'dportq', 'proto'):
|
|
||||||
terms[key] = struct.pack('B', int(value))
|
|
||||||
|
|
||||||
trailer = 8 * b'\x00'
|
|
||||||
|
|
||||||
result = b''.join((
|
|
||||||
terms['family'], terms['action'], terms['direction'], b'\x00',
|
|
||||||
terms['src'], terms['dst'], terms['srcl'], terms['dstl'], terms['proto'], b'\x00',
|
|
||||||
terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], b'\x00\x00', trailer))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeInteger(num, format='!I'):
|
|
||||||
try:
|
|
||||||
num = int(num)
|
|
||||||
except ValueError:
|
|
||||||
raise TypeError('Can not encode non-integer as integer')
|
|
||||||
return struct.pack(format, num)
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeInteger64(num, format='!Q'):
|
|
||||||
try:
|
|
||||||
num = int(num)
|
|
||||||
except ValueError:
|
|
||||||
raise TypeError('Can not encode non-integer as integer64')
|
|
||||||
return struct.pack(format, num)
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeDate(num):
|
|
||||||
if not isinstance(num, int):
|
|
||||||
raise TypeError('Can not encode non-integer as date')
|
|
||||||
return struct.pack('!I', num)
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeString(string):
|
|
||||||
try:
|
|
||||||
return string.decode('utf-8')
|
|
||||||
except:
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeOctets(string):
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeAddress(addr):
|
|
||||||
return '.'.join((str(a) for a in struct.unpack('BBBB', addr)))
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeIPv6Prefix(addr):
|
|
||||||
addr = addr + b'\x00' * (18-len(addr))
|
|
||||||
prefix = addr[:2]
|
|
||||||
addr = addr[2:]
|
|
||||||
return str(ipaddress.ip_network((prefix, addr)))
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeIPv6Address(addr):
|
|
||||||
addr = addr + b'\x00' * (16-len(addr))
|
|
||||||
return str(ipaddress.IPv6Address(addr))
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeAscendBinary(string):
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeInteger(num, format='!I'):
|
|
||||||
return (struct.unpack(format, num))[0]
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeInteger64(num, format='!Q'):
|
|
||||||
return (struct.unpack(format, num))[0]
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeDate(num):
|
|
||||||
return (struct.unpack('!I', num))[0]
|
|
||||||
|
|
||||||
|
|
||||||
ENCODE_MAP = {
|
|
||||||
'string': EncodeString,
|
|
||||||
'octets': EncodeOctets,
|
|
||||||
'integer': EncodeInteger,
|
|
||||||
'ipaddr': EncodeAddress,
|
|
||||||
'ipv6prefix': EncodeIPv6Prefix,
|
|
||||||
'ipv6addr': EncodeIPv6Address,
|
|
||||||
'abinary': EncodeAscendBinary,
|
|
||||||
'signed': lambda value: EncodeInteger(value, '!i'),
|
|
||||||
'short': lambda value: EncodeInteger(value, '!H'),
|
|
||||||
'byte': lambda value: EncodeInteger(value, '!B'),
|
|
||||||
'date': EncodeDate,
|
|
||||||
'integer64': EncodeInteger64,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def EncodeAttr(datatype, value):
|
|
||||||
try:
|
|
||||||
return ENCODE_MAP[datatype](value)
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f'Unknown attribute type {datatype}')
|
|
||||||
|
|
||||||
|
|
||||||
DECODE_MAP = {
|
|
||||||
'string': DecodeString,
|
|
||||||
'octets': DecodeOctets,
|
|
||||||
'integer': DecodeInteger,
|
|
||||||
'ipaddr': DecodeAddress,
|
|
||||||
'ipv6prefix': DecodeIPv6Prefix,
|
|
||||||
'ipv6addr': DecodeIPv6Address,
|
|
||||||
'abinary': DecodeAscendBinary,
|
|
||||||
'signed': lambda value: DecodeInteger(value, '!i'),
|
|
||||||
'short': lambda value: DecodeInteger(value, '!H'),
|
|
||||||
'byte': lambda value: DecodeInteger(value, '!B'),
|
|
||||||
'date': DecodeDate,
|
|
||||||
'integer64': DecodeInteger64,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def DecodeAttr(datatype, value):
|
|
||||||
try:
|
|
||||||
return DECODE_MAP[datatype](value)
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f'Unknown attribute type {datatype}')
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
from pyrad.packet import Packet, random_generator
|
|
||||||
|
|
||||||
|
|
||||||
def salt_encrypt(packet: Packet, value: bytes) -> bytes:
|
|
||||||
length = struct.pack('B', len(value))
|
|
||||||
buf = length + value
|
|
||||||
buf += b'\x00' * (16 - (len(buf) % 16))
|
|
||||||
|
|
||||||
# First bit if the random value must be 1
|
|
||||||
random_value = 32768 + random_generator.randrange(0, 32767)
|
|
||||||
result = struct.pack('!H', random_value)
|
|
||||||
|
|
||||||
last = packet.authenticator + result
|
|
||||||
while buf:
|
|
||||||
cur_hash = md5_constructor(packet.secret + last).digest()
|
|
||||||
for b, h in zip(buf, cur_hash):
|
|
||||||
result += bytes([b ^ h])
|
|
||||||
last = result[-16:]
|
|
||||||
buf = buf[16:]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def validate_chap_password(packet: Packet, password: bytes) -> bool:
|
|
||||||
# TODO:
|
|
||||||
challange = packet.get('CHAP-Challenge', packet.authenticator)
|
|
||||||
return chap_password == hashlib.md5(chapid + password + challenge).digest()
|
|
||||||
|
|
||||||
|
|
||||||
def validate_pap_password(packet: Packet, password: bytes) -> bool:
|
|
||||||
obf_pass = password_encode(packet, password)
|
|
||||||
return packet['Password'] == obf_pass
|
|
||||||
|
|
||||||
|
|
||||||
def password_encode(packet: Packet, password: bytes) -> bytes:
|
|
||||||
password += b'\x00' * (16 - (len(password) % 16))
|
|
||||||
return obfuscation_algorithm(packet, password)
|
|
||||||
|
|
||||||
|
|
||||||
def password_decode(packet: Packet, password: bytes) -> bytes:
|
|
||||||
decoded = obfuscation_algorithm(packet, password)
|
|
||||||
return decoded.rstrip(b'\x00')
|
|
||||||
|
|
||||||
|
|
||||||
def obfuscation_algorithm(packet: Packet, password: bytes) -> bytes:
|
|
||||||
result = b''
|
|
||||||
buf = password
|
|
||||||
last = packet.authenticator
|
|
||||||
|
|
||||||
while buf:
|
|
||||||
cur_hash = md5_constructor(packet.secret + last)
|
|
||||||
for b, h in zip(buf, cur_hash):
|
|
||||||
result += bytes([b ^ h])
|
|
||||||
(last, buf) = (buf[:16], buf[16:])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
19
shell.nix
19
shell.nix
@@ -1,5 +1,20 @@
|
|||||||
let
|
let
|
||||||
pkgs = import <nixpkgs> {};
|
pkgs = import <nixpkgs> {};
|
||||||
python = pkgs.python36;
|
|
||||||
in
|
in
|
||||||
import ./default.nix { inherit pkgs python; }
|
pkgs.mkShell {
|
||||||
|
buildInputs = with pkgs; [
|
||||||
|
git
|
||||||
|
nixfmt
|
||||||
|
pkgs.python3
|
||||||
|
pkgs.python3Packages.black
|
||||||
|
pkgs.python3Packages.pytest
|
||||||
|
pkgs.python3Packages.pytest-black
|
||||||
|
pkgs.python3Packages.pytest-flake8
|
||||||
|
pkgs.python3Packages.pytest-mypy
|
||||||
|
pkgs.python3Packages.pytest-pylint
|
||||||
|
];
|
||||||
|
|
||||||
|
shellHook = ''
|
||||||
|
export PYTHONPATH=${./src}:$PYTHONPATH
|
||||||
|
'';
|
||||||
|
}
|
||||||
|
|||||||
@@ -38,9 +38,9 @@ This package contains four modules:
|
|||||||
|
|
||||||
__docformat__ = 'epytext en'
|
__docformat__ = 'epytext en'
|
||||||
|
|
||||||
__author__ = 'Christian Giese <developer@gicnet.de>'
|
__author__ = 'Istvan Ruzman <istvan@ruzman.eu>'
|
||||||
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
|
__url__ = 'http://pyrad.readthedocs.io/en/latest/?badge=latest'
|
||||||
__copyright__ = 'Copyright 2002-2020 Wichert Akkerman and Christian Giese. All rights reserved.'
|
__copyright__ = 'Copyright 2020 Istvan Ruzman'
|
||||||
__version__ = '2.3'
|
__version__ = '0.1.0'
|
||||||
|
|
||||||
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'dictfile', 'new_client', 'new_host', 'new_packet']
|
__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'utils']
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
# Bidirectional map
|
# Bidirectional map
|
||||||
|
|
||||||
|
|
||||||
class BiDict():
|
class BiDict:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.forward = {}
|
self.forward = {}
|
||||||
self.backward = {}
|
self.backward = {}
|
||||||
@@ -1,63 +1,94 @@
|
|||||||
# Copyright 2020 Istvan Ruzman
|
# Copyright 2020 Istvan Ruzman
|
||||||
# SPDX-License-Identifier: MIT OR Apache-2.0
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
"""Implementation of a simple but extensible RADIUS Client"""
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
from ipaddress import IPv4Address, IPv6Address
|
||||||
|
|
||||||
import select
|
import select
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from pyrad3 import new_packet as P
|
import pyrad3.packet as P
|
||||||
from pyrad3 import new_host
|
from pyrad3 import host
|
||||||
|
|
||||||
SUPPORTED_SEND_TYPES = [
|
SUPPORTED_SEND_TYPES = [
|
||||||
P.AccessRequest,
|
P.Code.AccessRequest,
|
||||||
P.AccountingRequest,
|
P.Code.AccountingRequest,
|
||||||
P.CoARequest,
|
P.Code.CoARequest,
|
||||||
]
|
]
|
||||||
|
|
||||||
PACKET_TYPE_PORT_MAPPING = {
|
PACKET_TYPE_PORT_MAPPING = {
|
||||||
P.AccessRequest: 'authport',
|
P.Code.AccessRequest: "authport",
|
||||||
P.AccountingRequest: 'acctport',
|
P.Code.AccountingRequest: "acctport",
|
||||||
P.CoARequest: 'coaport',
|
P.Code.CoARequest: "coaport",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Timeout(Exception):
|
class Timeout(Exception):
|
||||||
pass
|
"""Exception for wait timeouts"""
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedPacketType(Exception):
|
class UnsupportedPacketType(Exception):
|
||||||
pass
|
"""Exception for received packets"""
|
||||||
|
|
||||||
|
|
||||||
class Client(new_host.Host):
|
class Client(host.Host):
|
||||||
def __init__(self, server, secret, radius_dictionary, **kwargs):
|
"""A simple and extensible RADIUS Client."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server: Union[str, IPv4Address, IPv6Address],
|
||||||
|
secret: bytes,
|
||||||
|
radius_dictionary: dict,
|
||||||
|
interface: Optional[str],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
super().__init__(secret, radius_dictionary, **kwargs)
|
super().__init__(secret, radius_dictionary, **kwargs)
|
||||||
self.server = server
|
self.server = server
|
||||||
self._socket = None
|
self.interface = interface
|
||||||
self._poll = None
|
self._socket: Optional[socket.socket] = None
|
||||||
|
self._poll: Optional[select.poll] = None
|
||||||
|
|
||||||
def bind(self, addr):
|
def bind(self, addr):
|
||||||
|
"""Bind the Address to some socket"""
|
||||||
self._socket_close()
|
self._socket_close()
|
||||||
self._socket_open()
|
self._socket_open()
|
||||||
self._socket.bind(addr)
|
self._socket.bind(addr)
|
||||||
|
|
||||||
def _socket_open(self):
|
def _socket_open(self):
|
||||||
|
"""Open a client socket"""
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
family = socket.getaddrinfo(self.server, 'www')[0][0]
|
family = socket.getaddrinfo(self.server, "www")[0][0]
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
family = socket.AF_INET
|
family = socket.AF_INET
|
||||||
self._socket = socket.socket(family, socket.SOCK_DGRAM)
|
self._socket = socket.socket(family, socket.SOCK_DGRAM)
|
||||||
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
if self.interface is not None:
|
||||||
|
# This will fail on non-Linux systems - that's ok, because we don't
|
||||||
|
# have any implementation yet for non-Linux systems.
|
||||||
|
# Better to fail loudly than to do something unexpected silently.
|
||||||
|
self._socket.setsockopt(
|
||||||
|
socket.SOL_SOCKET,
|
||||||
|
socket.SO_BINDTODEVICE,
|
||||||
|
self.interface,
|
||||||
|
len(self.interface),
|
||||||
|
)
|
||||||
self._poll = select.poll()
|
self._poll = select.poll()
|
||||||
self._poll.register(self._socket, select.POLLIN)
|
self._poll.register(self._socket, select.POLLIN)
|
||||||
|
|
||||||
def _socket_close(self):
|
def _socket_close(self):
|
||||||
|
"""Close the Client socket"""
|
||||||
if self._socket is not None:
|
if self._socket is not None:
|
||||||
self._poll.unregister(self._socket)
|
self._poll.unregister(self._socket)
|
||||||
self._socket.close()
|
self._socket.close()
|
||||||
self._socket = None
|
self._socket = None
|
||||||
|
|
||||||
def _select_port(self, packet):
|
def _select_port(self, packet: P.Packet):
|
||||||
|
"""Select the RADIUS Port depending on the RADIUS Packet type"""
|
||||||
try:
|
try:
|
||||||
port_type = PACKET_TYPE_PORT_MAPPING[packet.code]
|
port_type = PACKET_TYPE_PORT_MAPPING[packet.code]
|
||||||
return getattr(self, port_type)
|
return getattr(self, port_type)
|
||||||
@@ -65,7 +96,11 @@ class Client(new_host.Host):
|
|||||||
pass
|
pass
|
||||||
raise UnsupportedPacketType(f"The packet type {packet.code} by Client")
|
raise UnsupportedPacketType(f"The packet type {packet.code} by Client")
|
||||||
|
|
||||||
def _send_packet(self, packet):
|
def _send_packet(self, packet: P.Packet):
|
||||||
|
"""Send a RADIUS Packet and wait for the reply"""
|
||||||
|
assert self._socket is not None
|
||||||
|
assert self._poll is not None
|
||||||
|
|
||||||
port = self._select_port(packet)
|
port = self._select_port(packet)
|
||||||
|
|
||||||
raw_packet = packet.serialize()
|
raw_packet = packet.serialize()
|
||||||
@@ -13,9 +13,7 @@ from pyrad.packet import Packet, AuthPacket, AcctPacket, CoAPacket
|
|||||||
|
|
||||||
|
|
||||||
class DatagramProtocolClient(asyncio.Protocol):
|
class DatagramProtocolClient(asyncio.Protocol):
|
||||||
|
def __init__(self, server, port, logger, client, retries=3, timeout=30):
|
||||||
def __init__(self, server, port, logger,
|
|
||||||
client, retries=3, timeout=30):
|
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self.port = port
|
self.port = port
|
||||||
self.server = server
|
self.server = server
|
||||||
@@ -42,22 +40,31 @@ class DatagramProtocolClient(asyncio.Protocol):
|
|||||||
# noinspection PyShadowingBuiltins
|
# noinspection PyShadowingBuiltins
|
||||||
for id, req in self.pending_requests.items():
|
for id, req in self.pending_requests.items():
|
||||||
|
|
||||||
secs = (req['send_date'] - now).seconds
|
secs = (req["send_date"] - now).seconds
|
||||||
if secs > self.timeout:
|
if secs > self.timeout:
|
||||||
if req['retries'] == self.retries:
|
if req["retries"] == self.retries:
|
||||||
self.logger.debug('[%s:%d] For request %d execute all retries',
|
self.logger.debug(
|
||||||
self.server, self.port, id)
|
"[%s:%d] For request %d execute all retries",
|
||||||
req['future'].set_exception(
|
self.server,
|
||||||
TimeoutError('Timeout on Reply')
|
self.port,
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
req["future"].set_exception(
|
||||||
|
TimeoutError("Timeout on Reply")
|
||||||
)
|
)
|
||||||
req2delete.append(id)
|
req2delete.append(id)
|
||||||
else:
|
else:
|
||||||
# Send again packet
|
# Send again packet
|
||||||
req['send_date'] = now
|
req["send_date"] = now
|
||||||
req['retries'] += 1
|
req["retries"] += 1
|
||||||
self.logger.debug('[%s:%d] For request %d execute retry %d',
|
self.logger.debug(
|
||||||
self.server, self.port, id, req['retries'])
|
"[%s:%d] For request %d execute retry %d",
|
||||||
self.transport.sendto(req['packet'].RequestPacket())
|
self.server,
|
||||||
|
self.port,
|
||||||
|
id,
|
||||||
|
req["retries"],
|
||||||
|
)
|
||||||
|
self.transport.sendto(req["packet"].RequestPacket())
|
||||||
elif next_weak_up > secs:
|
elif next_weak_up > secs:
|
||||||
next_weak_up = secs
|
next_weak_up = secs
|
||||||
|
|
||||||
@@ -72,15 +79,15 @@ class DatagramProtocolClient(asyncio.Protocol):
|
|||||||
|
|
||||||
def send_packet(self, packet, future):
|
def send_packet(self, packet, future):
|
||||||
if packet.id in self.pending_requests:
|
if packet.id in self.pending_requests:
|
||||||
raise Exception(f'Packet with id {packet.id} already present')
|
raise Exception(f"Packet with id {packet.id} already present")
|
||||||
|
|
||||||
# Store packet on pending requests map
|
# Store packet on pending requests map
|
||||||
self.pending_requests[packet.id] = {
|
self.pending_requests[packet.id] = {
|
||||||
'packet': packet,
|
"packet": packet,
|
||||||
'creation_date': datetime.now(),
|
"creation_date": datetime.now(),
|
||||||
'retries': 0,
|
"retries": 0,
|
||||||
'future': future,
|
"future": future,
|
||||||
'send_date': datetime.now()
|
"send_date": datetime.now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# In queue packet raw on socket buffer
|
# In queue packet raw on socket buffer
|
||||||
@@ -88,10 +95,11 @@ class DatagramProtocolClient(asyncio.Protocol):
|
|||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
socket = transport.get_extra_info('socket')
|
socket = transport.get_extra_info("socket")
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
'[%s:%d] Transport created with binding in %s:%d',
|
"[%s:%d] Transport created with binding in %s:%d",
|
||||||
self.server, self.port,
|
self.server,
|
||||||
|
self.port,
|
||||||
socket.getsockname()[0],
|
socket.getsockname()[0],
|
||||||
socket.getsockname()[1],
|
socket.getsockname()[1],
|
||||||
)
|
)
|
||||||
@@ -99,38 +107,47 @@ class DatagramProtocolClient(asyncio.Protocol):
|
|||||||
pre_loop = asyncio.get_event_loop()
|
pre_loop = asyncio.get_event_loop()
|
||||||
asyncio.set_event_loop(loop=self.client.loop)
|
asyncio.set_event_loop(loop=self.client.loop)
|
||||||
# Start asynchronous timer handler
|
# Start asynchronous timer handler
|
||||||
self.timeout_future = asyncio.ensure_future(
|
self.timeout_future = asyncio.ensure_future(self.__timeout_handler__())
|
||||||
self.__timeout_handler__()
|
|
||||||
)
|
|
||||||
asyncio.set_event_loop(loop=pre_loop)
|
asyncio.set_event_loop(loop=pre_loop)
|
||||||
|
|
||||||
def error_received(self, exc):
|
def error_received(self, exc):
|
||||||
self.logger.error('[%s:%d] Error received: %s', self.server, self.port, exc)
|
self.logger.error(
|
||||||
|
"[%s:%d] Error received: %s", self.server, self.port, exc
|
||||||
|
)
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
if exc:
|
if exc:
|
||||||
self.logger.warn('[%s:%d] Connection lost: %s', self.server, self.port, str(exc))
|
self.logger.warn(
|
||||||
|
"[%s:%d] Connection lost: %s", self.server, self.port, str(exc)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.info('[%s:%d] Transport closed', self.server, self.port)
|
self.logger.info("[%s:%d] Transport closed", self.server, self.port)
|
||||||
|
|
||||||
# noinspection PyUnusedLocal
|
# noinspection PyUnusedLocal
|
||||||
def datagram_received(self, data, addr):
|
def datagram_received(self, data, addr):
|
||||||
try:
|
try:
|
||||||
req = self.pending_requests[data[0]]
|
req = self.pending_requests[data[0]]
|
||||||
reply = req.VerifyPacket(data)
|
reply = req.VerifyPacket(data)
|
||||||
req['future'].set_result(reply)
|
req["future"].set_result(reply)
|
||||||
# Remove request for map
|
# Remove request for map
|
||||||
del self.pending_requests[reply.id]
|
del self.pending_requests[reply.id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
self.logger.warn('[%s:%d] Ignore invalid reply: %s',
|
self.logger.warn(
|
||||||
self.server, self.port, data)
|
"[%s:%d] Ignore invalid reply: %s", self.server, self.port, data
|
||||||
|
)
|
||||||
except PacketError as exc:
|
except PacketError as exc:
|
||||||
self.logger.error('[%s:%d] Error on decode or verify packet: %s',
|
self.logger.error(
|
||||||
self.server, self.port, exc)
|
"[%s:%d] Error on decode or verify packet: %s",
|
||||||
|
self.server,
|
||||||
|
self.port,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
async def close_transport(self):
|
async def close_transport(self):
|
||||||
if self.transport:
|
if self.transport:
|
||||||
self.logger.debug('[%s:%d] Closing transport...', self.server, self.port)
|
self.logger.debug(
|
||||||
|
"[%s:%d] Closing transport...", self.server, self.port
|
||||||
|
)
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
self.transport = None
|
self.transport = None
|
||||||
if self.timeout_future:
|
if self.timeout_future:
|
||||||
@@ -143,7 +160,9 @@ class DatagramProtocolClient(asyncio.Protocol):
|
|||||||
return self.packet_id
|
return self.packet_id
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'DatagramProtocolClient(server?={self.server}, port={self.port})'
|
return (
|
||||||
|
f"DatagramProtocolClient(server?={self.server}, port={self.port})"
|
||||||
|
)
|
||||||
|
|
||||||
# Used as protocol_factory
|
# Used as protocol_factory
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
@@ -161,11 +180,21 @@ class ClientAsync:
|
|||||||
:ivar timeout: number of seconds to wait for an answer
|
:ivar timeout: number of seconds to wait for an answer
|
||||||
:type timeout: integer
|
:type timeout: integer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# noinspection PyShadowingBuiltins
|
# noinspection PyShadowingBuiltins
|
||||||
def __init__(self, server, auth_port=1812, acct_port=1813,
|
def __init__(
|
||||||
coa_port=3799, secret=b'', dict=None,
|
self,
|
||||||
loop=None, retries=3, timeout=30,
|
server,
|
||||||
logger_name='pyrad'):
|
auth_port=1812,
|
||||||
|
acct_port=1813,
|
||||||
|
coa_port=3799,
|
||||||
|
secret=b"",
|
||||||
|
dict=None,
|
||||||
|
loop=None,
|
||||||
|
retries=3,
|
||||||
|
timeout=30,
|
||||||
|
logger_name="pyrad",
|
||||||
|
):
|
||||||
|
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
@@ -205,23 +234,30 @@ class ClientAsync:
|
|||||||
self.protocol_coa = None
|
self.protocol_coa = None
|
||||||
self.coa_port = coa_port
|
self.coa_port = coa_port
|
||||||
|
|
||||||
async def initialize_transports(self, enable_acct=False,
|
async def initialize_transports(
|
||||||
enable_auth=False, enable_coa=False,
|
self,
|
||||||
local_addr=None, local_auth_port=None,
|
enable_acct=False,
|
||||||
local_acct_port=None, local_coa_port=None):
|
enable_auth=False,
|
||||||
|
enable_coa=False,
|
||||||
|
local_addr=None,
|
||||||
|
local_auth_port=None,
|
||||||
|
local_acct_port=None,
|
||||||
|
local_coa_port=None,
|
||||||
|
):
|
||||||
|
|
||||||
task_list = []
|
task_list = []
|
||||||
|
|
||||||
if not enable_acct and not enable_auth and not enable_coa:
|
if not enable_acct and not enable_auth and not enable_coa:
|
||||||
raise Exception('No transports selected')
|
raise Exception("No transports selected")
|
||||||
|
|
||||||
if enable_acct and not self.protocol_acct:
|
if enable_acct and not self.protocol_acct:
|
||||||
self.protocol_acct = DatagramProtocolClient(
|
self.protocol_acct = DatagramProtocolClient(
|
||||||
self.server,
|
self.server,
|
||||||
self.acct_port,
|
self.acct_port,
|
||||||
self.logger, self,
|
self.logger,
|
||||||
|
self,
|
||||||
retries=self.retries,
|
retries=self.retries,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
bind_addr = None
|
bind_addr = None
|
||||||
if local_addr and local_acct_port:
|
if local_addr and local_acct_port:
|
||||||
@@ -231,7 +267,7 @@ class ClientAsync:
|
|||||||
self.protocol_acct,
|
self.protocol_acct,
|
||||||
reuse_port=True,
|
reuse_port=True,
|
||||||
remote_addr=(self.server, self.acct_port),
|
remote_addr=(self.server, self.acct_port),
|
||||||
local_addr=bind_addr
|
local_addr=bind_addr,
|
||||||
)
|
)
|
||||||
task_list.append(acct_connect)
|
task_list.append(acct_connect)
|
||||||
|
|
||||||
@@ -239,9 +275,10 @@ class ClientAsync:
|
|||||||
self.protocol_auth = DatagramProtocolClient(
|
self.protocol_auth = DatagramProtocolClient(
|
||||||
self.server,
|
self.server,
|
||||||
self.auth_port,
|
self.auth_port,
|
||||||
self.logger, self,
|
self.logger,
|
||||||
|
self,
|
||||||
retries=self.retries,
|
retries=self.retries,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
bind_addr = None
|
bind_addr = None
|
||||||
if local_addr and local_auth_port:
|
if local_addr and local_auth_port:
|
||||||
@@ -251,7 +288,7 @@ class ClientAsync:
|
|||||||
self.protocol_auth,
|
self.protocol_auth,
|
||||||
reuse_port=True,
|
reuse_port=True,
|
||||||
remote_addr=(self.server, self.auth_port),
|
remote_addr=(self.server, self.auth_port),
|
||||||
local_addr=bind_addr
|
local_addr=bind_addr,
|
||||||
)
|
)
|
||||||
task_list.append(auth_connect)
|
task_list.append(auth_connect)
|
||||||
|
|
||||||
@@ -259,9 +296,10 @@ class ClientAsync:
|
|||||||
self.protocol_coa = DatagramProtocolClient(
|
self.protocol_coa = DatagramProtocolClient(
|
||||||
self.server,
|
self.server,
|
||||||
self.coa_port,
|
self.coa_port,
|
||||||
self.logger, self,
|
self.logger,
|
||||||
|
self,
|
||||||
retries=self.retries,
|
retries=self.retries,
|
||||||
timeout=self.timeout
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
bind_addr = None
|
bind_addr = None
|
||||||
if local_addr and local_coa_port:
|
if local_addr and local_coa_port:
|
||||||
@@ -271,22 +309,18 @@ class ClientAsync:
|
|||||||
self.protocol_coa,
|
self.protocol_coa,
|
||||||
reuse_port=True,
|
reuse_port=True,
|
||||||
remote_addr=(self.server, self.coa_port),
|
remote_addr=(self.server, self.coa_port),
|
||||||
local_addr=bind_addr
|
local_addr=bind_addr,
|
||||||
)
|
)
|
||||||
task_list.append(coa_connect)
|
task_list.append(coa_connect)
|
||||||
|
|
||||||
await asyncio.ensure_future(
|
await asyncio.ensure_future(
|
||||||
asyncio.gather(
|
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
|
||||||
*task_list,
|
|
||||||
return_exceptions=False,
|
|
||||||
),
|
|
||||||
loop=self.loop
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection SpellCheckingInspection
|
# noinspection SpellCheckingInspection
|
||||||
async def deinitialize_transports(self, deinit_coa=True,
|
async def deinitialize_transports(
|
||||||
deinit_auth=True,
|
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
|
||||||
deinit_acct=True):
|
):
|
||||||
if self.protocol_coa and deinit_coa:
|
if self.protocol_coa and deinit_coa:
|
||||||
await self.protocol_coa.close_transport()
|
await self.protocol_coa.close_transport()
|
||||||
del self.protocol_coa
|
del self.protocol_coa
|
||||||
@@ -312,11 +346,14 @@ class ClientAsync:
|
|||||||
:rtype: pyrad.packet.Packet
|
:rtype: pyrad.packet.Packet
|
||||||
"""
|
"""
|
||||||
if not self.protocol_auth:
|
if not self.protocol_auth:
|
||||||
raise Exception('Transport not initialized')
|
raise Exception("Transport not initialized")
|
||||||
|
|
||||||
return AuthPacket(dict=self.dict,
|
return AuthPacket(
|
||||||
id=self.protocol_auth.create_id(),
|
dict=self.dict,
|
||||||
secret=self.secret, **args)
|
id=self.protocol_auth.create_id(),
|
||||||
|
secret=self.secret,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
# noinspection PyPep8Naming
|
||||||
def CreateAcctPacket(self, **args):
|
def CreateAcctPacket(self, **args):
|
||||||
@@ -330,11 +367,14 @@ class ClientAsync:
|
|||||||
:rtype: pyrad.packet.Packet
|
:rtype: pyrad.packet.Packet
|
||||||
"""
|
"""
|
||||||
if not self.protocol_acct:
|
if not self.protocol_acct:
|
||||||
raise Exception('Transport not initialized')
|
raise Exception("Transport not initialized")
|
||||||
|
|
||||||
return AcctPacket(id=self.protocol_acct.create_id(),
|
return AcctPacket(
|
||||||
dict=self.dict,
|
id=self.protocol_acct.create_id(),
|
||||||
secret=self.secret, **args)
|
dict=self.dict,
|
||||||
|
secret=self.secret,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
# noinspection PyPep8Naming
|
||||||
def CreateCoAPacket(self, **args):
|
def CreateCoAPacket(self, **args):
|
||||||
@@ -349,20 +389,22 @@ class ClientAsync:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if not self.protocol_acct:
|
if not self.protocol_acct:
|
||||||
raise Exception('Transport not initialized')
|
raise Exception("Transport not initialized")
|
||||||
|
|
||||||
return CoAPacket(id=self.protocol_coa.create_id(),
|
return CoAPacket(
|
||||||
dict=self.dict,
|
id=self.protocol_coa.create_id(),
|
||||||
secret=self.secret, **args)
|
dict=self.dict,
|
||||||
|
secret=self.secret,
|
||||||
|
**args,
|
||||||
|
)
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
# noinspection PyPep8Naming
|
||||||
# noinspection PyShadowingBuiltins
|
# noinspection PyShadowingBuiltins
|
||||||
def CreatePacket(self, id, **args):
|
def CreatePacket(self, id, **args):
|
||||||
if not id:
|
if not id:
|
||||||
raise Exception('Missing mandatory packet id')
|
raise Exception("Missing mandatory packet id")
|
||||||
|
|
||||||
return Packet(id=id, dict=self.dict,
|
return Packet(id=id, dict=self.dict, secret=self.secret, **args)
|
||||||
secret=self.secret, **args)
|
|
||||||
|
|
||||||
# noinspection PyPep8Naming
|
# noinspection PyPep8Naming
|
||||||
def SendPacket(self, pkt):
|
def SendPacket(self, pkt):
|
||||||
@@ -378,23 +420,23 @@ class ClientAsync:
|
|||||||
|
|
||||||
if isinstance(pkt, AuthPacket):
|
if isinstance(pkt, AuthPacket):
|
||||||
if not self.protocol_auth:
|
if not self.protocol_auth:
|
||||||
raise Exception('Transport not initialized')
|
raise Exception("Transport not initialized")
|
||||||
|
|
||||||
self.protocol_auth.send_packet(pkt, ans)
|
self.protocol_auth.send_packet(pkt, ans)
|
||||||
|
|
||||||
elif isinstance(pkt, AcctPacket):
|
elif isinstance(pkt, AcctPacket):
|
||||||
if not self.protocol_acct:
|
if not self.protocol_acct:
|
||||||
raise Exception('Transport not initialized')
|
raise Exception("Transport not initialized")
|
||||||
|
|
||||||
self.protocol_acct.send_packet(pkt, ans)
|
self.protocol_acct.send_packet(pkt, ans)
|
||||||
|
|
||||||
elif isinstance(pkt, CoAPacket):
|
elif isinstance(pkt, CoAPacket):
|
||||||
if not self.protocol_coa:
|
if not self.protocol_coa:
|
||||||
raise Exception('Transport not initialized')
|
raise Exception("Transport not initialized")
|
||||||
|
|
||||||
self.protocol_coa.send_packet(pkt, ans)
|
self.protocol_coa.send_packet(pkt, ans)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported packet')
|
raise Exception("Unsupported packet")
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
483
src/pyrad3/dictionary.py
Normal file
483
src/pyrad3/dictionary.py
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
"""RADIUS Dictionary.
|
||||||
|
|
||||||
|
Classes and Types to parse and represent a RADIUS dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import IntEnum, Enum, auto
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from os.path import dirname, isabs, join, normpath
|
||||||
|
from typing import Dict, Generator, IO, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
INTEGER_TYPES = {
|
||||||
|
"byte": (0, 255),
|
||||||
|
"short": (0, 2 ** 16 - 1),
|
||||||
|
"signed": (-(2 ** 31), 2 ** 31 - 1),
|
||||||
|
"integer": (0, 2 ** 32 - 1),
|
||||||
|
"integer64": (0, 2 ** 64 - 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Datatype(Enum):
|
||||||
|
"""Possible Datatypes for ATTRIBUTES"""
|
||||||
|
|
||||||
|
string = auto()
|
||||||
|
octets = auto()
|
||||||
|
date = auto()
|
||||||
|
abinary = auto()
|
||||||
|
byte = auto()
|
||||||
|
short = auto()
|
||||||
|
integer = auto()
|
||||||
|
signed = auto()
|
||||||
|
integer64 = auto()
|
||||||
|
ipaddr = auto()
|
||||||
|
ipv4prefix = auto()
|
||||||
|
ipv6addr = auto()
|
||||||
|
ipv6prefix = auto()
|
||||||
|
comboip = auto()
|
||||||
|
ifid = auto()
|
||||||
|
ether = auto()
|
||||||
|
concat = auto()
|
||||||
|
tlv = auto()
|
||||||
|
extended = auto()
|
||||||
|
longextended = auto()
|
||||||
|
evs = auto()
|
||||||
|
|
||||||
|
|
||||||
|
class ParseError(Exception):
|
||||||
|
"""RADIUS Dictionary Parser Error"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
msg: str = None,
|
||||||
|
line: Optional[int] = None,
|
||||||
|
**data,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.msg = msg
|
||||||
|
self.file = filename
|
||||||
|
self.line = line
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
line = f"({self.line}" if self.line is not None else ""
|
||||||
|
return f"{self.file}{line}: ParseError: {self.msg}"
|
||||||
|
|
||||||
|
|
||||||
|
class Encrypt(IntEnum):
|
||||||
|
"""Enum for different RADIUS Encryption types."""
|
||||||
|
|
||||||
|
NoEncrpytion = 0
|
||||||
|
RadiusCrypt = 1
|
||||||
|
SaltCrypt = 2
|
||||||
|
AscendCrypt = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Attribute: # pylint: disable=too-many-instance-attributes
|
||||||
|
"""RADIUS Attribute definition"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
code: int
|
||||||
|
datatype: Datatype
|
||||||
|
has_tag: bool = False
|
||||||
|
encrypt: Encrypt = Encrypt(0)
|
||||||
|
is_sub_attr: bool = False
|
||||||
|
# vendor = Dictionary
|
||||||
|
values: Dict[Union[int, str], Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Vendor:
|
||||||
|
"""Representation of a vendor"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
code: int
|
||||||
|
tlength: int
|
||||||
|
llength: int
|
||||||
|
continuation: bool
|
||||||
|
attrs: Dict[Union[int, Tuple[int, ...]], Attribute]
|
||||||
|
|
||||||
|
|
||||||
|
def dict_parser(
|
||||||
|
filename: str, rad_dict: IO
|
||||||
|
) -> Generator[Tuple[int, List[str]], None, None]:
|
||||||
|
"""Tokenstream of RADIUS Dictionary files
|
||||||
|
|
||||||
|
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
|
||||||
|
and "FILE_CLOSED" tokens will be emitted.
|
||||||
|
"""
|
||||||
|
yield (-1, ["FILE_OPENED", filename])
|
||||||
|
for line_num, line in enumerate(rad_dict.readlines()):
|
||||||
|
tokens = line.split("#", 1)[0].strip().split()
|
||||||
|
if tokens:
|
||||||
|
first_tok = tokens[0] = tokens[0].upper()
|
||||||
|
if first_tok == "$INCLUDE":
|
||||||
|
try:
|
||||||
|
inner_filename = tokens[1]
|
||||||
|
except IndexError:
|
||||||
|
raise ParseError(
|
||||||
|
filename, "$INCLUDE is missing a filename", line_num,
|
||||||
|
)
|
||||||
|
if not isabs(tokens[1]):
|
||||||
|
path = dirname(filename)
|
||||||
|
inner_filename = normpath(join(path, inner_filename))
|
||||||
|
yield from dict_loader(inner_filename)
|
||||||
|
yield (line_num, tokens)
|
||||||
|
yield (-1, ["FILE_CLOSED"])
|
||||||
|
|
||||||
|
|
||||||
|
def dict_loader(filename: str) -> Generator[Tuple[int, List[str]], None, None]:
|
||||||
|
"""Tokenstream of RADIUS Dictionary files
|
||||||
|
|
||||||
|
Additionally to the "regular" (Free)RADIUS Dictionary tokens "FILE_OPENED"
|
||||||
|
and "FILE_CLOSED" tokens will be emitted.
|
||||||
|
"""
|
||||||
|
with open(filename, "r") as rad_dict:
|
||||||
|
yield from dict_parser(filename, rad_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_number(num: str) -> int:
|
||||||
|
"""Parse a number from (Free)RADIUS dictionaries
|
||||||
|
|
||||||
|
Numbers can be either decimal, octal, or hexadecimal.
|
||||||
|
"""
|
||||||
|
if num.startswith("0x"):
|
||||||
|
return int(num, 16)
|
||||||
|
if num.startswith("0o"):
|
||||||
|
return int(num, 8)
|
||||||
|
return int(num)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_attribute_code(attr_code: str) -> List[int]:
|
||||||
|
"""Parse attribute codes from (Free)RADIUS dictionaries
|
||||||
|
|
||||||
|
Codes can be either decimal, octal, or hexadecimal.
|
||||||
|
TLV typed can
|
||||||
|
"""
|
||||||
|
codes = []
|
||||||
|
for code in attr_code.split("."):
|
||||||
|
codes.append(_parse_number(code))
|
||||||
|
return codes
|
||||||
|
|
||||||
|
|
||||||
|
class Dictionary:
|
||||||
|
"""(Free)RADIUS Dictionary.
|
||||||
|
|
||||||
|
#TODO: Better documentation
|
||||||
|
This dictionary can "contain" multiple dictionaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# there must be some nicer way to unittest this...
|
||||||
|
def __init__(self, dictionary: str, __dictio: Optional[IO] = None):
|
||||||
|
self.vendor: Dict[int, Vendor] = {}
|
||||||
|
self.vendor_lookup_id_by_name: Dict[str, int] = {}
|
||||||
|
self.attrindex: Dict[Union[int, str], Attribute] = {}
|
||||||
|
self.rfc_vendor = Vendor("RFC", 0, 1, 1, False, {})
|
||||||
|
self.cur_vendor = self.rfc_vendor
|
||||||
|
if __dictio is not None:
|
||||||
|
loader = dict_parser(dictionary, __dictio)
|
||||||
|
else:
|
||||||
|
loader = dict_loader(dictionary)
|
||||||
|
self.read_dictionary(loader)
|
||||||
|
|
||||||
|
def read_dictionary(
|
||||||
|
self, reader: Generator[Tuple[int, List[str]], None, None]
|
||||||
|
):
|
||||||
|
"""Read and parse a (Free)RADIUS dictionary."""
|
||||||
|
self.filestack = []
|
||||||
|
for line_num, tokens in reader:
|
||||||
|
key = tokens[0]
|
||||||
|
if key == "ATTRIBUTE":
|
||||||
|
self._parse_attribute(tokens, line_num)
|
||||||
|
elif key == "VALUE":
|
||||||
|
self._parse_value(tokens, line_num)
|
||||||
|
elif key == "FILE_OPENED":
|
||||||
|
LOG.info("Parsing file: %s", tokens[1])
|
||||||
|
if tokens[1] in self.filestack:
|
||||||
|
raise ParseError(
|
||||||
|
self.filestack[-1], "Include recursion detected"
|
||||||
|
)
|
||||||
|
self.filestack.append(tokens[1])
|
||||||
|
elif key == "FILE_CLOSED":
|
||||||
|
filename = self.filestack.pop()
|
||||||
|
LOG.info("Finished parsing file: %s", filename)
|
||||||
|
elif key == "VENDOR":
|
||||||
|
self._parse_vendor(tokens, line_num)
|
||||||
|
elif key == "BEGIN-VENDOR":
|
||||||
|
self._parse_begin_vendor(tokens, line_num)
|
||||||
|
elif key == "END-VENDOR":
|
||||||
|
self._parse_end_vendor(tokens, line_num)
|
||||||
|
elif key == "BEGIN-TLV":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"BEGIN-TLV is deprecated and not supported by pyrad3"
|
||||||
|
)
|
||||||
|
elif key == "END-TLV":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"END-TLV is deprecated and not supported by pyrad3"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ParseError(
|
||||||
|
self.filestack[-1], f"Invalid Token key {key}", line_num
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_vendor(self, tokens: Sequence[str], line_num: int):
|
||||||
|
"""Parse the vendor definition"""
|
||||||
|
filename = self.filestack[-1]
|
||||||
|
if len(tokens) not in {3, 4}:
|
||||||
|
raise ParseError(
|
||||||
|
filename, "Incorrect number of tokens for vendor statement"
|
||||||
|
)
|
||||||
|
vendor_name = tokens[1]
|
||||||
|
vendor_id = int(tokens[2], 0)
|
||||||
|
continuation = False
|
||||||
|
|
||||||
|
# Parse optional vendor specification
|
||||||
|
try:
|
||||||
|
vendor_format = tokens[3].split("=")
|
||||||
|
if vendor_format[0] != "format":
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"Unknown option {vendor_format[0]} for vendor definition",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
vendor_format = vendor_format[1].split(",")
|
||||||
|
t_len, l_len = (int(a) for a in vendor_format[:2])
|
||||||
|
if t_len not in {1, 2, 4}:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f'Invalid type length definition "{t_len}" for vendor {vendor_name}',
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
if l_len not in {0, 1, 2}:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f'Invalid length definition "{l_len}" for vendor {vendor_name}',
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if vendor_format[2] == "c":
|
||||||
|
if not vendor_name.upper() == "WIMAX":
|
||||||
|
# Not sure why, but FreeRADIUS has this limit,
|
||||||
|
# so we just do the same cause they know better than me
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"continuation-bit is only supported for WiMAX",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
continuation = True
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"Syntax error in specification for vendor {vendor_name}",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
# no format definition
|
||||||
|
t_len, l_len = 1, 1
|
||||||
|
|
||||||
|
vendor = Vendor(vendor_name, vendor_id, t_len, l_len, continuation, {})
|
||||||
|
self.vendor_lookup_id_by_name[vendor_name] = vendor_id
|
||||||
|
self.vendor[vendor_id] = vendor
|
||||||
|
|
||||||
|
def _parse_begin_vendor(self, tokens: Sequence[str], line_num: int):
|
||||||
|
"""Parse the BEGIN-VENDOR line of (Free)RADIUS dictionaries."""
|
||||||
|
filename = self.filestack[-1]
|
||||||
|
if self.cur_vendor != self.rfc_vendor:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"vendor-begin sections are not allowed to be nested",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
if len(tokens) != 2:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"Incorrect number of tokens for begin-vendor statement",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
vendor_id = self.vendor_lookup_id_by_name[tokens[1]]
|
||||||
|
self.cur_vendor = self.vendor[vendor_id]
|
||||||
|
except KeyError:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"Unknown vendor {tokens[1]} in begin-vendor statement",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_end_vendor(self, tokens: Sequence[str], line_num: int):
|
||||||
|
"""Parse the END-VENDOR line of (Free)RADIUS dictionaries."""
|
||||||
|
filename = self.filestack[-1]
|
||||||
|
if len(tokens) != 2:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"Incorrect number of tokens for end-vendor statement",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
if self.cur_vendor.name != tokens[1]:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"Closing non-opened vendor {tokens[1]} in end-vendor statement",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
self.cur_vendor = self.rfc_vendor
|
||||||
|
|
||||||
|
def _parse_attribute_flags(
|
||||||
|
self, tokens: Sequence[str], line_num: int
|
||||||
|
) -> Tuple[bool, Encrypt]:
|
||||||
|
"""Parse Attribute flags of (Free)RADIUS dictionaries."""
|
||||||
|
filename = self.filestack[-1]
|
||||||
|
has_tag = False
|
||||||
|
encrypt = Encrypt.NoEncrpytion
|
||||||
|
|
||||||
|
try:
|
||||||
|
flags = [flag.split("=") for flag in tokens[4].split(",")]
|
||||||
|
except IndexError:
|
||||||
|
return False, Encrypt.NoEncrpytion
|
||||||
|
|
||||||
|
for flag in flags:
|
||||||
|
flag_len = len(flag)
|
||||||
|
if flag == 1:
|
||||||
|
value = None
|
||||||
|
elif flag_len == 2:
|
||||||
|
value = flag[1]
|
||||||
|
else:
|
||||||
|
raise ParseError(
|
||||||
|
filename, f"Incorrect attribute flag {flag}", line_num
|
||||||
|
)
|
||||||
|
key = flag[0]
|
||||||
|
|
||||||
|
if key == "has_tag":
|
||||||
|
has_tag = True
|
||||||
|
elif key == "encrypt":
|
||||||
|
try:
|
||||||
|
encrypt = Encrypt(int(value)) # type: ignore
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"Illegal attribute encryption {value}",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ParseError(
|
||||||
|
filename, "Unknown attribute flag {key}", line_num
|
||||||
|
)
|
||||||
|
|
||||||
|
return has_tag, encrypt
|
||||||
|
|
||||||
|
def _parse_attribute(self, tokens: Sequence[str], line_num: int):
|
||||||
|
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
|
||||||
|
filename = self.filestack[-1]
|
||||||
|
if not len(tokens) in {4, 5}:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"Incorrect number of tokens for attribute definition",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
has_tag, encrypt = self._parse_attribute_flags(tokens, line_num)
|
||||||
|
name, code, datatype = tokens[1:4]
|
||||||
|
|
||||||
|
if datatype == "concat" and self.cur_vendor != self.rfc_vendor:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
'vendor attributes are not allowed to have the datatype "concat"',
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
codes = _parse_attribute_code(code)
|
||||||
|
except ValueError:
|
||||||
|
raise ParseError(
|
||||||
|
filename, f'invalid attribute code {code}""', line_num
|
||||||
|
)
|
||||||
|
|
||||||
|
for code in codes:
|
||||||
|
tlength = self.cur_vendor.tlength
|
||||||
|
if 2 ** (8 * tlength) <= code:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"attribute code is too big, must be smaller than 2**{tlength}",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
if code < 0:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"negative attribute codes are not allowed",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Do we some explicit handling of tlvs?
|
||||||
|
# if len(codes) > 1:
|
||||||
|
# self._parse_attribute_tlv(codes, line_num)
|
||||||
|
# else:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
base_datatype = datatype.split("[")[0].replace("-", "")
|
||||||
|
try:
|
||||||
|
attribute_type = Datatype[base_datatype]
|
||||||
|
except KeyError:
|
||||||
|
raise ParseError(filename, f"Illegal type: {datatype}", line_num)
|
||||||
|
|
||||||
|
attribute = Attribute(
|
||||||
|
name,
|
||||||
|
codes[-1],
|
||||||
|
attribute_type,
|
||||||
|
has_tag,
|
||||||
|
encrypt,
|
||||||
|
len(codes) > 1,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
|
||||||
|
self.cur_vendor.attrs[attrcode] = attribute
|
||||||
|
|
||||||
|
if self.cur_vendor != self.rfc_vendor:
|
||||||
|
codes = tuple([26] + codes)
|
||||||
|
attrcode = codes[0] if len(codes) == 1 else tuple(codes)
|
||||||
|
self.attrindex[attrcode] = attribute
|
||||||
|
self.attrindex[name] = attribute
|
||||||
|
|
||||||
|
def _parse_value(self, tokens: Sequence[str], line_num: int):
|
||||||
|
"""Parse an ATTRIBUTE line of (Free)RADIUS dictionaries."""
|
||||||
|
filename = self.filestack[-1]
|
||||||
|
if len(tokens) != 4:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
"Incorrect number of tokens for VALUE definition",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
(attr_name, key, value) = tokens[1:]
|
||||||
|
value = _parse_number(value)
|
||||||
|
|
||||||
|
attribute = self.attrindex[attr_name]
|
||||||
|
try:
|
||||||
|
datatype = str(attribute.datatype).split(".")[1]
|
||||||
|
lmin, lmax = INTEGER_TYPES[datatype]
|
||||||
|
if value < lmin or value > lmax:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"VALUE {key}({value}) is not in the limit of type {datatype}",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
raise ParseError(
|
||||||
|
filename,
|
||||||
|
f"only attributes with integer typed datatypes can have"
|
||||||
|
f"value definitions {attribute.datatype}",
|
||||||
|
line_num,
|
||||||
|
)
|
||||||
|
attribute.values[value] = key
|
||||||
|
attribute.values[key] = value
|
||||||
46
src/pyrad3/host.py
Normal file
46
src/pyrad3/host.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
"""Interface Class for RADIUS Clients and Servers"""
|
||||||
|
|
||||||
|
from pyrad3.dictionary import Dictionary
|
||||||
|
from pyrad3 import packet
|
||||||
|
|
||||||
|
|
||||||
|
class Host: # pylint: disable=too-many-arguments
|
||||||
|
"""Interface Class for RADIUS Clients and Servers"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
secret: bytes,
|
||||||
|
radius_dict: Dictionary,
|
||||||
|
authport: int = 1812,
|
||||||
|
acctport: int = 1813,
|
||||||
|
coaport: int = 3799,
|
||||||
|
timeout: float = 30,
|
||||||
|
retries: int = 3,
|
||||||
|
):
|
||||||
|
self.secret = secret
|
||||||
|
self.dictionary = radius_dict
|
||||||
|
|
||||||
|
self.authport = authport
|
||||||
|
self.acctport = acctport
|
||||||
|
self.coaport = coaport
|
||||||
|
|
||||||
|
self.timeout = timeout
|
||||||
|
self.retries = retries
|
||||||
|
|
||||||
|
def create_packet(self, **kwargs):
|
||||||
|
"""Create a generic RADIUS Packet"""
|
||||||
|
return packet.Packet(self, **kwargs)
|
||||||
|
|
||||||
|
def create_auth_packet(self, **kwargs):
|
||||||
|
"""Create an Authentictaion packet (request per default)"""
|
||||||
|
return packet.AuthPacket(self, **kwargs)
|
||||||
|
|
||||||
|
def create_acct_packet(self, **kwargs):
|
||||||
|
"""Create an Accounting packet (request per default)"""
|
||||||
|
return packet.AcctPacket(self, **kwargs)
|
||||||
|
|
||||||
|
def create_coa_packet(self, **kwargs):
|
||||||
|
"""Create an CoA packet (requset per default)"""
|
||||||
|
return packet.CoAPacket(self, **kwargs)
|
||||||
305
src/pyrad3/packet.py
Normal file
305
src/pyrad3/packet.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from enum import IntEnum
|
||||||
|
from secrets import token_bytes
|
||||||
|
from typing import Any, Dict, Optional, Sequence
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from pyrad3.host import Host
|
||||||
|
from pyrad3.utils import (
|
||||||
|
PacketError,
|
||||||
|
Attribute,
|
||||||
|
parse_header,
|
||||||
|
parse_attributes,
|
||||||
|
calculate_authenticator,
|
||||||
|
validate_pap_password,
|
||||||
|
validate_chap_password,
|
||||||
|
)
|
||||||
|
|
||||||
|
HMAC = hmac.new
|
||||||
|
|
||||||
|
|
||||||
|
# Packet codes
|
||||||
|
class Code(IntEnum):
|
||||||
|
AccessRequest = 1
|
||||||
|
AccessAccept = 2
|
||||||
|
AccessReject = 3
|
||||||
|
AccountingRequest = 4
|
||||||
|
AccountingResponse = 5
|
||||||
|
AccessChallenge = 11
|
||||||
|
StatusServer = 12
|
||||||
|
StatusClient = 13
|
||||||
|
DisconnectRequest = 40
|
||||||
|
DisconnectACK = 41
|
||||||
|
DisconnectNAK = 42
|
||||||
|
CoARequest = 43
|
||||||
|
CoAACK = 44
|
||||||
|
CoANAK = 45
|
||||||
|
|
||||||
|
|
||||||
|
class AuthError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Packet(OrderedDict):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: Host,
|
||||||
|
code: Code,
|
||||||
|
radius_id: int,
|
||||||
|
*,
|
||||||
|
request: "Packet" = None,
|
||||||
|
**attributes
|
||||||
|
):
|
||||||
|
super().__init__(**attributes)
|
||||||
|
self.code = code
|
||||||
|
self.id = radius_id
|
||||||
|
self.host = host
|
||||||
|
self.request = request
|
||||||
|
self.ordered_attributes: Sequence[Attribute] = []
|
||||||
|
self.raw_packet: Optional[bytearray] = None
|
||||||
|
self.authenticator: Optional[bytes] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_raw(host: Host, raw_packet: bytearray) -> "Packet":
|
||||||
|
(code, radius_id, _length, authenticator) = parse_header(raw_packet)
|
||||||
|
|
||||||
|
ordered_attrs = parse_attributes(host.dictionary, raw_packet)
|
||||||
|
|
||||||
|
# Can we do better than Any with type hinting?
|
||||||
|
attrs: Dict[str, Any] = {}
|
||||||
|
for attr in ordered_attrs:
|
||||||
|
try:
|
||||||
|
attrs[attr.name].append(attr.value)
|
||||||
|
except KeyError:
|
||||||
|
attrs[attr.name] = [attr.value]
|
||||||
|
|
||||||
|
parsed_packet = Packet(host, code, radius_id, **attrs)
|
||||||
|
parsed_packet.authenticator = authenticator
|
||||||
|
parsed_packet.raw_packet = raw_packet
|
||||||
|
parsed_packet.ordered_attributes = ordered_attrs
|
||||||
|
return parsed_packet
|
||||||
|
|
||||||
|
def from_raw_reply(self, raw_packet: bytearray) -> "Packet":
|
||||||
|
self.verify_reply(raw_packet)
|
||||||
|
reply = Packet.from_raw(self.host, raw_packet)
|
||||||
|
reply.request = self
|
||||||
|
try:
|
||||||
|
if not reply.validate_message_authenticator():
|
||||||
|
raise PacketError("Packet has a wrong message authenticator")
|
||||||
|
except KeyError:
|
||||||
|
if "EAP-Message" in reply:
|
||||||
|
raise PacketError("Packet is missing a message authenticator")
|
||||||
|
return reply
|
||||||
|
|
||||||
|
def send(self):
|
||||||
|
"""Send the packet to the Client/Server.
|
||||||
|
"""
|
||||||
|
self.host._send_packet(self)
|
||||||
|
|
||||||
|
def verify_reply(self, raw_reply: bytes):
|
||||||
|
"""Verify the reply to this packet.
|
||||||
|
"""
|
||||||
|
if self.id != raw_reply[1]:
|
||||||
|
raise PacketError("Response has a wrong id")
|
||||||
|
|
||||||
|
# self.authenticator MUST be set, this packet got send so by definitation
|
||||||
|
# self.authenticator will not be non, but bytes
|
||||||
|
radius_hash = calculate_authenticator(
|
||||||
|
self.host.secret,
|
||||||
|
self.authenticator, # type: ignore
|
||||||
|
raw_reply,
|
||||||
|
)
|
||||||
|
|
||||||
|
if radius_hash != raw_reply[4:20]:
|
||||||
|
raise PacketError("Reply Packet has a wrong authenticator")
|
||||||
|
|
||||||
|
def validate_message_authenticator(self):
|
||||||
|
message_authenticator = self["Message-Authenticator"]
|
||||||
|
if isinstance(list, message_authenticator):
|
||||||
|
# There are multiple Message Authenticators, but a packet MUST NOT have
|
||||||
|
# more than one
|
||||||
|
return False
|
||||||
|
ma_attribute = self.find_first_attribute("Message-Authenticator")
|
||||||
|
generated = self._generate_message_authenticator(ma_attribute)
|
||||||
|
return message_authenticator == generated
|
||||||
|
|
||||||
|
def _generate_message_authenticator(self, ma_attr: Attribute):
|
||||||
|
assert self.authenticator is not None
|
||||||
|
assert self.request is not None
|
||||||
|
assert self.request.authenticator is not None
|
||||||
|
assert self.raw_packet is not None
|
||||||
|
|
||||||
|
# The message authenticator must be treated as 16 * \00
|
||||||
|
start_pos = ma_attr.pos + 2
|
||||||
|
end_pos = start_pos + 16
|
||||||
|
original_ma: bytes = ma_attr.value
|
||||||
|
self.raw_packet[start_pos:end_pos] = 16 * b"\00"
|
||||||
|
|
||||||
|
hmac_builder = HMAC(self.host.secret, digestmod=hashlib.md5)
|
||||||
|
hmac_builder.update(self.raw_packet)
|
||||||
|
|
||||||
|
if self.code in (Code.AccessRequest, Code.StatusServer):
|
||||||
|
hmac_builder.update(self.authenticator)
|
||||||
|
elif self.code in (
|
||||||
|
Code.AccessAccept,
|
||||||
|
Code.AccessChallenge,
|
||||||
|
Code.AccessReject,
|
||||||
|
):
|
||||||
|
hmac_builder.update(self.request.authenticator)
|
||||||
|
else:
|
||||||
|
hmac_builder.update(16 * b"\00")
|
||||||
|
|
||||||
|
hmac_builder.update(self.raw_packet[20:])
|
||||||
|
self.raw_packet[start_pos:end_pos] = original_ma
|
||||||
|
return hmac_builder.digest()
|
||||||
|
|
||||||
|
def add_message_authenticator(self):
|
||||||
|
self._encode_packet()
|
||||||
|
self._generate_message_authenticator(self)
|
||||||
|
try:
|
||||||
|
# quick lookup before we iterate over the whole packet
|
||||||
|
_ = self["Message-Authenticator"]
|
||||||
|
attr = self.find_first_attribute("Message-Authenticator")
|
||||||
|
except KeyError:
|
||||||
|
self["Message-Authenticator"] = 16 * b"\00"
|
||||||
|
attr = self.ordered_attributes[-1]
|
||||||
|
generated = self._generate_message_authenticator(attr)
|
||||||
|
self[attr.pos + 2 :] = generated
|
||||||
|
|
||||||
|
def refresh_message_authenticator(self):
|
||||||
|
self.add_message_authenticator()
|
||||||
|
|
||||||
|
def find_first_attribute(self, attr_type_name: str) -> Attribute:
|
||||||
|
for attr in self.ordered_attributes:
|
||||||
|
if attr.type == attr_type_name:
|
||||||
|
return attr.type
|
||||||
|
raise KeyError
|
||||||
|
|
||||||
|
def _encode_packet(self):
|
||||||
|
self.raw_packet = None
|
||||||
|
|
||||||
|
|
||||||
|
class AuthPacket(Packet):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: Host,
|
||||||
|
radius_id: int,
|
||||||
|
auth_type,
|
||||||
|
*,
|
||||||
|
code: Code = Code.AccessRequest,
|
||||||
|
request: Optional[Packet] = None,
|
||||||
|
**attributes
|
||||||
|
):
|
||||||
|
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||||
|
self.auth_type = auth_type
|
||||||
|
if code == Code.AccessRequest:
|
||||||
|
self.authenticator = token_bytes(16)
|
||||||
|
|
||||||
|
def create_accept(self, **attributes):
|
||||||
|
return AuthPacket(
|
||||||
|
self.host,
|
||||||
|
self.id,
|
||||||
|
self.auth_type,
|
||||||
|
request=self,
|
||||||
|
code=Code.AccessAccept,
|
||||||
|
**attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_reject(self, **attributes):
|
||||||
|
return AuthPacket(
|
||||||
|
self.host,
|
||||||
|
self.id,
|
||||||
|
self.auth_type,
|
||||||
|
request=self,
|
||||||
|
code=Code.AccessReject,
|
||||||
|
**attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_challange(self, **attributes):
|
||||||
|
return AuthPacket(
|
||||||
|
self.host,
|
||||||
|
self.id,
|
||||||
|
self.auth_type,
|
||||||
|
request=self,
|
||||||
|
code=Code.AccessChallenge,
|
||||||
|
**attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_password(self, password: bytes) -> bool:
|
||||||
|
try:
|
||||||
|
return self.validate_pap(password)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
# Will throw KeyError if no chap password exists
|
||||||
|
return self.validate_chap(password)
|
||||||
|
|
||||||
|
def validate_pap(self, password: bytes) -> bool:
|
||||||
|
packet_password = self["User-Password"]
|
||||||
|
return validate_pap_password(
|
||||||
|
self.host.secret,
|
||||||
|
self.authenticator, # type: ignore
|
||||||
|
packet_password,
|
||||||
|
password,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_chap(self, password: bytes) -> bool:
|
||||||
|
packet_password = self["Chap-Password"]
|
||||||
|
chap_id = packet_password[:1]
|
||||||
|
chap_password = packet_password[1:]
|
||||||
|
try:
|
||||||
|
challenge = self["Chap-Challenge"]
|
||||||
|
except KeyError:
|
||||||
|
challenge = self.authenticator
|
||||||
|
return validate_chap_password(
|
||||||
|
chap_id, challenge, chap_password, password, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AcctPacket(Packet):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: Host,
|
||||||
|
radius_id: int,
|
||||||
|
*,
|
||||||
|
code: Code = Code.AccountingRequest,
|
||||||
|
request: Optional[Packet] = None,
|
||||||
|
**attributes
|
||||||
|
):
|
||||||
|
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||||
|
|
||||||
|
def create_response(self, **attributes):
|
||||||
|
return AcctPacket(
|
||||||
|
self.host,
|
||||||
|
self.id,
|
||||||
|
code=Code.AccountingResponse,
|
||||||
|
request=self,
|
||||||
|
**attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CoAPacket(Packet):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: Host,
|
||||||
|
radius_id: int,
|
||||||
|
*,
|
||||||
|
code: Code = Code.CoARequest,
|
||||||
|
request: Optional[Packet] = None,
|
||||||
|
**attributes
|
||||||
|
):
|
||||||
|
super().__init__(host, code, radius_id, request=request, **attributes)
|
||||||
|
|
||||||
|
def create_ack(self, **attributes):
|
||||||
|
return CoAPacket(
|
||||||
|
self.host, self.id, code=Code.CoAACK, request=self, **attributes
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_nack(self, **attributes):
|
||||||
|
return CoAPacket(
|
||||||
|
self.host, self.id, code=Code.CoANAK, request=self, **attributes
|
||||||
|
)
|
||||||
@@ -26,7 +26,8 @@ class Proxy(Server):
|
|||||||
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
|
self._fdmap[self._proxyfd.fileno()] = self._proxyfd
|
||||||
self._poll.register(
|
self._poll.register(
|
||||||
self._proxyfd.fileno(),
|
self._proxyfd.fileno(),
|
||||||
(select.POLLIN | select.POLLPRI | select.POLLERR))
|
(select.POLLIN | select.POLLPRI | select.POLLERR),
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_proxy_packet(self, pkt):
|
def _handle_proxy_packet(self, pkt):
|
||||||
"""Process a packet received on the reply socket.
|
"""Process a packet received on the reply socket.
|
||||||
@@ -38,12 +39,15 @@ class Proxy(Server):
|
|||||||
:type pkt: Packet class instance
|
:type pkt: Packet class instance
|
||||||
"""
|
"""
|
||||||
if pkt.source[0] not in self.hosts:
|
if pkt.source[0] not in self.hosts:
|
||||||
raise ServerPacketError('Received packet from unknown host')
|
raise ServerPacketError("Received packet from unknown host")
|
||||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||||
|
|
||||||
if pkt.code not in [packet.AccessAccept, packet.AccessReject,
|
if pkt.code not in [
|
||||||
packet.AccountingResponse]:
|
packet.AccessAccept,
|
||||||
raise ServerPacketError('Received non-response on proxy socket')
|
packet.AccessReject,
|
||||||
|
packet.AccountingResponse,
|
||||||
|
]:
|
||||||
|
raise ServerPacketError("Received non-response on proxy socket")
|
||||||
|
|
||||||
def _process_input(self, fd):
|
def _process_input(self, fd):
|
||||||
"""Process available data.
|
"""Process available data.
|
||||||
@@ -62,7 +66,8 @@ class Proxy(Server):
|
|||||||
"""
|
"""
|
||||||
if fd.fileno() == self._proxyfd.fileno():
|
if fd.fileno() == self._proxyfd.fileno():
|
||||||
pkt = self._grab_packet(
|
pkt = self._grab_packet(
|
||||||
lambda data, s=self: s.CreatePacket(packet=data), fd)
|
lambda data, s=self: s.CreatePacket(packet=data), fd
|
||||||
|
)
|
||||||
self._handle_proxy_packet(pkt)
|
self._handle_proxy_packet(pkt)
|
||||||
else:
|
else:
|
||||||
Server._process_input(self, fd)
|
Server._process_input(self, fd)
|
||||||
@@ -9,13 +9,15 @@ from pyrad import host
|
|||||||
from pyrad import packet
|
from pyrad import packet
|
||||||
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger('pyrad')
|
LOGGER = logging.getLogger("pyrad")
|
||||||
|
|
||||||
|
|
||||||
class RemoteHost:
|
class RemoteHost:
|
||||||
"""Remote RADIUS capable host we can talk to."""
|
"""Remote RADIUS capable host we can talk to."""
|
||||||
|
|
||||||
def __init__(self, address, secret, name, authport=1812, acctport=1813, coaport=3799):
|
def __init__(
|
||||||
|
self, address, secret, name, authport=1812, acctport=1813, coaport=3799
|
||||||
|
):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
:param address: IP address
|
:param address: IP address
|
||||||
@@ -62,10 +64,21 @@ class Server(host.Host):
|
|||||||
:cvar MaxPacketSize: maximum size of a RADIUS packet
|
:cvar MaxPacketSize: maximum size of a RADIUS packet
|
||||||
:type MaxPacketSize: integer
|
:type MaxPacketSize: integer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MaxPacketSize = 4096
|
MaxPacketSize = 4096
|
||||||
|
|
||||||
def __init__(self, addresses=[], authport=1812, acctport=1813, coaport=3799,
|
def __init__(
|
||||||
hosts=None, dict=None, auth_enabled=True, acct_enabled=True, coa_enabled=False):
|
self,
|
||||||
|
addresses=[],
|
||||||
|
authport=1812,
|
||||||
|
acctport=1813,
|
||||||
|
coaport=3799,
|
||||||
|
hosts=None,
|
||||||
|
dict=None,
|
||||||
|
auth_enabled=True,
|
||||||
|
acct_enabled=True,
|
||||||
|
coa_enabled=False,
|
||||||
|
):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
:param addresses: IP addresses to listen on
|
:param addresses: IP addresses to listen on
|
||||||
@@ -114,7 +127,7 @@ class Server(host.Host):
|
|||||||
"""
|
"""
|
||||||
results = set()
|
results = set()
|
||||||
try:
|
try:
|
||||||
tmp = socket.getaddrinfo(addr, 'www')
|
tmp = socket.getaddrinfo(addr, "www")
|
||||||
except socket.gaierror:
|
except socket.gaierror:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -199,10 +212,10 @@ class Server(host.Host):
|
|||||||
"""
|
"""
|
||||||
if pkt.source[0] in self.hosts:
|
if pkt.source[0] in self.hosts:
|
||||||
pkt.secret = self.hosts[pkt.source[0]].secret
|
pkt.secret = self.hosts[pkt.source[0]].secret
|
||||||
elif '0.0.0.0' in self.hosts:
|
elif "0.0.0.0" in self.hosts:
|
||||||
pkt.secret = self.hosts['0.0.0.0'].secret
|
pkt.secret = self.hosts["0.0.0.0"].secret
|
||||||
else:
|
else:
|
||||||
raise ServerPacketError('Received packet from unknown host')
|
raise ServerPacketError("Received packet from unknown host")
|
||||||
|
|
||||||
def _handle_auth_packet(self, pkt):
|
def _handle_auth_packet(self, pkt):
|
||||||
"""Process a packet received on the authentication port.
|
"""Process a packet received on the authentication port.
|
||||||
@@ -216,7 +229,8 @@ class Server(host.Host):
|
|||||||
self._add_secret(pkt)
|
self._add_secret(pkt)
|
||||||
if pkt.code != packet.AccessRequest:
|
if pkt.code != packet.AccessRequest:
|
||||||
raise ServerPacketError(
|
raise ServerPacketError(
|
||||||
'Received non-authentication packet on authentication port')
|
"Received non-authentication packet on authentication port"
|
||||||
|
)
|
||||||
self.HandleAuthPacket(pkt)
|
self.HandleAuthPacket(pkt)
|
||||||
|
|
||||||
def _handle_acct_packet(self, pkt):
|
def _handle_acct_packet(self, pkt):
|
||||||
@@ -229,10 +243,13 @@ class Server(host.Host):
|
|||||||
:type pkt: Packet class instance
|
:type pkt: Packet class instance
|
||||||
"""
|
"""
|
||||||
self._add_secret(pkt)
|
self._add_secret(pkt)
|
||||||
if pkt.code not in [packet.AccountingRequest,
|
if pkt.code not in [
|
||||||
packet.AccountingResponse]:
|
packet.AccountingRequest,
|
||||||
|
packet.AccountingResponse,
|
||||||
|
]:
|
||||||
raise ServerPacketError(
|
raise ServerPacketError(
|
||||||
'Received non-accounting packet on accounting port')
|
"Received non-accounting packet on accounting port"
|
||||||
|
)
|
||||||
self.HandleAcctPacket(pkt)
|
self.HandleAcctPacket(pkt)
|
||||||
|
|
||||||
def _handle_coa_packet(self, pkt):
|
def _handle_coa_packet(self, pkt):
|
||||||
@@ -251,7 +268,7 @@ class Server(host.Host):
|
|||||||
elif pkt.code == packet.DisconnectRequest:
|
elif pkt.code == packet.DisconnectRequest:
|
||||||
self.HandleDisconnectPacket(pkt)
|
self.HandleDisconnectPacket(pkt)
|
||||||
else:
|
else:
|
||||||
raise ServerPacketError('Received non-coa packet on coa port')
|
raise ServerPacketError("Received non-coa packet on coa port")
|
||||||
|
|
||||||
def _grab_packet(self, pktgen, fd):
|
def _grab_packet(self, pktgen, fd):
|
||||||
"""Read a packet from a network connection.
|
"""Read a packet from a network connection.
|
||||||
@@ -273,7 +290,9 @@ class Server(host.Host):
|
|||||||
"""
|
"""
|
||||||
for fd in self.authfds + self.acctfds + self.coafds:
|
for fd in self.authfds + self.acctfds + self.coafds:
|
||||||
self._fdmap[fd.fileno()] = fd
|
self._fdmap[fd.fileno()] = fd
|
||||||
self._poll.register(fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR)
|
self._poll.register(
|
||||||
|
fd.fileno(), select.POLLIN | select.POLLPRI | select.POLLERR
|
||||||
|
)
|
||||||
if self.auth_enabled:
|
if self.auth_enabled:
|
||||||
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
|
self._realauthfds = list(map(lambda x: x.fileno(), self.authfds))
|
||||||
if self.acct_enabled:
|
if self.acct_enabled:
|
||||||
@@ -307,16 +326,22 @@ class Server(host.Host):
|
|||||||
:type fd: socket class instance
|
:type fd: socket class instance
|
||||||
"""
|
"""
|
||||||
if self.auth_enabled and fd.fileno() in self._realauthfds:
|
if self.auth_enabled and fd.fileno() in self._realauthfds:
|
||||||
pkt = self._grab_packet(lambda data, s=self: s.CreateAuthPacket(packet=data), fd)
|
pkt = self._grab_packet(
|
||||||
|
lambda data, s=self: s.CreateAuthPacket(packet=data), fd
|
||||||
|
)
|
||||||
self._handle_auth_packet(pkt)
|
self._handle_auth_packet(pkt)
|
||||||
elif self.acct_enabled and fd.fileno() in self._realacctfds:
|
elif self.acct_enabled and fd.fileno() in self._realacctfds:
|
||||||
pkt = self._grab_packet(lambda data, s=self: s.CreateAcctPacket(packet=data), fd)
|
pkt = self._grab_packet(
|
||||||
|
lambda data, s=self: s.CreateAcctPacket(packet=data), fd
|
||||||
|
)
|
||||||
self._handle_acct_packet(pkt)
|
self._handle_acct_packet(pkt)
|
||||||
elif self.coa_enabled:
|
elif self.coa_enabled:
|
||||||
pkt = self._grab_packet(lambda data, s=self: s.CreateCoAPacket(packet=data), fd)
|
pkt = self._grab_packet(
|
||||||
|
lambda data, s=self: s.CreateCoAPacket(packet=data), fd
|
||||||
|
)
|
||||||
self._handle_coa_packet(pkt)
|
self._handle_coa_packet(pkt)
|
||||||
else:
|
else:
|
||||||
raise ServerPacketError('Received packet for unknown handler')
|
raise ServerPacketError("Received packet for unknown handler")
|
||||||
|
|
||||||
def Run(self):
|
def Run(self):
|
||||||
"""Main loop.
|
"""Main loop.
|
||||||
@@ -335,8 +360,8 @@ class Server(host.Host):
|
|||||||
fdo = self._fdmap[fd]
|
fdo = self._fdmap[fd]
|
||||||
self._process_input(fdo)
|
self._process_input(fdo)
|
||||||
except ServerPacketError as err:
|
except ServerPacketError as err:
|
||||||
LOGGER.info('Dropping packet: %s', err)
|
LOGGER.info("Dropping packet: %s", err)
|
||||||
except packet.PacketError as err:
|
except packet.PacketError as err:
|
||||||
LOGGER.info('Received a broken packet: %s', err)
|
LOGGER.info("Received a broken packet: %s", err)
|
||||||
else:
|
else:
|
||||||
LOGGER.error('Unexpected event in server main loop')
|
LOGGER.error("Unexpected event in server main loop")
|
||||||
@@ -9,26 +9,37 @@ from abc import abstractmethod, ABCMeta
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pyrad.packet import (
|
from pyrad.packet import (
|
||||||
Packet, AccessAccept, AccessReject,
|
Packet,
|
||||||
AccountingRequest, AccountingResponse,
|
AccessAccept,
|
||||||
DisconnectACK, DisconnectNAK, DisconnectRequest, CoARequest,
|
AccessReject,
|
||||||
CoAACK, CoANAK, AccessRequest, AuthPacket, AcctPacket, CoAPacket,
|
AccountingRequest,
|
||||||
PacketError
|
AccountingResponse,
|
||||||
|
DisconnectACK,
|
||||||
|
DisconnectNAK,
|
||||||
|
DisconnectRequest,
|
||||||
|
CoARequest,
|
||||||
|
CoAACK,
|
||||||
|
CoANAK,
|
||||||
|
AccessRequest,
|
||||||
|
AuthPacket,
|
||||||
|
AcctPacket,
|
||||||
|
CoAPacket,
|
||||||
|
PacketError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pyrad.server import ServerPacketError
|
from pyrad.server import ServerPacketError
|
||||||
|
|
||||||
|
|
||||||
class ServerType(Enum):
|
class ServerType(Enum):
|
||||||
Auth = 'Authentication'
|
Auth = "Authentication"
|
||||||
Acct = 'Accounting'
|
Acct = "Accounting"
|
||||||
Coa = 'Coa'
|
Coa = "Coa"
|
||||||
|
|
||||||
|
|
||||||
class DatagramProtocolServer(asyncio.Protocol):
|
class DatagramProtocolServer(asyncio.Protocol):
|
||||||
|
def __init__(
|
||||||
def __init__(self, ip, port, logger, server, server_type, hosts,
|
self, ip, port, logger, server, server_type, hosts, request_callback
|
||||||
request_callback):
|
):
|
||||||
self.transport = None
|
self.transport = None
|
||||||
self.ip = ip
|
self.ip = ip
|
||||||
self.port = port
|
self.port = port
|
||||||
@@ -40,94 +51,149 @@ class DatagramProtocolServer(asyncio.Protocol):
|
|||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
self.logger.info('[%s:%d] Transport created', self.ip, self.port)
|
self.logger.info("[%s:%d] Transport created", self.ip, self.port)
|
||||||
|
|
||||||
def connection_lost(self, exc):
|
def connection_lost(self, exc):
|
||||||
if exc:
|
if exc:
|
||||||
self.logger.warn('[%s:%d] Connection lost: %s', self.ip, self.port, str(exc))
|
self.logger.warn(
|
||||||
|
"[%s:%d] Connection lost: %s", self.ip, self.port, str(exc)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.info('[%s:%d] Transport closed', self.ip, self.port)
|
self.logger.info("[%s:%d] Transport closed", self.ip, self.port)
|
||||||
|
|
||||||
def send_response(self, reply, addr):
|
def send_response(self, reply, addr):
|
||||||
self.transport.sendto(reply.ReplyPacket(), addr)
|
self.transport.sendto(reply.ReplyPacket(), addr)
|
||||||
|
|
||||||
def datagram_received(self, data, addr):
|
def datagram_received(self, data, addr):
|
||||||
self.logger.debug('[%s:%d] Received %d bytes from %s', self.ip, self.port, len(data), addr)
|
self.logger.debug(
|
||||||
|
"[%s:%d] Received %d bytes from %s",
|
||||||
|
self.ip,
|
||||||
|
self.port,
|
||||||
|
len(data),
|
||||||
|
addr,
|
||||||
|
)
|
||||||
|
|
||||||
receive_date = datetime.utcnow()
|
receive_date = datetime.utcnow()
|
||||||
|
|
||||||
if addr[0] in self.hosts:
|
if addr[0] in self.hosts:
|
||||||
remote_host = self.hosts[addr[0]]
|
remote_host = self.hosts[addr[0]]
|
||||||
elif '0.0.0.0' in self.hosts:
|
elif "0.0.0.0" in self.hosts:
|
||||||
remote_host = self.hosts['0.0.0.0']
|
remote_host = self.hosts["0.0.0.0"]
|
||||||
else:
|
else:
|
||||||
self.logger.warn('[%s:%d] Drop package from unknown source %s', self.ip, self.port, addr)
|
self.logger.warn(
|
||||||
|
"[%s:%d] Drop package from unknown source %s",
|
||||||
|
self.ip,
|
||||||
|
self.port,
|
||||||
|
addr,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.logger.debug('[%s:%d] Received from %s packet: %s', self.ip, self.port, addr, data.hex())
|
self.logger.debug(
|
||||||
|
"[%s:%d] Received from %s packet: %s",
|
||||||
|
self.ip,
|
||||||
|
self.port,
|
||||||
|
addr,
|
||||||
|
data.hex(),
|
||||||
|
)
|
||||||
req = Packet(packet=data, dict=self.server.dict)
|
req = Packet(packet=data, dict=self.server.dict)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self.logger.error('[%s:%d] Error on decode packet: %s', self.ip, self.port, exc)
|
self.logger.error(
|
||||||
|
"[%s:%d] Error on decode packet: %s", self.ip, self.port, exc
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if req.code in (AccountingResponse, AccessAccept, AccessReject, CoANAK, CoAACK, DisconnectNAK, DisconnectACK):
|
if req.code in (
|
||||||
raise ServerPacketError(f'Invalid response packet {req.code}')
|
AccountingResponse,
|
||||||
|
AccessAccept,
|
||||||
|
AccessReject,
|
||||||
|
CoANAK,
|
||||||
|
CoAACK,
|
||||||
|
DisconnectNAK,
|
||||||
|
DisconnectACK,
|
||||||
|
):
|
||||||
|
raise ServerPacketError(f"Invalid response packet {req.code}")
|
||||||
|
|
||||||
elif self.server_type == ServerType.Auth:
|
elif self.server_type == ServerType.Auth:
|
||||||
if req.code != AccessRequest:
|
if req.code != AccessRequest:
|
||||||
raise ServerPacketError('Received non-auth packet on auth port')
|
raise ServerPacketError(
|
||||||
req = AuthPacket(secret=remote_host.secret,
|
"Received non-auth packet on auth port"
|
||||||
dict=self.server.dict,
|
)
|
||||||
packet=data)
|
req = AuthPacket(
|
||||||
|
secret=remote_host.secret,
|
||||||
|
dict=self.server.dict,
|
||||||
|
packet=data,
|
||||||
|
)
|
||||||
if self.server.enable_pkt_verify:
|
if self.server.enable_pkt_verify:
|
||||||
if req.VerifyAuthRequest():
|
if req.VerifyAuthRequest():
|
||||||
raise PacketError('Packet verification failed')
|
raise PacketError("Packet verification failed")
|
||||||
|
|
||||||
elif self.server_type == ServerType.Coa:
|
elif self.server_type == ServerType.Coa:
|
||||||
if req.code != DisconnectRequest and req.code != CoARequest:
|
if req.code != DisconnectRequest and req.code != CoARequest:
|
||||||
raise ServerPacketError('Received non-coa packet on coa port')
|
raise ServerPacketError(
|
||||||
req = CoAPacket(secret=remote_host.secret,
|
"Received non-coa packet on coa port"
|
||||||
dict=self.server.dict,
|
)
|
||||||
packet=data)
|
req = CoAPacket(
|
||||||
|
secret=remote_host.secret,
|
||||||
|
dict=self.server.dict,
|
||||||
|
packet=data,
|
||||||
|
)
|
||||||
if self.server.enable_pkt_verify:
|
if self.server.enable_pkt_verify:
|
||||||
if req.VerifyCoARequest():
|
if req.VerifyCoARequest():
|
||||||
raise PacketError('Packet verification failed')
|
raise PacketError("Packet verification failed")
|
||||||
|
|
||||||
elif self.server_type == ServerType.Acct:
|
elif self.server_type == ServerType.Acct:
|
||||||
|
|
||||||
if req.code != AccountingRequest:
|
if req.code != AccountingRequest:
|
||||||
raise ServerPacketError('Received non-acct packet on acct port')
|
raise ServerPacketError(
|
||||||
req = AcctPacket(secret=remote_host.secret,
|
"Received non-acct packet on acct port"
|
||||||
dict=self.server.dict,
|
)
|
||||||
packet=data)
|
req = AcctPacket(
|
||||||
|
secret=remote_host.secret,
|
||||||
|
dict=self.server.dict,
|
||||||
|
packet=data,
|
||||||
|
)
|
||||||
if self.server.enable_pkt_verify:
|
if self.server.enable_pkt_verify:
|
||||||
if req.VerifyAcctRequest():
|
if req.VerifyAcctRequest():
|
||||||
raise PacketError('Packet verification failed')
|
raise PacketError("Packet verification failed")
|
||||||
|
|
||||||
# Call request callback
|
# Call request callback
|
||||||
self.request_callback(self, req, addr)
|
self.request_callback(self, req, addr)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if self.server.debug:
|
if self.server.debug:
|
||||||
self.logger.exception('[%s:%d] Error for packet from %s', self.ip, self.port, addr)
|
self.logger.exception(
|
||||||
|
"[%s:%d] Error for packet from %s", self.ip, self.port, addr
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.error('[%s:%d] Error for packet from %s: %s', self.ip, self.port, addr, exc)
|
self.logger.error(
|
||||||
|
"[%s:%d] Error for packet from %s: %s",
|
||||||
|
self.ip,
|
||||||
|
self.port,
|
||||||
|
addr,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
process_date = datetime.utcnow()
|
process_date = datetime.utcnow()
|
||||||
self.logger.debug('[%s:%d] Request from %s processed in %d ms', self.ip, self.port, addr, (process_date-receive_date).microseconds/1000)
|
self.logger.debug(
|
||||||
|
"[%s:%d] Request from %s processed in %d ms",
|
||||||
|
self.ip,
|
||||||
|
self.port,
|
||||||
|
addr,
|
||||||
|
(process_date - receive_date).microseconds / 1000,
|
||||||
|
)
|
||||||
|
|
||||||
def error_received(self, exc):
|
def error_received(self, exc):
|
||||||
self.logger.error('[%s:%d] Error received: %s', self.ip, self.port, exc)
|
self.logger.error("[%s:%d] Error received: %s", self.ip, self.port, exc)
|
||||||
|
|
||||||
async def close_transport(self):
|
async def close_transport(self):
|
||||||
if self.transport:
|
if self.transport:
|
||||||
self.logger.debug('[%s:%d] Close transport...', self.ip, self.port)
|
self.logger.debug("[%s:%d] Close transport...", self.ip, self.port)
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
self.transport = None
|
self.transport = None
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'DatagramProtocolServer(ip={self.ip}, port={self.port})'
|
return f"DatagramProtocolServer(ip={self.ip}, port={self.port})"
|
||||||
|
|
||||||
# Used as protocol_factory
|
# Used as protocol_factory
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
@@ -135,12 +201,18 @@ class DatagramProtocolServer(asyncio.Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class ServerAsync(metaclass=ABCMeta):
|
class ServerAsync(metaclass=ABCMeta):
|
||||||
|
def __init__(
|
||||||
def __init__(self, auth_port=1812, acct_port=1813,
|
self,
|
||||||
coa_port=3799, hosts=None, dictionary=None,
|
auth_port=1812,
|
||||||
loop=None, logger_name='pyrad',
|
acct_port=1813,
|
||||||
enable_pkt_verify=False,
|
coa_port=3799,
|
||||||
debug=False):
|
hosts=None,
|
||||||
|
dictionary=None,
|
||||||
|
loop=None,
|
||||||
|
logger_name="pyrad",
|
||||||
|
enable_pkt_verify=False,
|
||||||
|
debug=False,
|
||||||
|
):
|
||||||
|
|
||||||
if not loop:
|
if not loop:
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
@@ -174,20 +246,35 @@ class ServerAsync(metaclass=ABCMeta):
|
|||||||
self.handle_acct_packet(protocol, req, addr)
|
self.handle_acct_packet(protocol, req, addr)
|
||||||
elif protocol.server_type == ServerType.Auth:
|
elif protocol.server_type == ServerType.Auth:
|
||||||
self.handle_auth_packet(protocol, req, addr)
|
self.handle_auth_packet(protocol, req, addr)
|
||||||
elif protocol.server_type == ServerType.Coa and \
|
elif (
|
||||||
req.code == CoARequest:
|
protocol.server_type == ServerType.Coa
|
||||||
|
and req.code == CoARequest
|
||||||
|
):
|
||||||
self.handle_coa_packet(protocol, req, addr)
|
self.handle_coa_packet(protocol, req, addr)
|
||||||
elif protocol.server_type == ServerType.Coa and \
|
elif (
|
||||||
req.code == DisconnectRequest:
|
protocol.server_type == ServerType.Coa
|
||||||
|
and req.code == DisconnectRequest
|
||||||
|
):
|
||||||
self.handle_disconnect_packet(protocol, req, addr)
|
self.handle_disconnect_packet(protocol, req, addr)
|
||||||
else:
|
else:
|
||||||
self.logger.error('[%s:%s] Unexpected request found', protocol.ip, protocol.port)
|
self.logger.error(
|
||||||
|
"[%s:%s] Unexpected request found",
|
||||||
|
protocol.ip,
|
||||||
|
protocol.port,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.logger.exception('[%s:%s] Unexpected error', protocol.ip, protocol.port)
|
self.logger.exception(
|
||||||
|
"[%s:%s] Unexpected error", protocol.ip, protocol.port
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.logger.error('[%s:%s] Unexpected error: %s', protocol.ip, protocol.port, exc)
|
self.logger.error(
|
||||||
|
"[%s:%s] Unexpected error: %s",
|
||||||
|
protocol.ip,
|
||||||
|
protocol.port,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
def __is_present_proto__(self, ip, port):
|
def __is_present_proto__(self, ip, port):
|
||||||
if port == self.auth_port:
|
if port == self.auth_port:
|
||||||
@@ -217,87 +304,92 @@ class ServerAsync(metaclass=ABCMeta):
|
|||||||
reply = pkt.CreateReply(**attributes)
|
reply = pkt.CreateReply(**attributes)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
async def initialize_transports(self, enable_acct=False,
|
async def initialize_transports(
|
||||||
enable_auth=False, enable_coa=False,
|
self,
|
||||||
addresses=None):
|
enable_acct=False,
|
||||||
|
enable_auth=False,
|
||||||
|
enable_coa=False,
|
||||||
|
addresses=None,
|
||||||
|
):
|
||||||
|
|
||||||
task_list = []
|
task_list = []
|
||||||
|
|
||||||
if not enable_acct and not enable_auth and not enable_coa:
|
if not enable_acct and not enable_auth and not enable_coa:
|
||||||
raise Exception('No transports selected')
|
raise Exception("No transports selected")
|
||||||
if not addresses or len(addresses) == 0:
|
if not addresses or len(addresses) == 0:
|
||||||
addresses = ['127.0.0.1']
|
addresses = ["127.0.0.1"]
|
||||||
|
|
||||||
# noinspection SpellCheckingInspection
|
# noinspection SpellCheckingInspection
|
||||||
for addr in addresses:
|
for addr in addresses:
|
||||||
|
|
||||||
if enable_acct and not self.__is_present_proto__(addr, self.acct_port):
|
if enable_acct and not self.__is_present_proto__(
|
||||||
|
addr, self.acct_port
|
||||||
|
):
|
||||||
protocol_acct = DatagramProtocolServer(
|
protocol_acct = DatagramProtocolServer(
|
||||||
addr,
|
addr,
|
||||||
self.acct_port,
|
self.acct_port,
|
||||||
self.logger, self,
|
self.logger,
|
||||||
|
self,
|
||||||
ServerType.Acct,
|
ServerType.Acct,
|
||||||
self.hosts,
|
self.hosts,
|
||||||
self.__request_handler__
|
self.__request_handler__,
|
||||||
)
|
)
|
||||||
|
|
||||||
bind_addr = (addr, self.acct_port)
|
bind_addr = (addr, self.acct_port)
|
||||||
acct_connect = self.loop.create_datagram_endpoint(
|
acct_connect = self.loop.create_datagram_endpoint(
|
||||||
protocol_acct,
|
protocol_acct, reuse_port=True, local_addr=bind_addr
|
||||||
reuse_port=True,
|
|
||||||
local_addr=bind_addr
|
|
||||||
)
|
)
|
||||||
self.acct_protocols.append(protocol_acct)
|
self.acct_protocols.append(protocol_acct)
|
||||||
task_list.append(acct_connect)
|
task_list.append(acct_connect)
|
||||||
|
|
||||||
if enable_auth and not self.__is_present_proto__(addr, self.auth_port):
|
if enable_auth and not self.__is_present_proto__(
|
||||||
|
addr, self.auth_port
|
||||||
|
):
|
||||||
protocol_auth = DatagramProtocolServer(
|
protocol_auth = DatagramProtocolServer(
|
||||||
addr,
|
addr,
|
||||||
self.auth_port,
|
self.auth_port,
|
||||||
self.logger, self,
|
self.logger,
|
||||||
|
self,
|
||||||
ServerType.Auth,
|
ServerType.Auth,
|
||||||
self.hosts,
|
self.hosts,
|
||||||
self.__request_handler__
|
self.__request_handler__,
|
||||||
)
|
)
|
||||||
bind_addr = (addr, self.auth_port)
|
bind_addr = (addr, self.auth_port)
|
||||||
|
|
||||||
auth_connect = self.loop.create_datagram_endpoint(
|
auth_connect = self.loop.create_datagram_endpoint(
|
||||||
protocol_auth,
|
protocol_auth, reuse_port=True, local_addr=bind_addr
|
||||||
reuse_port=True,
|
|
||||||
local_addr=bind_addr
|
|
||||||
)
|
)
|
||||||
self.auth_protocols.append(protocol_auth)
|
self.auth_protocols.append(protocol_auth)
|
||||||
task_list.append(auth_connect)
|
task_list.append(auth_connect)
|
||||||
|
|
||||||
if enable_coa and not self.__is_present_proto__(addr, self.coa_port):
|
if enable_coa and not self.__is_present_proto__(
|
||||||
|
addr, self.coa_port
|
||||||
|
):
|
||||||
protocol_coa = DatagramProtocolServer(
|
protocol_coa = DatagramProtocolServer(
|
||||||
addr,
|
addr,
|
||||||
self.coa_port,
|
self.coa_port,
|
||||||
self.logger, self,
|
self.logger,
|
||||||
|
self,
|
||||||
ServerType.Coa,
|
ServerType.Coa,
|
||||||
self.hosts,
|
self.hosts,
|
||||||
self.__request_handler__
|
self.__request_handler__,
|
||||||
)
|
)
|
||||||
bind_addr = (addr, self.coa_port)
|
bind_addr = (addr, self.coa_port)
|
||||||
|
|
||||||
coa_connect = self.loop.create_datagram_endpoint(
|
coa_connect = self.loop.create_datagram_endpoint(
|
||||||
protocol_coa,
|
protocol_coa, reuse_port=True, local_addr=bind_addr
|
||||||
reuse_port=True,
|
|
||||||
local_addr=bind_addr
|
|
||||||
)
|
)
|
||||||
self.coa_protocols.append(protocol_coa)
|
self.coa_protocols.append(protocol_coa)
|
||||||
task_list.append(coa_connect)
|
task_list.append(coa_connect)
|
||||||
|
|
||||||
await asyncio.ensure_future(
|
await asyncio.ensure_future(
|
||||||
asyncio.gather(
|
asyncio.gather(*task_list, return_exceptions=False,), loop=self.loop
|
||||||
*task_list,
|
|
||||||
return_exceptions=False,
|
|
||||||
),
|
|
||||||
loop=self.loop
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# noinspection SpellCheckingInspection
|
# noinspection SpellCheckingInspection
|
||||||
async def deinitialize_transports(self, deinit_coa=True, deinit_auth=True, deinit_acct=True):
|
async def deinitialize_transports(
|
||||||
|
self, deinit_coa=True, deinit_auth=True, deinit_acct=True
|
||||||
|
):
|
||||||
|
|
||||||
if deinit_coa:
|
if deinit_coa:
|
||||||
for proto in self.coa_protocols:
|
for proto in self.coa_protocols:
|
||||||
243
src/pyrad3/tools.py
Normal file
243
src/pyrad3/tools.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
"""Collections of functions to en- and decode RADIUS Attributes"""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from ipaddress import IPv4Address, IPv6Address, IPv6Network, ip_network, ip_address
|
||||||
|
|
||||||
|
import struct
|
||||||
|
|
||||||
|
|
||||||
|
def encode_string(string: str) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type string"""
|
||||||
|
if len(string) > 253:
|
||||||
|
raise ValueError("Can only encode strings of <= 253 characters")
|
||||||
|
if isinstance(string, str):
|
||||||
|
return string.encode("utf-8")
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def encode_octets(string: bytes) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type octet"""
|
||||||
|
if len(string) > 253:
|
||||||
|
raise ValueError("Can only encode strings of <= 253 characters")
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def encode_address(addr: Union[str, IPv4Address]) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type ipaddr"""
|
||||||
|
return IPv4Address(addr).packed
|
||||||
|
|
||||||
|
|
||||||
|
def encode_ipv6_prefix(addr: Union[str, IPv6Network]) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type ipv6prefix"""
|
||||||
|
address = IPv6Network(addr)
|
||||||
|
return struct.pack("2B", *[0, address.prefixlen]) + address.network_address.packed
|
||||||
|
|
||||||
|
|
||||||
|
def encode_ipv6_address(addr: Union[str, IPv6Address]) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type ipv6addr"""
|
||||||
|
return IPv6Address(addr).packed
|
||||||
|
|
||||||
|
|
||||||
|
def encode_combo_ip(addr: Union[str, IPv4Address, IPv6Address]) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type combo-ip"""
|
||||||
|
return ip_address(addr).packed
|
||||||
|
|
||||||
|
|
||||||
|
def encode_ascend_binary(string: str) -> bytes:
|
||||||
|
"""
|
||||||
|
struct_format: List of type=value pairs sperated by spaces.
|
||||||
|
|
||||||
|
Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'
|
||||||
|
|
||||||
|
Type:
|
||||||
|
family ipv4(default) or ipv6
|
||||||
|
action discard(default) or accept
|
||||||
|
direction in(default) or out
|
||||||
|
src source prefix (default ignore)
|
||||||
|
dst destination prefix (default ignore)
|
||||||
|
proto protocol number / next-header number (default ignore)
|
||||||
|
sport source port (default ignore)
|
||||||
|
dport destination port (default ignore)
|
||||||
|
sportq source port qualifier (default 0)
|
||||||
|
dportq destination port qualifier (default 0)
|
||||||
|
|
||||||
|
Source/Destination Port Qualifier:
|
||||||
|
0 no compare
|
||||||
|
1 less than
|
||||||
|
2 equal to
|
||||||
|
3 greater than
|
||||||
|
4 not equal to
|
||||||
|
"""
|
||||||
|
|
||||||
|
terms = {
|
||||||
|
"family": b"\x01",
|
||||||
|
"action": b"\x00",
|
||||||
|
"direction": b"\x01",
|
||||||
|
"src": b"\x00\x00\x00\x00",
|
||||||
|
"dst": b"\x00\x00\x00\x00",
|
||||||
|
"srcl": b"\x00",
|
||||||
|
"dstl": b"\x00",
|
||||||
|
"proto": b"\x00",
|
||||||
|
"sport": b"\x00\x00",
|
||||||
|
"dport": b"\x00\x00",
|
||||||
|
"sportq": b"\x00",
|
||||||
|
"dportq": b"\x00",
|
||||||
|
}
|
||||||
|
|
||||||
|
for term in string.split(" "):
|
||||||
|
key, value = term.split("=")
|
||||||
|
if key == "family" and value == "ipv6":
|
||||||
|
terms[key] = b"\x03"
|
||||||
|
if terms["src"] == b"\x00\x00\x00\x00":
|
||||||
|
terms["src"] = 16 * b"\x00"
|
||||||
|
if terms["dst"] == b"\x00\x00\x00\x00":
|
||||||
|
terms["dst"] = 16 * b"\x00"
|
||||||
|
elif key == "action" and value == "accept":
|
||||||
|
terms[key] = b"\x01"
|
||||||
|
elif key == "direction" and value == "out":
|
||||||
|
terms[key] = b"\x00"
|
||||||
|
elif key in ("src", "dst"):
|
||||||
|
address = ip_network(value)
|
||||||
|
terms[key] = address.network_address.packed
|
||||||
|
terms[key + "l"] = struct.pack("B", address.prefixlen)
|
||||||
|
elif key in ("sport", "dport"):
|
||||||
|
terms[key] = struct.pack("!H", int(value))
|
||||||
|
elif key in ("sportq", "dportq", "proto"):
|
||||||
|
terms[key] = struct.pack("B", int(value))
|
||||||
|
|
||||||
|
trailer = 8 * b"\x00"
|
||||||
|
|
||||||
|
result = b"".join(
|
||||||
|
(
|
||||||
|
terms["family"],
|
||||||
|
terms["action"],
|
||||||
|
terms["direction"],
|
||||||
|
b"\x00",
|
||||||
|
terms["src"],
|
||||||
|
terms["dst"],
|
||||||
|
terms["srcl"],
|
||||||
|
terms["dstl"],
|
||||||
|
terms["proto"],
|
||||||
|
b"\x00",
|
||||||
|
terms["sport"],
|
||||||
|
terms["dport"],
|
||||||
|
terms["sportq"],
|
||||||
|
terms["dportq"],
|
||||||
|
b"\x00\x00",
|
||||||
|
trailer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def encode_integer(num: Union[int, str], struct_format="!I") -> bytes:
|
||||||
|
"""Encode a RADIUS value of some type integer"""
|
||||||
|
return struct.pack(struct_format, int(num))
|
||||||
|
|
||||||
|
|
||||||
|
def encode_date(num: Union[int, str]) -> bytes:
|
||||||
|
"""Encode a RADIUS value of type date"""
|
||||||
|
return struct.pack("!I", int(num))
|
||||||
|
|
||||||
|
|
||||||
|
def decode_string(string: bytes) -> Union[str, bytes]:
|
||||||
|
"""Decode a RADIUS value of type string"""
|
||||||
|
try:
|
||||||
|
return string.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def decode_octets(string: bytes) -> bytes:
|
||||||
|
"""Decode a RADIUS value of type octet"""
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def decode_address(addr: bytes) -> IPv4Address:
|
||||||
|
"""Decode a RADIUS value of type ipaddr"""
|
||||||
|
return IPv4Address(addr)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_ipv6_prefix(addr: bytes) -> IPv6Network:
|
||||||
|
"""Decode a RADIUS value of type ipv6prefix"""
|
||||||
|
addr = addr + b"\x00" * (18 - len(addr))
|
||||||
|
prefix = addr[:2]
|
||||||
|
addr = addr[2:]
|
||||||
|
return IPv6Network((prefix, addr))
|
||||||
|
|
||||||
|
|
||||||
|
def decode_ipv6_address(addr: bytes) -> IPv6Address:
|
||||||
|
"""Decode a RADIUS value of type ipv6addr"""
|
||||||
|
addr = addr + b"\x00" * (16 - len(addr))
|
||||||
|
return IPv6Address(addr)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_combo_ip(addr: bytes) -> Union[IPv4Address, IPv6Address]:
|
||||||
|
"""Decode a RADIUS value of type combo-ip"""
|
||||||
|
return ip_address(addr).packed
|
||||||
|
|
||||||
|
|
||||||
|
def decode_ascend_binary(string):
|
||||||
|
"""Decode a RADIUS value of type abinary"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def decode_integer(num: bytes, struct_format="!I") -> int:
|
||||||
|
"""Decode a RADIUS value of some integer type"""
|
||||||
|
return (struct.unpack(struct_format, num))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def decode_date(num):
|
||||||
|
"""Decode a RADIUS value of type date"""
|
||||||
|
return (struct.unpack("!I", num))[0]
|
||||||
|
|
||||||
|
|
||||||
|
ENCODE_MAP = {
|
||||||
|
"string": encode_string,
|
||||||
|
"octets": encode_octets,
|
||||||
|
"integer": encode_integer,
|
||||||
|
"ipaddr": encode_address,
|
||||||
|
"ipv6prefix": encode_ipv6_prefix,
|
||||||
|
"ipv6addr": encode_ipv6_address,
|
||||||
|
"abinary": encode_ascend_binary,
|
||||||
|
"signed": lambda value: encode_integer(value, "!i"),
|
||||||
|
"short": lambda value: encode_integer(value, "!H"),
|
||||||
|
"byte": lambda value: encode_integer(value, "!B"),
|
||||||
|
"integer64": lambda value: encode_integer(value, '!Q'),
|
||||||
|
"date": encode_date,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def encode_attr(datatype, value):
|
||||||
|
"""Encode a RADIUS attribute"""
|
||||||
|
try:
|
||||||
|
return ENCODE_MAP[datatype](value)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"Unknown attribute type {datatype}")
|
||||||
|
|
||||||
|
|
||||||
|
DECODE_MAP = {
|
||||||
|
"string": decode_string,
|
||||||
|
"octets": decode_octets,
|
||||||
|
"integer": decode_integer,
|
||||||
|
"ipaddr": decode_address,
|
||||||
|
"ipv6prefix": decode_ipv6_prefix,
|
||||||
|
"ipv6addr": decode_ipv6_address,
|
||||||
|
"abinary": decode_ascend_binary,
|
||||||
|
"signed": lambda value: decode_integer(value, "!i"),
|
||||||
|
"short": lambda value: decode_integer(value, "!H"),
|
||||||
|
"byte": lambda value: decode_integer(value, "!B"),
|
||||||
|
"integer64": lambda value: decode_integer(value, "!Q"),
|
||||||
|
"date": decode_date,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_attr(datatype, value):
|
||||||
|
"""Decode a RADIUS attribute"""
|
||||||
|
try:
|
||||||
|
return DECODE_MAP[datatype](value)
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"Unknown attribute type {datatype}")
|
||||||
234
src/pyrad3/utils.py
Normal file
234
src/pyrad3/utils.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
"""Collection of functions to deal with RADIUS packet en- and decoding."""
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
import struct
|
||||||
|
|
||||||
|
from pyrad3.dictionary import Dictionary
|
||||||
|
|
||||||
|
RANDOM_GENERATOR = secrets.SystemRandom()
|
||||||
|
MD5 = hashlib.md5
|
||||||
|
|
||||||
|
|
||||||
|
class PacketError(Exception):
|
||||||
|
"""Exception for Invalid Packets"""
|
||||||
|
|
||||||
|
|
||||||
|
Header = namedtuple("Header", ["code", "radius_id", "length", "authenticator"])
|
||||||
|
Attribute = namedtuple("Attribute", ["name", "pos", "type", "length", "value"])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_header(raw_packet: bytes) -> Header:
|
||||||
|
"""Parse the Header of a RADIUS Packet."""
|
||||||
|
try:
|
||||||
|
header = struct.unpack("!BBH16s", raw_packet)
|
||||||
|
except struct.error:
|
||||||
|
raise PacketError("Packet header is corrupt")
|
||||||
|
|
||||||
|
length = header[3]
|
||||||
|
if len(raw_packet) != length:
|
||||||
|
raise PacketError(
|
||||||
|
f"RADIUS Packet ({len(raw_packet)}) has an invalid length ({length})"
|
||||||
|
)
|
||||||
|
if length > 4096:
|
||||||
|
raise PacketError(f"Packet length is too big ({length})")
|
||||||
|
return Header(*header)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_attributes(
|
||||||
|
rad_dict: Dictionary, raw_packet: bytes
|
||||||
|
) -> List[Attribute]:
|
||||||
|
"""Parse the Attributes in a RADIUS Packet.
|
||||||
|
|
||||||
|
This function skips the Header. The Header must be parsed and verified
|
||||||
|
separately.
|
||||||
|
"""
|
||||||
|
attributes = []
|
||||||
|
packet = raw_packet[20:]
|
||||||
|
|
||||||
|
while packet:
|
||||||
|
try:
|
||||||
|
(key, length) = struct.unpack("!BB", packet[0:2])
|
||||||
|
except struct.error:
|
||||||
|
raise PacketError("Attribute header is corrupt")
|
||||||
|
if length < 2:
|
||||||
|
raise PacketError(f"Attribute length ({length}) is too small")
|
||||||
|
|
||||||
|
value = packet[2:length]
|
||||||
|
offset = len(raw_packet) - len(packet) + length
|
||||||
|
if key == 26:
|
||||||
|
try:
|
||||||
|
attributes.extend(
|
||||||
|
parse_vendor_attributes(rad_dict, offset, value)
|
||||||
|
)
|
||||||
|
except (PacketError, IndexError):
|
||||||
|
attributes.append(
|
||||||
|
Attribute(
|
||||||
|
name="Unknown-Vendor-Attribute",
|
||||||
|
pos=offset,
|
||||||
|
type="octets",
|
||||||
|
length=int(packet[1]),
|
||||||
|
value=packet[2:],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
key = parse_key(rad_dict, key)
|
||||||
|
attributes.extend(parse_value(rad_dict, key, offset, value))
|
||||||
|
packet = packet[length:]
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
def parse_vendor_attributes(
|
||||||
|
rad_dict: Dictionary, offset: int, vendor_value: bytes
|
||||||
|
) -> List[Attribute]:
|
||||||
|
"""Parse A Vendor Attribute"""
|
||||||
|
if len(vendor_value) < 4:
|
||||||
|
raise PacketError
|
||||||
|
vendor_id = int.from_bytes(vendor_value[:4], "big")
|
||||||
|
vendor_dict = rad_dict.vendor_by_id[vendor_id]
|
||||||
|
vendor_name = vendor_dict.name
|
||||||
|
|
||||||
|
attributes = []
|
||||||
|
vendor_tlv = vendor_value[4:]
|
||||||
|
while vendor_tlv:
|
||||||
|
try:
|
||||||
|
(key, length) = struct.unpack("!BB", vendor_tlv[0:2])
|
||||||
|
except struct.error:
|
||||||
|
attribute = [
|
||||||
|
Attribute(
|
||||||
|
name=f"Unknown-{vendor_name}-Attribute",
|
||||||
|
pos=offset - len(vendor_value),
|
||||||
|
type="octets",
|
||||||
|
length=len(vendor_value) - 4,
|
||||||
|
value=vendor_value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
offset = offset - len(vendor_tlv) + length
|
||||||
|
key = parse_key(vendor_dict, key)
|
||||||
|
attribute = parse_value(vendor_dict, key, offset, vendor_tlv)
|
||||||
|
attributes.extend(attribute)
|
||||||
|
vendor_tlv = vendor_tlv[length:]
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
|
||||||
|
def parse_key(rad_dict: Dictionary, key_id: int) -> Union[str, int]:
|
||||||
|
"""Parse the key in the Dictionary Context"""
|
||||||
|
try:
|
||||||
|
return rad_dict.attrs[key_id].name
|
||||||
|
except KeyError:
|
||||||
|
return key_id
|
||||||
|
|
||||||
|
|
||||||
|
def parse_value(
|
||||||
|
rad_dict: Dictionary, key: Union[str, int], offset: int, raw_value: bytes
|
||||||
|
) -> List[Attribute]:
|
||||||
|
"""Parse the Value in the given Key/Dictionary Context"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_authenticator(
|
||||||
|
secret: bytes, authenticator: bytes, raw_packet: bytes
|
||||||
|
) -> bytes:
|
||||||
|
"""Calculate the Authenticator for the RADIUS Packet"""
|
||||||
|
return MD5(
|
||||||
|
raw_packet[0:4] + authenticator + raw_packet[20:] + secret
|
||||||
|
).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_pap_password(
|
||||||
|
secret: bytes,
|
||||||
|
authenticator: bytes,
|
||||||
|
obfuscated_password: bytes,
|
||||||
|
plaintext_password: bytes,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the plaintext and the RADIUS passwords match
|
||||||
|
This function does not "decrypt" the received password.
|
||||||
|
"""
|
||||||
|
obf_pass = password_encode(secret, authenticator, plaintext_password)
|
||||||
|
return obfuscated_password == obf_pass
|
||||||
|
|
||||||
|
|
||||||
|
def password_encode(
|
||||||
|
secret: bytes, authenticator: bytes, password: bytes
|
||||||
|
) -> bytes:
|
||||||
|
"""Obfuscate the plaintext Password for RADIUS"""
|
||||||
|
password += b"\x00" * (16 - (len(password) % 16))
|
||||||
|
return obfuscation_algorithm(secret, authenticator, password)
|
||||||
|
|
||||||
|
|
||||||
|
def password_decode(
|
||||||
|
secret: bytes, authenticator: bytes, obfuscated_password: bytes
|
||||||
|
) -> str:
|
||||||
|
"""Reverse the RADIUS obfuscation on a given password
|
||||||
|
|
||||||
|
The password password is padded with \\x00 to a 16 byte boundary. The padding will
|
||||||
|
be removed by this function.
|
||||||
|
If the original password had some trailing \\x00 it will get lost. Therefore it is
|
||||||
|
not recommended to use (trailing) \\x00 in passwords.
|
||||||
|
"""
|
||||||
|
deobfuscated = obfuscation_algorithm(
|
||||||
|
secret, authenticator, obfuscated_password
|
||||||
|
)
|
||||||
|
return deobfuscated.rstrip(b"\x00").decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def obfuscation_algorithm(
|
||||||
|
secret: bytes, authenticator: bytes, password: bytes
|
||||||
|
) -> bytes:
|
||||||
|
"""Obfuscate the plaintext password.
|
||||||
|
|
||||||
|
This function does not deal with the padding (which the
|
||||||
|
RADIUS Protocol requires.)
|
||||||
|
The User has to pad the password themself, or better use
|
||||||
|
the `password_encode` or `password_decode` function.
|
||||||
|
"""
|
||||||
|
result = b""
|
||||||
|
buf = password
|
||||||
|
last = authenticator
|
||||||
|
|
||||||
|
while buf:
|
||||||
|
cur_hash = MD5(secret + last).digest()
|
||||||
|
for cbuf, chash in zip(buf, cur_hash):
|
||||||
|
result += bytes([cbuf ^ chash])
|
||||||
|
(last, buf) = (buf[:16], buf[16:])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def validate_chap_password(
|
||||||
|
chap_id: bytes,
|
||||||
|
challenge: bytes,
|
||||||
|
chap_password: bytes,
|
||||||
|
plaintext_password: bytes,
|
||||||
|
) -> bool:
|
||||||
|
"""Validate the CHAP password against the given plaintext password"""
|
||||||
|
return (
|
||||||
|
chap_password == MD5(chap_id + plaintext_password + challenge).digest()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def salt_encrypt(secret: bytes, authenticator: bytes, value: bytes) -> bytes:
|
||||||
|
"""Salt Encrypt the given value"""
|
||||||
|
# The highest bit MUST be 1
|
||||||
|
random_value = RANDOM_GENERATOR.randrange(32768, 65535)
|
||||||
|
salt = struct.pack("!H", random_value)
|
||||||
|
|
||||||
|
salted_auth = authenticator + salt
|
||||||
|
|
||||||
|
return obfuscation_algorithm(secret, salted_auth, value)
|
||||||
|
|
||||||
|
|
||||||
|
def salt_decrypt(
|
||||||
|
secret: bytes, authenticator: bytes, salt: bytes, encrypted_value: bytes
|
||||||
|
) -> bytes:
|
||||||
|
"""Decrypt the given value"""
|
||||||
|
salted_auth = authenticator + salt
|
||||||
|
return obfuscation_algorithm(secret, salted_auth, encrypted_value)
|
||||||
1
tests/dictionaries/mutual_recursive
Normal file
1
tests/dictionaries/mutual_recursive
Normal file
@@ -0,0 +1 @@
|
|||||||
|
$INCLUDE ./other_mutual_recursive
|
||||||
1
tests/dictionaries/other_mutual_recursive
Normal file
1
tests/dictionaries/other_mutual_recursive
Normal file
@@ -0,0 +1 @@
|
|||||||
|
$INCLUDE mutual_recursive
|
||||||
1
tests/dictionaries/self_recursive
Normal file
1
tests/dictionaries/self_recursive
Normal file
@@ -0,0 +1 @@
|
|||||||
|
$INCLUDE ./self_recursive
|
||||||
190
tests/test_dictionary.py
Normal file
190
tests/test_dictionary.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
# Copyright 2020 Istvan Ruzman
|
||||||
|
# SPDX-License-Identifier: MIT OR Apache-2.0
|
||||||
|
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pyrad3.dictionary import Dictionary, ParseError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"filename", ["dictionaries/self_recursive", "dictionaries/mutual_recursive"]
|
||||||
|
)
|
||||||
|
def test_dictionary_recursion(filename):
|
||||||
|
with pytest.raises(ParseError):
|
||||||
|
Dictionary("tests/" + filename)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"line",
|
||||||
|
[
|
||||||
|
"$INCLUDE",
|
||||||
|
"BEGIN-VENDOR",
|
||||||
|
"END-VENDOR",
|
||||||
|
"VENDOR",
|
||||||
|
"VENDOR NAME",
|
||||||
|
"ATTRIBUTE",
|
||||||
|
"ATTRIBUTE NAME",
|
||||||
|
"VALUE",
|
||||||
|
"VALUE ATTRNAME",
|
||||||
|
"VALUE ATTRNAME VALUENAME",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_lines_missing_tokens(line):
|
||||||
|
dictionary = StringIO(line)
|
||||||
|
with pytest.raises(ParseError):
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"vendor",
|
||||||
|
[
|
||||||
|
"VENDOR test 1234",
|
||||||
|
"VENDOR test 1234 format=1,1",
|
||||||
|
"VENDOR test 1234 format=2,2",
|
||||||
|
"VENDOR test 1234 format=1,2",
|
||||||
|
"VENDOR test 1234 format=4,2",
|
||||||
|
"VENDOR test 1234 format=4,0",
|
||||||
|
"VENDOR WiMAX 1234 format=1,1,c",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_valid_vendor_definitions(vendor):
|
||||||
|
dictionary = StringIO(vendor)
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"vendor",
|
||||||
|
[
|
||||||
|
"VENDOR test 1234 1,1",
|
||||||
|
"VENDOR test 1234 format=3,1",
|
||||||
|
"VENDOR test 1234 format=2",
|
||||||
|
"VENDOR test 1234 format=1,2,c",
|
||||||
|
"VENDOR test 1234 format=1,9",
|
||||||
|
"VENDOR test 1234 format=4,4 suffix",
|
||||||
|
"VENDOR test 1234 format=a,b suffix",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_vendor_definitions(vendor):
|
||||||
|
dictionary = StringIO(vendor)
|
||||||
|
with pytest.raises(ParseError):
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"number",
|
||||||
|
[
|
||||||
|
"ATTRIBUTE NAME 0x01 byte",
|
||||||
|
"ATTRIBUTE NAME 0x0001 byte",
|
||||||
|
"ATTRIBUTE NAME 0o123 byte",
|
||||||
|
"ATTRIBUTE NAME 5 byte",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_valid_attribute_numbers(number):
|
||||||
|
dictionary = StringIO(number)
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"number",
|
||||||
|
[
|
||||||
|
"ATTRIBUTE NAME 1234 byte",
|
||||||
|
"ATTRIBUTE NAME ABCD byte",
|
||||||
|
"ATTRIBUTE NAME -1 byte",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_invalid_attribute_numbers(number):
|
||||||
|
dictionary = StringIO(number)
|
||||||
|
with pytest.raises(ParseError):
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("type_length", [1, 2, 4])
|
||||||
|
def test_attribute_number_limits(type_length):
|
||||||
|
too_big = 2 ** (8 * type_length)
|
||||||
|
max_value = too_big - 1
|
||||||
|
dictionary = StringIO(
|
||||||
|
f"VENDOR TEST 1234 format={type_length},1\n"
|
||||||
|
"BEGIN-VENDOR TEST\n"
|
||||||
|
f"ATTRIBUTE TEST {max_value} byte\n"
|
||||||
|
"END-VENDOR TEST\n"
|
||||||
|
)
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
dictionary = StringIO(
|
||||||
|
f"VENDOR TEST 1234 format={type_length},1\n"
|
||||||
|
"BEGIN-VENDOR TEST\n"
|
||||||
|
f"ATTRIBUTE TEST {too_big} byte\n"
|
||||||
|
"END-VENDOR TEST\n"
|
||||||
|
)
|
||||||
|
with pytest.raises(ParseError):
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value_definition",
|
||||||
|
[
|
||||||
|
"VALUE TEST-ATTRIBUTE TEST-VALUE 1",
|
||||||
|
"VALUE TEST-ATTRIBUTE TEST-VALUE 0x1",
|
||||||
|
"VALUE TEST-ATTRIBUTE TEST-VALUE 0o1",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_value_definition(value_definition):
|
||||||
|
dictionary = StringIO(
|
||||||
|
"\n".join(["ATTRIBUTE TEST-ATTRIBUTE 1 byte", value_definition])
|
||||||
|
)
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value_num, attr_type",
|
||||||
|
[
|
||||||
|
(0, "byte"),
|
||||||
|
(255, "byte"),
|
||||||
|
(0, "short"),
|
||||||
|
(2 ** 16 - 1, "short"),
|
||||||
|
(0, "integer"),
|
||||||
|
(2 ** 32 - 1, "integer"),
|
||||||
|
((-(2 ** 31)), "signed"),
|
||||||
|
(2 ** 31 - 1, "signed"),
|
||||||
|
(0, "integer64"),
|
||||||
|
(2 ** 64 - 1, "integer64"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_value_number_within_limit(value_num, attr_type):
|
||||||
|
dictionary = StringIO(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
f"ATTRIBUTE TEST-ATTRIBUTE 1 {attr_type}",
|
||||||
|
f"VALUE TEST-ATTRIBUTE TEST-VALUE {value_num}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
Dictionary("", dictionary)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value_num, attr_type",
|
||||||
|
[
|
||||||
|
(-1, "byte"),
|
||||||
|
(256, "byte"),
|
||||||
|
(-1, "short"),
|
||||||
|
(2 ** 16, "short"),
|
||||||
|
(-1, "integer"),
|
||||||
|
(2 ** 32, "integer"),
|
||||||
|
(2 ** 31, "signed"),
|
||||||
|
((-(2 ** 31)) - 1, "signed"),
|
||||||
|
(-1, "integer64"),
|
||||||
|
(2 ** 64, "integer64"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_value_number_out_of_limit(value_num, attr_type):
|
||||||
|
dictionary = StringIO(
|
||||||
|
"\n".join(
|
||||||
|
[
|
||||||
|
f"ATTRIBUTE TEST-ATTRIBUTE 1 {attr_type}",
|
||||||
|
f"VALUE TEST-ATTRIBUTE TEST-VALUE {value_num}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
with pytest.raises(ParseError):
|
||||||
|
Dictionary("", dictionary)
|
||||||
Reference in New Issue
Block a user