# -*- coding: utf-8 -*- import datetime import logging logger = logging.getLogger(__name__) class TokenBucket: def __init__(self, tokens_per_sec, max_token_num): self._tokens_per_sec = float(tokens_per_sec) self._max_token_num = float(max_token_num) self._stored_token_num = self._max_token_num self._last_update_time = datetime.datetime.now() if self._tokens_per_sec <= 0.0 and self._max_token_num >= 1.0: logger.warning('TokenBucket token_per_sec=%f <= 0, rate has no limit', tokens_per_sec) def try_decrease_token(self): if self._tokens_per_sec <= 0.0: # self._max_token_num < 1.0 时完全禁止 return self._max_token_num >= 1.0 cur_time = datetime.datetime.now() last_update_time = min(self._last_update_time, cur_time) # 防止时钟回拨 add_token_num = (cur_time - last_update_time).total_seconds() * self._tokens_per_sec self._stored_token_num = min(self._stored_token_num + add_token_num, self._max_token_num) self._last_update_time = cur_time if self._stored_token_num < 1.0: return False self._stored_token_num -= 1.0 return True