1
0
mirror of https://github.com/checktheroads/hyperglass synced 2024-05-11 05:55:08 +00:00
2023-04-13 23:05:05 -04:00

187 lines
6.4 KiB
Python

"""Interact with redis for state management."""
# Standard Library
import pickle
import typing as t
from types import TracebackType
from typing import overload
from datetime import datetime, timedelta
# Project
from hyperglass.log import log
from hyperglass.exceptions.private import StateError
if t.TYPE_CHECKING:
# Third Party
from redis import Redis
from redis.client import Pipeline
class RedisManager:
"""Convenience wrapper for managing a redis session."""
instance: "Redis"
namespace: str
def __init__(self, instance: "Redis", namespace: str) -> None:
"""Set up Redis connection and add configuration objects."""
self.instance = instance
self.namespace = namespace
def __repr__(self) -> str:
"""Alias repr to Redis instance's repr."""
return repr(self.instance)
def __str__(self) -> str:
"""String-friendly redis manager."""
return repr(self)
def _key_join(self, *keys: str) -> str:
"""Format keys with state namespace."""
key_in_parts = (k for key in keys for k in key.split("."))
key_parts = list(dict.fromkeys((*self.namespace.split("."), *key_in_parts)))
return ".".join(key_parts)
def key(self, key: t.Union[str, t.Sequence[str]]) -> str:
"""Format keys with state namespace."""
if isinstance(key, (t.List, t.Tuple, t.Generator)):
return self._key_join(*key)
return self._key_join(key)
def check(self) -> bool:
"""Ensure the redis instance is running and reachable."""
result = self.instance.ping()
if result is False:
raise RuntimeError(
"Redis instance {!r} is not running or reachable".format(self.instance)
)
return result
def delete(self, key: t.Union[str, t.Sequence[str]]) -> None:
"""Delete a key and value from the cache."""
self.instance.delete(self.key(key))
def expire(
self,
key: t.Union[str, t.Sequence[str]],
*,
expire_in: t.Optional[t.Union[timedelta, int]] = None,
expire_at: t.Optional[t.Union[datetime, int]] = None,
) -> None:
"""Expire a cache key, either at a time, or in a number of seconds.
If no at or in time is specified, the key is deleted.
"""
key = self.key(key)
if isinstance(expire_at, (datetime, int)):
self.instance.expireat(key, expire_at)
return
if isinstance(expire_in, (timedelta, int)):
self.instance.expire(key, expire_in)
return
self.instance.delete(key)
def get(
self,
key: t.Union[str, t.Sequence[str]],
*,
raise_if_none: bool = False,
value_if_none: t.Any = None,
) -> t.Union[None, t.Any]:
"""Get and decode a value from the cache."""
name = self.key(key)
value: t.Optional[bytes] = self.instance.get(name)
if isinstance(value, bytes):
return pickle.loads(value) # noqa
if raise_if_none is True:
raise StateError("'{key}' ('{name}') does not exist in Redis store", key=key, name=name)
if value_if_none is not None:
return value_if_none
return None
def set(self, key: t.Union[str, t.Sequence[str]], value: t.Any) -> None:
"""Add an object to the cache."""
name = self.key(key)
self.instance.set(name, pickle.dumps(value))
@overload
def get_map(self, key: str, item: str) -> t.Any:
"""Get a single value from a Redis hash map (dict)."""
@overload
def get_map(self, key: str, item=None) -> t.Any:
"""Get a single value from a Redis hash map (dict)."""
def get_map(self, key: str, item: t.Optional[str] = None) -> t.Any:
"""Get a Redis hash map or hash map value."""
name = self.key(key)
if isinstance(item, str):
value = self.instance.hget(name, item)
else:
value = self.instance.hgetall(name)
if isinstance(value, bytes):
return pickle.loads(value) # noqa
return None
def set_map_item(self, key: str, item: str, value: t.Any) -> None:
"""Add a value to a hash map (dict)."""
name = self.key(key)
self.instance.hset(name, item, pickle.dumps(value))
def pipeline(self):
"""Enter a Redis Pipeline, but expose all the custom interaction methods."""
# Copy the base RedisManager and remove the pipeline method (this method).
ctx = type(
"RedisManagerExcludePipeline",
(RedisManager,),
{k: v for k, v in self.__dict__.items() if k != "pipeline"},
)
def nested_pipeline(*_, **__) -> None:
"""Ensure pipeline is never called from within pipeline."""
raise AttributeError("Cannot access pipeline from pipeline")
class RedisManagerPipeline(ctx):
"""Copy of RedisManager, but uses `Redis.pipeline` as the `instance`."""
parent: "Redis"
instance: "Pipeline"
pipeline: t.Any = nested_pipeline
def __init__(
pipeline_self, # noqa: N805 Avoid `self` namespace conflict
*,
parent: "Redis",
instance: "Pipeline",
namespace: str,
) -> None:
pipeline_self.parent = parent
super().__init__(instance=instance, namespace=namespace)
def __enter__(
pipeline_self: "RedisManagerPipeline", # noqa: N805 Avoid `self` namespace conflict
) -> "RedisManagerPipeline":
return pipeline_self
def __exit__(
pipeline_self: "RedisManagerPipeline", # noqa: N805 Avoid `self` namespace conflict
exc_type: t.Optional[t.Type[BaseException]] = None,
exc_value: t.Optional[BaseException] = None,
_: t.Optional[TracebackType] = None,
) -> None:
pipeline_self.instance.execute()
if exc_type is not None:
log.error(
"Error in pipeline {!r} from parent instance {!r}:\n{!s}",
pipeline_self,
pipeline_self.parent,
exc_value,
)
return RedisManagerPipeline(
parent=self.instance,
instance=self.instance.pipeline(),
namespace=self.namespace,
)