Skip to content

Commit ebf4528

Browse files
author
Amit Lasry
committed
store redis data encoded and decode on fetching when client requires
1 parent 05d391e commit ebf4528

File tree

1 file changed

+61
-35
lines changed

1 file changed

+61
-35
lines changed

mockredis/client.py

Lines changed: 61 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division
22
from collections import defaultdict
3+
from collections.abc import ValuesView, KeysView, ItemsView
34
from copy import deepcopy
45
from itertools import chain
56
from datetime import datetime, timedelta
@@ -175,7 +176,7 @@ def keys(self, pattern='*'):
175176
for key in self.redis.keys():
176177
decoded_key = key if isinstance(key, unicode) else key.decode('utf-8')
177178
if regex.match(decoded_key):
178-
keys.append(key)
179+
keys.append(decoded_key)
179180

180181
return keys
181182

@@ -301,11 +302,32 @@ def _rename(self, old_key, new_key, nx=False):
301302
def dbsize(self):
302303
return len(self.redis.keys())
303304

304-
# String Functions #
305+
def _decode(self, value):
306+
if value is None:
307+
return None
308+
309+
if self.decode_responses:
310+
if isinstance(value, (list, tuple, set)):
311+
value = type(value)(self._decode(v) for v in value)
312+
# dict.keys()
313+
elif isinstance(value, KeysView):
314+
value = set(self._decode(v) for v in value)
315+
# dict.values()
316+
elif isinstance(value, ValuesView):
317+
value = list(self._decode(v) for v in value)
318+
# dict.items()
319+
elif isinstance(value, ItemsView):
320+
value = list((self._decode(k),self._decode(v)) for k,v in value)
321+
elif isinstance(value, dict):
322+
value = type(value)((self._decode(k), self._decode(v)) for k,v in value.items())
323+
elif isinstance(value, (newbytes, bytes)):
324+
value = value.decode('utf-8', 'strict')
325+
326+
return value
305327

306328
def get(self, key):
307329
key = self._encode(key)
308-
return self.redis.get(key)
330+
return self._decode(self.redis.get(key))
309331

