Skip to content

Standalone Product Key Memory module in Pytorch - for augmenting Transformer models

License

Notifications You must be signed in to change notification settings

lucidrains/product-key-memory

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Product Key Memory

PyPI version

Standalone Product Key Memory module for augmenting Transformer models

Install

$ pip install product-key-memory

Usage

Replace the feedforwards in a Transformer with the following

import torch
from product_key_memory import PKM

pkm = PKM(
    dim = 512,
    heads = 4,
    dim_head = 256,       # keep at 256 for best results
    num_keys = 256,       # number of subkeys, # values will be num_keys ^ 2
    topk = 32,            # the top number of subkeys to select
    use_evonorm = True    # usually PKM requires decent batch sizes with batchnorm to work well. this is an experimental feature using the new evonorm-s0 for batch-independent normalization
)

x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()
values = pkm(x, input_mask = mask) # (1, 1024, 512)

Learning Rates

To give different learning rates to the value parameters of the product-key-memory network, use the following helper function.

from torch.optim import Adam
from product_key_memory import fetch_pkm_value_parameters

# this helper function, for your root model, finds all the PKM models and the embedding bag weight parameters
pkm_parameters, other_parameters = fetch_pkm_value_parameters(model)

optim = Adam([
    {'params': other_parameters},
    {'params': pkm_parameters, 'lr': 1e-2}
], lr=1e-3)

Appreciation

Special thanks go to Aran for encouraging me to look into this, and to Madison May for his educational blog post, which helped me understand this better.

Citations

@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}

About

Standalone Product Key Memory module in Pytorch - for augmenting Transformer models

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages