forked from herbie-fp/herbie
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvanilla.patch
146 lines (134 loc) · 4.44 KB
/
vanilla.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
diff --git a/egg-herbie/Cargo.lock b/egg-herbie/Cargo.lock
index d2213ed14..b451675c1 100644
--- a/egg-herbie/Cargo.lock
+++ b/egg-herbie/Cargo.lock
@@ -76,6 +76,7 @@ dependencies = [
"egg",
"env_logger",
"indexmap",
+ "lazy_static",
"libc",
"log",
"num-bigint",
@@ -142,6 +143,12 @@ dependencies = [
"wasm-bindgen",
]
+[[package]]
+name = "lazy_static"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
+
[[package]]
name = "libc"
version = "0.2.153"
diff --git a/egg-herbie/Cargo.toml b/egg-herbie/Cargo.toml
index 013456ffb..ea5603f1e 100644
--- a/egg-herbie/Cargo.toml
+++ b/egg-herbie/Cargo.toml
@@ -16,6 +16,7 @@ num-integer = "0.1.45"
num-rational = "0.4.0"
num-traits = "0.2.15"
env_logger = { version = "0.9", default-features = false }
+lazy_static = "1.5.0"
[lib]
name = "egg_math"
diff --git a/egg-herbie/src/lib.rs b/egg-herbie/src/lib.rs
index cba0d6124..827793796 100644
--- a/egg-herbie/src/lib.rs
+++ b/egg-herbie/src/lib.rs
@@ -14,6 +14,13 @@ use std::os::raw::c_char;
use std::time::Duration;
use std::{slice, sync::atomic::Ordering};
+use lazy_static::lazy_static;
+use std::sync::Mutex;
+
+lazy_static! {
+ static ref INC_ITERDATA: Mutex<Vec<Iteration>> = Mutex::new(vec![]);
+}
+
pub struct Context {
iteration: usize,
runner: Runner,
@@ -113,13 +120,14 @@ pub unsafe extern "C" fn egraph_add_node(
#[no_mangle]
pub unsafe extern "C" fn egraph_copy(ptr: *mut Context) -> *mut Context {
// Safety: `ptr` was box allocated by `egraph_create`
- let context = Box::from_raw(ptr);
+ let mut context = Box::from_raw(ptr);
let mut runner = Runner::new(Default::default())
.with_explanations_enabled()
- .with_egraph(context.runner.egraph.clone());
+ .with_egraph(context.runner.egraph);
runner.roots = context.runner.roots.clone();
runner.egraph.rebuild();
+ context.runner.egraph = EGraph::default();
mem::forget(context);
Box::into_raw(Box::new(Context {
@@ -190,6 +198,7 @@ pub unsafe extern "C" fn egraph_run(
.with_time_limit(Duration::from_secs(u64::MAX))
.with_hook(|r| {
if r.egraph.analysis.unsound.load(Ordering::SeqCst) {
+ panic!("Unsoundness detected");
Err("Unsoundness detected".into())
} else {
Ok(())
@@ -198,6 +207,19 @@ pub unsafe extern "C" fn egraph_run(
.run(&context.rules);
}
+ let mut inc_iterdata = INC_ITERDATA.lock().unwrap();
+ inc_iterdata.extend(context.runner.iterations.clone());
+
+ // Construct a fresh Runner to print the aggregate report
+ let mut tmp = Runner::new(Default::default());
+ tmp.iterations = inc_iterdata.clone();
+ tmp.stop_reason = Some(StopReason::Other("Tmp Runner".to_string()));
+ println!(
+ "Stop reason: {:?}",
+ context.runner.stop_reason.clone().unwrap()
+ );
+ println!("{}", tmp.report());
+
let iterations = context
.runner
.iterations
diff --git a/egg-herbie/src/math.rs b/egg-herbie/src/math.rs
index dce400f7b..d1f471d2d 100644
--- a/egg-herbie/src/math.rs
+++ b/egg-herbie/src/math.rs
@@ -15,10 +15,12 @@ pub type Rewrite = egg::Rewrite<Math, ConstantFold>;
pub type Runner = egg::Runner<Math, ConstantFold, IterData>;
pub type Iteration = egg::Iteration<IterData>;
+#[derive(Clone)]
pub struct IterData {
pub extracted: Vec<(Id, Extracted)>,
}
+#[derive(Clone)]
pub struct Extracted {
pub best: RecExpr,
pub cost: usize,
diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt
index c6315ba73..e82b7f89d 100644
--- a/src/core/egg-herbie.rkt
+++ b/src/core/egg-herbie.rkt
@@ -1,6 +1,6 @@
#lang racket
-(require egg-herbie
+(require "../../egg-herbie/main.rkt"
(only-in ffi/vector
make-u32vector
u32vector-length
diff --git a/src/core/rules.rkt b/src/core/rules.rkt
index 1117a72f2..6077c800e 100644
--- a/src/core/rules.rkt
+++ b/src/core/rules.rkt
@@ -28,10 +28,10 @@
(set-member? (rule-tags rule) tag))
(define (*rules*)
- (filter rule-enabled? *all-rules*))
+ (filter (conjoin rule-enabled? (has-tag? 'sound)) *all-rules*))
(define (*simplify-rules*)
- (filter (conjoin rule-enabled? (has-tag? 'simplify)) *all-rules*))
+ (filter (conjoin rule-enabled? (has-tag? 'simplify) (has-tag? 'sound)) *all-rules*))
;;
;; Rule loading