Skip to content

Commit

Permalink
Arbitrary priors should now work.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertostling committed Oct 2, 2018
1 parent 2aba26a commit 94926a3
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 7 deletions.
26 changes: 25 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,33 @@ details.

## Input data format

The `align.py` interface expects one sentence per line with space-separated
When used with the `-s` and `-t` options for separate source/target files, the
`align.py` interface expects one sentence per line with space-separated
tokens, similar to most word alignment software.

The `-i` option assumes a `fast_text` style joint source/target file of the
format
```
source sentence ||| target sentence
another source sentence ||| another target sentence
...
```

The `--priors` option expects a file of the following structure:
```
sourceword1<TAB>targetword1<TAB>alpha1
sourceword2<TAB>targetword2<TAB>alpha2
sourceword3<TAB>targetword3<TAB>alpha3
...
```
where the `alphaN` values will be added to the Dirichlet prior of generating
`targetewordN` from `sourcewordN` (or vice versa, for reverse-direction
alignments). Note that the source and target word will be processed in the
same way as input text, i.e. lower-cased (always) and optionally stemmed
according to the `--source-prefix-len`, `--source-suffix-len`,
`--target-prefix-len`, `--target-suffix-len` options. In other words, you
should be able to pass a raw wordlist to it.

## Output data format

The alignment output contains the same number of lines as the input files,
Expand Down
84 changes: 81 additions & 3 deletions align.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def main():
'-r', '--reverse-links', dest='links_filename_rev', type=str,
metavar='filename',
help='Filename to write reverse direction alignments to')

parser.add_argument(
'-p', '--priors', dest='priors_filename', type=str, metavar='filename',
help='Filename of lexical priors')

args = parser.parse_args()

