Bugfix : iptogeo : _send_request that must handles timeout errors (empty packets) was not used

This commit is contained in:
Grégory Soutadé 2020-05-11 10:03:07 +02:00
parent afc6f02181
commit 6f9622bb91
1 changed files with 18 additions and 20 deletions

View File

@ -53,11 +53,12 @@ class IPToGeo(object):
self._remote_port = remote_port
self._timeout = timeout
self._family = family
self._nb_requests_sent = 0
self._create_socket()
self._nb_requests_sent = self.MAX_REQUESTS # Force socket creation
self._socket = None
def _create_socket(self):
if self._socket:
self._socket.close()
self._socket = socket.socket(self._family, socket.SOCK_STREAM)
if not self._timeout is None:
self._socket.settimeout(self._timeout)
@ -112,19 +113,24 @@ class IPToGeo(object):
return (ip_res, '%c%c%c%c' % (cc0, cc1, cc2, cc3))
def _send_request(self, packet):
def _send_request(self, packet, second_chance=True):
self._nb_requests_sent += 1
if self._nb_requests_sent == self.MAX_REQUESTS:
self.close()
if self._nb_requests_sent >= self.MAX_REQUESTS:
self._create_socket()
self._nb_requests_sent = 0
try:
self._socket.send(packet)
except IOError, e:
# Give another chance (we may have been disconnected due to timeout)
self._create_socket()
self._socket.send(packet)
packet = self._socket.recv(IPToGeo.PACKET_SIZE)
if not packet:
raise socket.timeout
return packet
except socket.timeout, e:
if second_chance:
self._nb_requests_sent = self.MAX_REQUESTS
return self._send_request(packet, False)
else:
raise e
def ip_to_geo(self, ip):
ip_type = IPToGeo.IPV4
if ip.find('.') >= 0:
@ -140,15 +146,7 @@ class IPToGeo(object):
raise Exception('Bad IP %s' % (ip))
packet = self._create_request(splitted_ip, ip_type)
try:
self._socket.send(packet)
except IOError, e:
# Give another chance (we may have been disconnected due to timeout)
self._create_socket()
self._socket.send(packet)
packet = self._socket.recv(IPToGeo.PACKET_SIZE)
if not packet:
raise IPToGeoException('Error, empty packet')
packet = self._send_request(packet)
(ip, country_code) = self._check_request(packet)
if country_code:
# convert to string