Handle timeouts & max requests in Python test class

This commit is contained in:
Grégory Soutadé 2020-05-11 10:07:10 +02:00
parent 13ef026003
commit ea4bfda214
1 changed files with 27 additions and 13 deletions

View File

@ -45,16 +45,20 @@ class IPToGeo(object):
4 : 'Bad IP version',
5 : 'Unsupported IP version',
6 : 'IP not found'}
MAX_REQUESTS = 50
def __init__(self, remote_addr='127.0.0.1', remote_port=53333, timeout=None, family=socket.AF_INET):
self._remote_addr = remote_addr
self._remote_port = remote_port
self._timeout = timeout
self._family = family
self._create_socket()
self._nb_requests_sent = self.MAX_REQUESTS # Force socket creation
self._socket = None
def _create_socket(self):
if self._socket:
self._socket.close()
self._socket = socket.socket(self._family, socket.SOCK_STREAM)
if not self._timeout is None:
self._socket.settimeout(self._timeout)
@ -108,7 +112,25 @@ class IPToGeo(object):
(cc0, cc1, cc2, cc3) = struct.unpack_from('BBBB', packet, 7*4)
return (ip_res, '%c%c%c%c' % (cc0, cc1, cc2, cc3))
def _send_request(self, packet, second_chance=True):
self._nb_requests_sent += 1
if self._nb_requests_sent >= self.MAX_REQUESTS:
self._create_socket()
self._nb_requests_sent = 0
try:
self._socket.send(packet)
packet = self._socket.recv(IPToGeo.PACKET_SIZE)
if not packet:
raise socket.timeout
return packet
except socket.timeout, e:
if second_chance:
self._nb_requests_sent = self.MAX_REQUESTS
return self._send_request(packet, False)
else:
raise e
def ip_to_geo(self, ip):
ip_type = IPToGeo.IPV4
if ip.find('.') >= 0:
@ -124,15 +146,7 @@ class IPToGeo(object):
raise Exception('Bad IP %s' % (ip))
packet = self._create_request(splitted_ip, ip_type)
try:
self._socket.send(packet)
except IOError, e:
# Give another chance (we may have been disconnected due to timeout)
self._create_socket()
self._socket.send(packet)
packet = self._socket.recv(IPToGeo.PACKET_SIZE)
if not packet:
raise IPToGeoException('Error, empty packet')
packet = self._send_request(packet)
(ip, country_code) = self._check_request(packet)
if country_code:
# convert to string