310332
def __getitem__(self, name):
311333
"""
@@ -518,12 +540,12 @@ def hget(self, hashkey, attribute):
518540
"""Emulate hget."""
519541

520542
redis_hash = self._get_hash(hashkey, 'HGET')
521-
return redis_hash.get(self._encode(attribute))
543+
return self._decode(redis_hash.get(self._encode(attribute)))
522544

523545
def hgetall(self, hashkey):
524546
"""Emulate hgetall."""
525547

526-
redis_hash = self._get_hash(hashkey, 'HGETALL')
548+
redis_hash = self._get_hash(hashkey, 'HGETALL', decode=True)
527549
return dict(redis_hash)
528550

529551
def hdel(self, hashkey, *keys):
@@ -559,7 +581,7 @@ def hmget(self, hashkey, keys, *args):
559581

560582
redis_hash = self._get_hash(hashkey, 'HMGET')
561583
attributes = self._list_or_args(keys, args)
562-
return [redis_hash.get(self._encode(attribute)) for attribute in attributes]
584+
return [self._decode(redis_hash.get(self._encode(attribute))) for attribute in attributes]
563585

564586
def hset(self, hashkey, attribute, value):
565587
"""Emulate hset."""
@@ -595,29 +617,29 @@ def _hincrby(self, hashkey, attribute, command, type_, increment):
595617
"""Shared hincrby and hincrbyfloat routine"""
596618
redis_hash = self._get_hash(hashkey, command, create=True)
597619
attribute = self._encode(attribute)
598-
previous_value = type_(redis_hash.get(attribute, '0'))
620+
previous_value = type_(self._decode(redis_hash.get(attribute, '0')))
599621
redis_hash[attribute] = self._encode(previous_value + increment)
600622
return type_(redis_hash[attribute])
601623

602624
def hkeys(self, hashkey):
603625
"""Emulate hkeys."""
604626

605627
redis_hash = self._get_hash(hashkey, 'HKEYS')
606-
return redis_hash.keys()
628+
return self._decode(redis_hash.keys())
607629

608630
def hvals(self, hashkey):
609631
"""Emulate hvals."""
610632

611633
redis_hash = self._get_hash(hashkey, 'HVALS')
612-
return redis_hash.values()
634+
return self._decode(redis_hash.values())
613635

614636
# List Functions #
615637

616638
def lrange(self, key, start, stop):
617639
"""Emulate lrange."""
618640
redis_list = self._get_list(key, 'LRANGE')
619641
start, stop = self._translate_range(len(redis_list), start, stop)
620-
return redis_list[start:stop + 1]
642+
return self._decode(redis_list[start:stop + 1])
621643

622644
def lindex(self, key, index):
623645
"""Emulate lindex."""
@@ -628,7 +650,7 @@ def lindex(self, key, index):
628650
return None
629651

630652
try:
631-
return redis_list[index]
653+
return self._decode(redis_list[index])
632654
except (IndexError):
633655
# Redis returns nil if the index doesn't exist
634656
return None
@@ -668,7 +690,7 @@ def _pop_first_available(self, pop_func, keys):
668690
for key in keys:
669691
val = pop_func(key)
670692
if val:
671-
return self._encode(key), val
693+
return self._decode(key), self._decode(val)
672694
return None, None
673695

674696
def blpop(self, keys, timeout=0):
@@ -715,7 +737,7 @@ def rpop(self, key):
715737
value = redis_list.pop()
716738
if len(redis_list) == 0:
717739
self.delete(key)
718-
return value
740+
return self._decode(value)
719741
except (IndexError):
720742
# Redis returns nil if popping from an empty list
721743
return None
@@ -1025,7 +1047,7 @@ def sismember(self, name, value):
10251047

10261048
def smembers(self, name):
10271049
"""Emulate smembers."""
1028-
return self._get_set(name, 'SMEMBERS').copy()
1050+
return self._get_set(name, 'SMEMBERS', decode=True).copy()
10291051

10301052
def smove(self, src, dst, value):
10311053
"""Emulate smove."""
@@ -1050,11 +1072,11 @@ def spop(self, name):
10501072
redis_set.remove(member)
10511073
if len(redis_set) == 0:
10521074
self.delete(name)
1053-
return member
1075+
return self._decode(member)
10541076

10551077
def srandmember(self, name, number=None):
10561078
"""Emulate srandmember."""
1057-
redis_set = self._get_set(name, 'SRANDMEMBER')
1079+
redis_set = self._get_set(name, 'SRANDMEMBER',decode=True)
10581080
if not redis_set:
10591081
return None if number is None else []
10601082
if number is None:
@@ -1149,7 +1171,7 @@ def zrange(self, name, start, end, desc=False, withscores=False,
11491171

11501172
start, end = self._translate_range(len(zset), start, end)
11511173

1152-
func = self._range_func(withscores, score_cast_func)
1174+
func = self._range_func(withscores, score_cast_func, decode_value_func=self._decode)
11531175
return [func(item) for item in zset.range(start, end, desc)]
11541176

11551177
def zrangebyscore(self, name, min, max, start=None, num=None,
@@ -1162,7 +1184,7 @@ def zrangebyscore(self, name, min, max, start=None, num=None,
11621184
if not zset:
11631185
return []
11641186

1165-
func = self._range_func(withscores, score_cast_func)
1187+
func = self._range_func(withscores, score_cast_func, decode_value_func=self._decode)
11661188
include_start, min = self._score_inclusive(min)
11671189
include_end, max = self._score_inclusive(max)
11681190
scorerange = zset.scorerange(min, max, start_inclusive=include_start, end_inclusive=include_end) # noqa
@@ -1234,7 +1256,7 @@ def zrevrangebyscore(self, name, max, min, start=None, num=None,
12341256
if not zset:
12351257
return []
12361258

1237-
func = self._range_func(withscores, score_cast_func)
1259+
func = self._range_func(withscores, score_cast_func, decode_value_func=self._decode)
12381260
include_start, min = self._score_inclusive(min)
12391261
include_end, max = self._score_inclusive(max)
12401262

@@ -1420,40 +1442,46 @@ def publish(self, channel, message):
14201442

14211443
# Internal #
14221444

1423-
def _get_list(self, key, operation, create=False):
1445+
def _get_list(self, key, operation, create=False, decode=False):
14241446
"""
14251447
Get (and maybe create) a list by name.
14261448
"""
1427-
return self._get_by_type(key, operation, create, b'list', [])
1449+
return self._get_by_type(key, operation, create, b'list', [], decode=decode)
14281450

1429-
def _get_set(self, key, operation, create=False):
1451+
def _get_set(self, key, operation, create=False, decode=False):
14301452
"""
14311453
Get (and maybe create) a set by name.
14321454
"""
1433-
return self._get_by_type(key, operation, create, b'set', set())
1455+
return self._get_by_type(key, operation, create, b'set', set(), decode=decode)
14341456

1435-
def _get_hash(self, name, operation, create=False):
1457+
def _get_hash(self, name, operation, create=False, decode=False):
14361458
"""
14371459
Get (and maybe create) a hash by name.
14381460
"""
1439-
return self._get_by_type(name, operation, create, b'hash', {})
1461+
return self._get_by_type(name, operation, create, b'hash', {}, decode=decode)
14401462

1441-
def _get_zset(self, name, operation, create=False):
1463+
def _get_zset(self, name, operation, create=False, decode=False):
14421464
"""
14431465
Get (and maybe create) a sorted set by name.
14441466
"""
1445-
return self._get_by_type(name, operation, create, b'zset', SortedSet(), return_default=False) # noqa
1467+
return self._get_by_type(name, operation, create, b'zset', SortedSet(), return_default=False, decode=decode) # noqa
14461468

1447-
def _get_by_type(self, key, operation, create, type_, default, return_default=True):
1469+
def _get_by_type(self, key, operation, create, type_, default, return_default=True, decode=False):
14481470
"""
14491471
Get (and maybe create) a redis data structure by name and type.
14501472
"""
14511473
key = self._encode(key)
14521474
if self.type(key) in [type_, b'none']:
14531475
if create:
1454-
return self.redis.setdefault(key, default)
1476+
val = self.redis.setdefault(key, default)
1477+
if decode:
1478+
val = self._decode(val)
1479+
return val
14551480
else:
1456-
return self.redis.get(key, default if return_default else None)
1481+
val = self.redis.get(key, default if return_default else None)
1482+
if decode:
1483+
val = self._decode(val)
1484+
return val
14571485

14581486
raise TypeError("{} requires a {}".format(operation, type_))
14591487

@@ -1479,14 +1507,14 @@ def _translate_limit(self, len_, start, num):
14791507
return 0, 0
14801508
return min(start, len_), num
14811509

1482-
def _range_func(self, withscores, score_cast_func):
1510+
def _range_func(self, withscores, score_cast_func, decode_value_func=lambda x: x):
14831511
"""
14841512
Return a suitable function from (score, member)
14851513
"""
14861514
if withscores:
1487-
return lambda score_member: (score_member[1], score_cast_func(self._encode(score_member[0]))) # noqa
1515+
return lambda score_member: (decode_value_func(score_member[1]), score_cast_func(self._encode(score_member[0]))) # noqa
14881516
else:
1489-
return lambda score_member: score_member[1]
1517+
return lambda score_member: decode_value_func(score_member[1])
14901518

14911519
def _aggregate_func(self, aggregate):
14921520
"""
@@ -1545,8 +1573,6 @@ def _encode(self, value):
15451573
else:
15461574
value = value.encode('utf-8', 'strict')
15471575

1548-
if self.decode_responses:
1549-
return value.decode('utf-8', 'strict')
15501576
return value
15511577

15521578
def _log(self, level, msg):

0 commit comments

Comments
 (0)