Skip to content

Commit

Permalink
Added support for HMM/fertility priors
Browse files Browse the repository at this point in the history
  • Loading branch information
robertostling committed Oct 2, 2018
1 parent 309c565 commit 746f30a
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 126 deletions.
115 changes: 87 additions & 28 deletions align.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main():
help='Filename to write reverse direction alignments to')
parser.add_argument(
'-p', '--priors', dest='priors_filename', type=str, metavar='filename',
help='Filename of lexical priors')
help='File to read priors from')

args = parser.parse_args()

Expand Down Expand Up @@ -101,25 +101,45 @@ def main():
args.priors_filename,
file=sys.stderr, flush=True)

priors_list = []
priors_list = [] # list of (srcword, trgword, alpha)
ferf_priors = [] # list of (wordform, alpha)
ferr_priors = [] # list of (wordform, alpha)
hmmf_priors = {} # dict of jump: alpha
hmmr_priors = {} # dict of jump: alpha
with open(args.priors_filename, 'r', encoding='utf-8') as f:
# 5 types of lines valid:
#
# LEX srcword trgword alpha | lexical prior
# HMMF jump alpha | target-side HMM prior
# HMMR jump alpha | source-side HMM prior
# FERF srcword fert alpha | source-side fertility p.
# FERR trgword fert alpha | target-side fertility p.
for i, line in enumerate(f):
fields = line.rstrip('\n').split('\t')
if len(fields) != 3:
print('ERROR: priors file %s line %d contains %d '
'tab-separated fields but should have 3' % (
args.priors_filename, i+1, len(fields)),
file=sys.stderr, flush=True)
sys.exit(1)
try:
alpha = float(fields[2])
alpha = float(fields[-1])
except ValueError:
print('ERROR: priors file %s line %d contains alpha '
'value of "%s" which is not numeric' % (
args.priors_filename, i+1, fields[2]),
file=sys.stderr, flush=True)
sys.exit(1)
priors_list.append((fields[0], fields[1], alpha))

if fields[0] == 'LEX' and len(fields) == 4:
priors_list.append((fields[1], fields[2], alpha))
elif fields[0] == 'HMMF' and len(fields) == 3:
hmmf_priors[int(fields[1])] = alpha
elif fields[0] == 'HMMR' and len(fields) == 3:
hmmr_priors[int(fields[1])] = alpha
elif fields[0] == 'FERF' and len(fields) == 4:
ferf_priors.append((fields[1], int(fields[2]), alpha))
elif fields[0] == 'FERR' and len(fields) == 4:
ferr_priors.append((fields[1], int(fields[2]), alpha))
else:
print('ERROR: priors file %s line %d is invalid ' % (
args.priors_filename, i+1),
file=sys.stderr, flush=True)
sys.exit(1)

if args.joint_filename:
if args.verbose:
Expand Down Expand Up @@ -164,47 +184,86 @@ def main():
trg_sents = None
trg_text = None

def get_src_index(src_word):
src_word = src_word.lower()
if args.source_prefix_len != 0:
src_word = src_word[:args.source_prefix_len]
if args.source_suffix_len != 0:
src_word = src_word[-args.source_suffix_len:]
e = src_index.get(src_word)
if e is not None:
e = e + 1
return e

def get_trg_index(trg_word):
trg_word = trg_word.lower()
if args.target_prefix_len != 0:
trg_word = trg_word[:args.target_prefix_len]
if args.target_suffix_len != 0:
trg_word = trg_word[-args.target_suffix_len:]
f = trg_index.get(trg_word)
if f is not None:
f = f + 1
return f


if args.priors_filename:
priors_indexed = {}
for src_word, trg_word, alpha in priors_list:
if src_word == '<NULL>':
e = 0
else:
src_word = src_word.lower()
if args.source_prefix_len != 0:
src_word = src_word[:args.source_prefix_len]
if args.source_suffix_len != 0:
src_word = src_word[-args.source_suffix_len:]
e = src_index.get(src_word)
if e is not None:
e = e + 1
e = get_src_index(src_word)

if trg_word == '<NULL>':
f = 0
else:
trg_word = trg_word.lower()
if args.target_prefix_len != 0:
trg_word = trg_word[:args.target_prefix_len]
if args.target_suffix_len != 0:
trg_word = trg_word[-args.target_suffix_len:]
f = trg_index.get(trg_word)
if f is not None:
f = f + 1
f = get_trg_index(trg_word)

if (e is not None) and (f is not None):
priors_indexed[(e,f)] = priors_indexed.get((e,f), 0.0) \
+ alpha

ferf_indexed = {}
for src_word, fert, alpha in ferf_priors:
e = get_src_index(src_word)
if e is not None:
ferf_indexed[(e, fert)] = \
ferf_indexed.get((e, fert), 0.0) + alpha

ferr_indexed = {}
for trg_word, fert, alpha in ferr_priors:
f = get_trg_index(trg_word)
if f is not None:
ferr_indexed[(f, fert)] = \
ferr_indexed.get((f, fert), 0.0) + alpha

if args.verbose:
print('%d (of %d) pairs of lexical priors used' % (
len(priors_indexed), len(priors_list)),
file=sys.stderr)
priorsf = NamedTemporaryFile('w', encoding='utf-8')
print('%d %d %d' % (
len(src_index)+1, len(trg_index)+1, len(priors_indexed)),
print('%d %d %d %d %d %d %d' % (
len(src_index)+1, len(trg_index)+1, len(priors_indexed),
len(hmmf_priors), len(hmmr_priors),
len(ferf_indexed), len(ferr_indexed)),
file=priorsf)

for (e, f), alpha in sorted(priors_indexed.items()):
print('%d %d %g' % (e, f, alpha), file=priorsf)

for jump, alpha in sorted(hmmf_priors.items()):
print('%d %g' % (jump, alpha), file=priorsf)

for jump, alpha in sorted(hmmr_priors.items()):
print('%d %g' % (jump, alpha), file=priorsf)

for (e, fert), alpha in sorted(ferf_indexed.items()):
print('%d %d %g' % (e, fert, alpha), file=priorsf)

for (f, fert), alpha in sorted(ferr_indexed.items()):
print('%d %d %g' % (f, fert, alpha), file=priorsf)

priorsf.flush()

trg_index = None
Expand Down
Loading

0 comments on commit 746f30a

Please sign in to comment.