-
Notifications
You must be signed in to change notification settings - Fork 2
/
em.ml
96 lines (94 loc) · 3.91 KB
/
em.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
open Expression
open Type
open Task
open Library
open Enumerate
open Utils
open Compress
let expectation_maximization_iteration
lambda smoothing frontier_size
tasks grammar =
let (frontiers,dagger) = enumerate_frontiers_for_tasks grammar frontier_size tasks in
print_string "Scoring programs... \n";
let (i2n,_,_) = dagger in
let program_scores = score_programs dagger frontiers tasks in
(* display the hit rate *)
let number_hit = List.length (List.filter (fun scores ->
List.exists (fun (_,s) -> s > log (0.999)) scores
) program_scores) in
let number_of_partial = List.length (List.filter (fun scores ->
List.length scores > 0
) program_scores) in
Printf.printf "Hit %i / %i \n" number_hit (List.length tasks);
Printf.printf "Partial credit %i / %i \n" number_of_partial (List.length tasks);
(* compute likelihoods under grammar and then normalize the frontiers *)
let type_array = infer_graph_types dagger in
let requests = List.fold_left (fun requests (requested_type,frontier) ->
List.fold_left (fun (a : (tp list) IntMap.t) (i : int) ->
try
let old = IntMap.find i a in
if List.mem requested_type old
then a else IntMap.add i (requested_type::old) a
with Not_found -> IntMap.add i [requested_type] a
) requests frontier
) IntMap.empty frontiers
in
let task_solutions = List.filter (fun (_,s) -> List.length s > 0)
(List.combine tasks @@ List.map (List.filter (fun (_,s) -> s > log (0.999))) program_scores)
in
ignore (compress lambda dagger type_array requests task_solutions);
let likelihoods = program_likelihoods grammar dagger type_array requests in
let task_posteriors =
List.map2 (fun task scores ->
let scores = List.map (fun (i,s) -> (i,s+. Hashtbl.find likelihoods (i,task.task_type))) scores
in let z = lse_list (List.map snd scores) in
List.map (fun (i,s) -> (i,s-.z)) scores
) tasks program_scores in
(* compute rewards for each program *)
let rewards = Hashtbl.create 100000 in
List.iter (fun posterior ->
List.iter (fun (i,r) ->
try
let old_reward = Hashtbl.find rewards i in
Hashtbl.replace rewards i (old_reward+.(exp r))
with Not_found -> Hashtbl.add rewards i (exp r)
) posterior
) task_posteriors;
(* compute rewards for each expression *)
let expression_rewards = Hashtbl.create 100000 in
let reward_expression weight i =
let rec reward j =
(try
let old_reward = Hashtbl.find expression_rewards j in
Hashtbl.replace expression_rewards j (old_reward+.weight)
with Not_found -> Hashtbl.add expression_rewards j weight);
match Hashtbl.find i2n j with
ExpressionBranch(l,r) -> reward l; reward r
| _ -> ()
in reward i
in Hashtbl.iter (fun i w -> reward_expression w i) rewards;
(* find those productions that have enough weight to make it into the library *)
let productions = List.map (fun (i,_) ->
extract_expression dagger i) (List.filter (fun (i,r) ->
is_leaf_ID dagger i || r > lambda)
(hash_bindings expression_rewards)) in
let new_grammar = make_flat_library productions in
print_string "Computed posterior probabilities. \n";
(* assembled corpus *)
let corpus = Hashtbl.create 100000 in
List.iter (fun (task, posterior) ->
List.iter (fun (i,log_posterior) ->
let tag = (i,task.task_type) in
try
let old_weight = Hashtbl.find corpus tag in
Hashtbl.replace corpus tag (old_weight+.(exp log_posterior))
with Not_found -> Hashtbl.add corpus tag (exp log_posterior)
) posterior
) (List.combine tasks task_posteriors);
(* fit the continuous parameters of the new grammar and then return it *)
let likelihoods = program_likelihoods new_grammar dagger type_array requests in
let final_grammar = fit_grammar smoothing new_grammar dagger type_array likelihoods (hash_bindings corpus) in
(* print_string (string_of_library final_grammar); *)
print_newline ();
final_grammar
;;