From 208d70d36fc11547212c1058b992530ee77c65e2 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 4 Jan 2021 21:31:40 -0800 Subject: [PATCH] fix bug with default empty memories for txl --- setup.py | 2 +- x_transformers/x_transformers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 617c4fa5..178fc54a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '0.7.1', + version = '0.7.2', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index c4144dbb..5d91ccba 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -437,7 +437,7 @@ def __init__( layer_types = default_block * depth self.layer_types = layer_types - self.default_mems = ([None] * len(list(filter(equals('a'), layer_types)))) + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) for layer_type in self.layer_types: if layer_type == 'a': @@ -472,7 +472,7 @@ def forward( prev_attn = None prev_cross_attn = None - mems = mems.copy() if exists(mems) else self.default_mems + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers for ind, (layer_type, (norm, block)) in enumerate(zip(self.layer_types, self.layers)): is_last = ind == (len(self.layers) - 1)