142 lines
4.6 KiB
Python
142 lines
4.6 KiB
Python
"""Google Cloud Storage result store backend for Celery."""
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime, timedelta
|
|
from os import getpid
|
|
from threading import RLock
|
|
|
|
from kombu.utils.encoding import bytes_to_str
|
|
from kombu.utils.functional import dictfilter
|
|
from kombu.utils.url import url_to_parts
|
|
|
|
from celery.exceptions import ImproperlyConfigured
|
|
|
|
from .base import KeyValueStoreBackend
|
|
|
|
try:
|
|
import requests
|
|
from google.cloud import storage
|
|
from google.cloud.storage import Client
|
|
from google.cloud.storage.retry import DEFAULT_RETRY
|
|
except ImportError:
|
|
storage = None
|
|
|
|
__all__ = ('GCSBackend',)
|
|
|
|
|
|
class GCSBackend(KeyValueStoreBackend):
|
|
"""Google Cloud Storage task result backend."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self._lock = RLock()
|
|
self._pid = getpid()
|
|
self._retry_policy = DEFAULT_RETRY
|
|
self._client = None
|
|
|
|
if not storage:
|
|
raise ImproperlyConfigured(
|
|
'You must install google-cloud-storage to use gcs backend'
|
|
)
|
|
conf = self.app.conf
|
|
if self.url:
|
|
url_params = self._params_from_url()
|
|
conf.update(**dictfilter(url_params))
|
|
|
|
self.bucket_name = conf.get('gcs_bucket')
|
|
if not self.bucket_name:
|
|
raise ImproperlyConfigured(
|
|
'Missing bucket name: specify gcs_bucket to use gcs backend'
|
|
)
|
|
self.project = conf.get('gcs_project')
|
|
if not self.project:
|
|
raise ImproperlyConfigured(
|
|
'Missing project:specify gcs_project to use gcs backend'
|
|
)
|
|
self.base_path = conf.get('gcs_base_path', '').strip('/')
|
|
self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10))
|
|
self.ttl = float(conf.get('gcs_ttl') or 0)
|
|
if self.ttl < 0:
|
|
raise ImproperlyConfigured(
|
|
f'Invalid ttl: {self.ttl} must be greater than or equal to 0'
|
|
)
|
|
elif self.ttl:
|
|
if not self._is_bucket_lifecycle_rule_exists():
|
|
raise ImproperlyConfigured(
|
|
f'Missing lifecycle rule to use gcs backend with ttl on '
|
|
f'bucket: {self.bucket_name}'
|
|
)
|
|
|
|
def get(self, key):
|
|
key = bytes_to_str(key)
|
|
blob = self._get_blob(key)
|
|
try:
|
|
return blob.download_as_bytes(retry=self._retry_policy)
|
|
except storage.blob.NotFound:
|
|
return None
|
|
|
|
def set(self, key, value):
|
|
key = bytes_to_str(key)
|
|
blob = self._get_blob(key)
|
|
if self.ttl:
|
|
blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl)
|
|
blob.upload_from_string(value, retry=self._retry_policy)
|
|
|
|
def delete(self, key):
|
|
key = bytes_to_str(key)
|
|
blob = self._get_blob(key)
|
|
if blob.exists():
|
|
blob.delete(retry=self._retry_policy)
|
|
|
|
def mget(self, keys):
|
|
with ThreadPoolExecutor() as pool:
|
|
return list(pool.map(self.get, keys))
|
|
|
|
@property
|
|
def client(self):
|
|
"""Returns a storage client."""
|
|
|
|
# make sure it's thread-safe, as creating a new client is expensive
|
|
with self._lock:
|
|
if self._client and self._pid == getpid():
|
|
return self._client
|
|
# make sure each process gets its own connection after a fork
|
|
self._client = Client(project=self.project)
|
|
self._pid = getpid()
|
|
|
|
# config the number of connections to the server
|
|
adapter = requests.adapters.HTTPAdapter(
|
|
pool_connections=self._threadpool_maxsize,
|
|
pool_maxsize=self._threadpool_maxsize,
|
|
max_retries=3,
|
|
)
|
|
client_http = self._client._http
|
|
client_http.mount("https://", adapter)
|
|
client_http._auth_request.session.mount("https://", adapter)
|
|
|
|
return self._client
|
|
|
|
@property
|
|
def bucket(self):
|
|
return self.client.bucket(self.bucket_name)
|
|
|
|
def _get_blob(self, key):
|
|
key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key
|
|
return self.bucket.blob(key_bucket_path)
|
|
|
|
def _is_bucket_lifecycle_rule_exists(self):
|
|
bucket = self.bucket
|
|
bucket.reload()
|
|
for rule in bucket.lifecycle_rules:
|
|
if rule['action']['type'] == 'Delete':
|
|
return True
|
|
return False
|
|
|
|
def _params_from_url(self):
|
|
url_parts = url_to_parts(self.url)
|
|
|
|
return {
|
|
'gcs_bucket': url_parts.hostname,
|
|
'gcs_base_path': url_parts.path,
|
|
**url_parts.query,
|
|
}
|