From 4f8a4f334c59588223b6f1f24b707d7e8d5fe08c Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Thu, 7 Dec 2023 19:40:08 +0100 Subject: [PATCH] fix onmt converter (#1581) --- python/ctranslate2/converters/opennmt_py.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/ctranslate2/converters/opennmt_py.py b/python/ctranslate2/converters/opennmt_py.py index cbc40d7ea..ccd9ad417 100644 --- a/python/ctranslate2/converters/opennmt_py.py +++ b/python/ctranslate2/converters/opennmt_py.py @@ -104,9 +104,10 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd activation_fn = getattr(opt, "pos_ffn_activation_fn", "relu") num_heads = getattr(opt, "heads", 8) num_kv = getattr(opt, "num_kv", 0) - if num_kv == num_heads: + if num_kv == num_heads or num_kv == 0: num_kv = None rotary_dim = 0 if with_rotary else None + rotary_interleave = getattr(opt, "rotary_interleave", True) ffn_glu = activation_fn == "silu" sliding_window = getattr(opt, "sliding_window", 0) @@ -119,7 +120,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd alibi=with_alibi, rms_norm=opt.layer_norm == "rms", rotary_dim=rotary_dim, - rotary_interleave=True, + rotary_interleave=rotary_interleave, multi_query_attention=getattr(opt, "multiquery", False), num_heads_kv=num_kv, sliding_window=sliding_window, @@ -329,7 +330,7 @@ def set_linear(spec, variables, scope): spec.weight = _get_variable(variables, "%s.weight" % scope) bias = variables.get("%s.bias" % scope) if bias is not None: - spec.bias = bias.numpy() + spec.bias = bias def set_embeddings(spec, variables, scope): @@ -341,7 +342,7 @@ def set_position_encodings(spec, variables, scope): def _get_variable(variables, name): - return variables[name].numpy() + return variables[name] def main():