if not (args.joint_filename or (args.source_filename and
Expand All @@ -92,6 +95,32 @@ def main():
file=sys.stderr, flush=True)
sys.exit(1)

if args.priors_filename:
if args.verbose:
print('Reading lexical priors from %s...' %
args.priors_filename,
file=sys.stderr, flush=True)

priors_list = []
with open(args.priors_filename, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
fields = line.rstrip('\n').split('\t')
if len(fields) != 3:
print('ERROR: priors file %s line %d contains %d '
'tab-separated fields but should have 3' % (
args.priors_filename, i+1, len(fields)),
file=sys.stderr, flush=True)
sys.exit(1)
try:
alpha = float(fields[2])
except ValueError:
print('ERROR: priors file %s line %d contains alpha '
'value of "%s" which is not numeric' % (
args.priors_filename, i+1, fields[2]),
file=sys.stderr, flush=True)
sys.exit(1)
priors_list.append((fields[0], fields[1], alpha))

if args.joint_filename:
if args.verbose:
print('Reading source/target sentences from %s...' %
Expand Down Expand Up @@ -120,7 +149,6 @@ def main():
f, True, args.source_prefix_len, args.source_suffix_len)
n_src_sents = len(src_sents)
src_voc_size = len(src_index)
src_index = None
srcf = NamedTemporaryFile('wb')
write_text(srcf, tuple(src_sents), src_voc_size)
src_sents = None
Expand All @@ -131,11 +159,57 @@ def main():
f, True, args.target_prefix_len, args.target_suffix_len)
trg_voc_size = len(trg_index)
n_trg_sents = len(trg_sents)
trg_index = None
trgf = NamedTemporaryFile('wb')
write_text(trgf, tuple(trg_sents), trg_voc_size)
trg_sents = None
trg_text = None

if args.priors_filename:
priors_indexed = {}
for src_word, trg_word, alpha in priors_list:
if src_word == '<NULL>':
e = 0
else:
src_word = src_word.lower()
if args.source_prefix_len != 0:
src_word = src_word[:args.source_prefix_len]
if args.source_suffix_len != 0:
src_word = src_word[-args.source_suffix_len:]
e = src_index.get(src_word)
if e is not None:
e = e + 1

if trg_word == '<NULL>':
f = 0
else:
trg_word = trg_word.lower()
if args.target_prefix_len != 0:
trg_word = trg_word[:args.target_prefix_len]
if args.target_suffix_len != 0:
trg_word = trg_word[-args.target_suffix_len:]
f = trg_index.get(trg_word)
if f is not None:
f = f + 1

if (e is not None) and (f is not None):
priors_indexed[(e,f)] = priors_indexed.get((e,f), 0.0) \
+ alpha

if args.verbose:
print('%d (of %d) pairs of lexical priors used' % (
len(priors_indexed), len(priors_list)),
file=sys.stderr)
priorsf = NamedTemporaryFile('w', encoding='utf-8')
print('%d %d %d' % (
len(src_index)+1, len(trg_index)+1, len(priors_indexed)),
file=priorsf)
for (e, f), alpha in sorted(priors_indexed.items()):
print('%d %d %g' % (e, f, alpha), file=priorsf)
priorsf.flush()

trg_index = None
src_index = None

else:
if args.verbose:
print('Reading source text from %s...' % args.source_filename,
Expand Down Expand Up @@ -182,6 +256,8 @@ def main():
links_filename_rev=args.links_filename_rev,
statistics_filename=None,
scores_filename=None,
priors_filename=(None if args.priors_filename is None
else priorsf.name),
model=args.model,
n_iterations=iters,
n_samplers=args.n_samplers,
Expand All @@ -192,6 +268,8 @@ def main():

srcf.close()
trgf.close()
if args.priors_filename:
priorsf.close()


if __name__ == '__main__': main()
Expand Down
32 changes: 29 additions & 3 deletions eflomal.c
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ void text_alignment_randomize(struct text_alignment *ta, random_state *state) {
}

int text_alignment_load_priors(
struct text_alignment *ta, const char *filename)
struct text_alignment *ta, const char *filename, int reverse)
{
FILE *file = (!strcmp(filename, "-"))? stdin: fopen(filename, "r");
if (file == NULL) {
Expand Down Expand Up @@ -754,8 +754,29 @@ int text_alignment_load_priors(
return -1;
}

size_t t;
if (reverse) {
t = source_vocabulary_size;
source_vocabulary_size = target_vocabulary_size;
target_vocabulary_size = t;
}

if (source_vocabulary_size != ta->source->vocabulary_size ||
target_vocabulary_size != ta->target->vocabulary_size)
{
fprintf(stderr,
"text_alignment_load_priors(): vocabulary size mismatch, "
"source is %zd (expected %zd) "
"and target is %zd (expected %zd)\n",
source_vocabulary_size, ta->source->vocabulary_size,
target_vocabulary_size, ta->target->vocabulary_size,
filename);
if (file != stdin) fclose(file);
return -1;
}

for (size_t i=0; i<n_lex_priors; i++) {
token e, f;
token e, f, t;
float alpha;
if (fscanf(file, "%"SCNtoken" %"SCNtoken" %f", &e, &f, &alpha) != 3) {
fprintf(stderr,
Expand All @@ -764,6 +785,11 @@ int text_alignment_load_priors(
if (file != stdin) fclose(file);
return -1;
}
if (reverse) {
t = e;
e = f;
f = t;
}
// TODO: fix this properly
map_token_u32_add(ta->source_prior + e, f, *((uint32_t*)&alpha));
ta->source_prior_sum[e] += alpha;
Expand Down Expand Up @@ -986,7 +1012,7 @@ static void align(
// TODO: since read-only, could use the pointer from tas[0]
// for everything, but this would require careful
// initialization/destruction
if(text_alignment_load_priors(tas[i], priors_filename)) {
if(text_alignment_load_priors(tas[i], priors_filename, reverse)) {
fprintf(stderr, "Unable to load %s, exiting\n",
priors_filename);
exit(1);
Expand Down
3 changes: 3 additions & 0 deletions python/eflomal/eflomal.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def align(
str links_filename_rev=None,
str statistics_filename=None,
str scores_filename=None,
str priors_filename=None,
int model=3,
tuple n_iterations=None,
int n_samplers=1,
Expand All @@ -111,6 +112,7 @@ def align(
links_filename_rev -- if given, write links here (reverse direction)
statistics_filename -- if given, write alignment statistics here
scores_filename -- if given, write sentence alignment scoeres here
priors_filename -- if given, read Dirichlet priors from here
model -- alignment model (1 = IBM1, 2 = HMM, 3 = HMM+fertility)
n_iterations -- 3-tuple with number of iterations per model, if this is
not given the numbers will be computed automatically based
Expand Down Expand Up @@ -160,6 +162,7 @@ def align(
if links_filename_rev: args.extend(['-r', links_filename_rev])
if statistics_filename: args.extend(['-S', statistics_filename])
if scores_filename: args.extend(['-x', scores_filename])
if priors_filename: args.extend(['-p', priors_filename])
if not quiet: sys.stderr.write(' '.join(args) + '\n')
if use_gdb: args = ['gdb', '-ex=run', '--args'] + args
subprocess.call(args)
Expand Down

0 comments on commit 94926a3

Please sign in to comment.