Issue
I have no idea if this kind of lock is called time lock, but I need something for the following scenario: I'm making a lot of concurrent requests with aiohttp
, and it's possible that the server at some point returns 429 Too Many Requests
. In that instance, I have to pause all my subsequent requests for some time.
I came up with the following solution:
import asyncio
class TimeLock:
def __init__(self, *, loop=None):
self._locked = False
self._locked_at = None
self._time_lock = None
self._unlock_task = None
self._num_waiters = 0
if loop is not None:
self._loop = loop
else:
self._loop = asyncio.get_event_loop()
def __repr__(self):
state = f'locked at {self.locked_at}' if self._locked else 'unlocked'
return f'[{state}] {self._num_waiters} waiters'
@property
def locked(self):
return self._locked
@property
def locked_at(self):
return self._locked_at
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, exc_type, exc, tb):
# in this time lock there is nothing to do when it's released
return
async def acquire(self):
if not self._locked:
return True
try:
print('waiting for lock to be released')
self._num_waiters += 1
await self._time_lock
self._num_waiters -= 1
print('done, returning now')
except asyncio.CancelledError:
if self._locked:
raise
return True
def lock_for(self, delay, lock_more=False):
print(f'locking for {delay}')
if self._locked:
if not lock_more:
# if we don't want to increase the lock time, we just exit when
# the lock is already in a locked state
print('already locked, nothing to do')
return
print('already locked, but canceling old unlock task')
self._unlock_task.cancel()
self._locked = True
self._locked_at = time.time()
self._time_lock = self._loop.create_future()
self._unlock_task = self._loop.create_task(self.unlock_in(delay))
print('locked')
async def unlock_in(self, delay):
print('unlocking started')
await asyncio.sleep(delay)
self._locked = False
self._locked_at = None
self._unlock_task = None
self._time_lock.set_result(True)
print('unlocked')
I am testing the lock with this code:
import asyncio
from ares.http import TimeLock
async def run(lock, i):
async with lock:
print(lock)
print(i)
if i in (3, 6, 9):
lock.lock_for(2)
if __name__ == '__main__':
lock = TimeLock()
tasks = []
loop = asyncio.get_event_loop()
for i in range(10):
tasks.append(run(lock, i))
loop.run_until_complete(asyncio.gather(*tasks))
print(lock)
The code produces the following output, which seems to be consistent with what I want from the above scenario:
[unlocked] 0 waiters
0
[unlocked] 0 waiters
1
[unlocked] 0 waiters
2
[unlocked] 0 waiters
3
locking for 2
locked
waiting for lock to be released
waiting for lock to be released
waiting for lock to be released
waiting for lock to be released
waiting for lock to be released
waiting for lock to be released
unlocking started
unlocked
done, returning now
[unlocked] 5 waiters
4
done, returning now
[unlocked] 4 waiters
5
done, returning now
[unlocked] 3 waiters
6
locking for 2
locked
done, returning now
[locked at 1559496296.7109463] 2 waiters
7
done, returning now
[locked at 1559496296.7109463] 1 waiters
8
done, returning now
[locked at 1559496296.7109463] 0 waiters
9
locking for 2
already locked, nothing to do
unlocking started
[locked at 1559496296.7109463] 0 waiters
Is this the proper way to implement this synchronization primitive? I am also not sure about the thread-safety of this code. I don't have too much experience with threads and asyncio code.
Solution
I didn't test your code, but idea seems to be fine. You should worry about thread-safety only if you're going to use same lock object in different threads. As Jimmy Engelbrecht already noted asyncio runs in single thread and you usually don't have to worry about thread-safety of primitives.
Here's few more thoughts:
- I'm note sure about terminology, but it seems this primitive should be called semaphore
- Instead of implementing it from the begging you can inherit or just use existing primitive(s)
- You can delegate to semaphore tracking of events when if should pause instead of doing it inside client code
This code snippet shows the idea:
import asyncio
class PausingSemaphore:
def __init__(self, should_pause, pause_for_seconds):
self.should_pause = should_pause
self.pause_for_seconds = pause_for_seconds
self._is_paused = False
self._resume = asyncio.Event()
async def __aenter__(self):
await self.check_paused()
return self
async def __aexit__(self, exc_type, exc, tb):
if self.should_pause(exc):
self.pause()
async def check_paused(self):
if self._is_paused:
await self._resume.wait()
def pause(self):
if not self._is_paused:
self._is_paused = True
asyncio.get_running_loop().call_later(
self.pause_for_seconds,
self.unpause
)
def unpause(self):
self._is_paused = False
self._resume.set()
Let's test it:
import aiohttp
def should_pause(exc):
return (
type(exc) is aiohttp.ClientResponseError
and
exc.status == 429
)
pausing_sem = None
regular_sem = None
async def request(url):
async with regular_sem:
async with pausing_sem:
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, raise_for_status=True) as resp:
print('Done!')
except aiohttp.ClientResponseError:
print('Too many requests!')
raise
async def main():
global pausing_sem
global regular_sem
pausing_sem = PausingSemaphore(should_pause, 5)
regular_sem = asyncio.Semaphore(3)
await asyncio.gather(
*[
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/status/429'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
request('http://httpbin.org/get'),
],
return_exceptions=True
)
if __name__ == '__main__':
asyncio.run(main())
P.S. Didn't test this code much!
Answered By - Mikhail Gerasimov
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.