🏷️sec_key_value
KVStore is a place for data sharing. Think of it as a single object shared across different devices (GPUs and computers), where each device can push data in and pull data out.
Let us consider a simple example: initializing a (int, NDArray) pair into the store, and then pulling the value out:
#@tab mxnet
from mxnet import np, npx, kv
npx.set_np()
#@tab mxnet
np.ones((2,3))
#@tab mxnet
help(kv)
#@tab mxnet
kv = kv.create('local') # Create a local kv store.
shape = (2,3)
kv.init(3, np.ones(shape) * 2)
a = np.zeros(shape)
kv.pull(3, out = a)
print(a)
For any key that has been initialized, you can push a new value with the same shape to the key:
#@tab mxnet
kv.push(3, np.ones(shape)*8)
kv.pull(3, out = a) # Pull out the value
print(a.asnumpy())
The data for pushing can be stored on any device. Furthermore, you can push multiple values into the same key, where KVStore will first sum all of these values and then push the aggregated value. Here we will just demonstrate pushing a list of values on CPU. Please note summation only happens if the value list is longer than one
#@tab mxnet
devices = [npx.cpu(i) for i in range(4)]
b = [np.ones(shape, ctx=device) for device in devices]
kv.push(3, b)
kv.pull(3, out = a)
print(a)
For each push, KVStore combines the pushed value with the value stored using an updater. The default updater is ASSIGN. You can replace the default to control how data is merged:
#@tab mxnet
def update(key, input, stored):
print(f'update on key: {key}')
stored += input * 2
kv._set_updater(update)
kv.pull(3, out=a)
print(a)
#@tab mxnet
kv.push(3, np.ones(shape))
kv.pull(3, out=a)
print(a)
You have already seen how to pull a single key-value pair. Similarly, to push, you can pull the value onto several devices with a single call:
#@tab mxnet
b = [np.ones(shape, ctx=device) for device in devices]
kv.pull(3, out = b)
print(b[1])
All operations introduced so far involve a single key. KVStore also provides an interface for a list of key-value pairs.
For a single device:
#@tab mxnet
keys = [5, 7, 9]
kv.init(keys, [np.ones(shape)]*len(keys))
kv.push(keys, [np.ones(shape)]*len(keys))
b = [np.zeros(shape)]*len(keys)
kv.pull(keys, out = b)
print(b[1])
For multiple devices:
#@tab mxnet
b = [[np.ones(shape, ctx=device) for device in devices]] * len(keys)
kv.push(keys, b)
kv.pull(keys, out = b)
print(b[1][1])