-
Notifications
You must be signed in to change notification settings - Fork 430
/
Copy pathexpressions.rs
91 lines (82 loc) · 2.61 KB
/
expressions.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
extern crate rand;
extern crate tensorflow;
use std::error::Error;
use std::result::Result;
use tensorflow::expr::{Compiler, Placeholder};
use tensorflow::Code;
use tensorflow::Graph;
use tensorflow::Session;
use tensorflow::SessionOptions;
use tensorflow::SessionRunArgs;
use tensorflow::Status;
use tensorflow::Tensor;
#[cfg_attr(feature = "examples_system_alloc", global_allocator)]
#[cfg(feature = "examples_system_alloc")]
static ALLOCATOR: std::alloc::System = std::alloc::System;
struct Checker {
success: bool,
epsilon: f32,
}
impl Checker {
fn new(epsilon: f32) -> Self {
Checker {
success: true,
epsilon: epsilon,
}
}
fn check(&mut self, name: &str, expected: f32, actual: f32) {
let success = (expected - actual).abs() < self.epsilon;
println!(
"Checking {}: expected {}, got {}. {}",
name,
expected,
actual,
if success { "Success!" } else { "FAIL" }
);
self.success &= success;
}
fn result(&self) -> Result<(), Box<dyn Error>> {
if self.success {
Ok(())
} else {
Err(Box::new(Status::new_set(
Code::Internal,
"At least one check failed",
)?))
}
}
}
fn main() -> Result<(), Box<dyn Error>> {
// Build the graph
let mut g = Graph::new();
let y_node = {
let mut compiler = Compiler::new(&mut g);
let x_expr = <Placeholder<f32>>::new_expr(&vec![2], "x");
compiler.compile(x_expr * 2.0f32 + 1.0f32)?
};
let x_node = g.operation_by_name_required("x")?;
// This is another valid way to get x_node and y_node:
// let (x_node, y_node) = {
// let mut compiler = Compiler::new(&mut g);
// let x_expr = <Placeholder<f32>>::new_expr(&vec![2], "x");
// let x_node = compiler.compile(x_expr.clone())?;
// let y_node = compiler.compile(x_expr * 2.0f32 + 1.0f32)?;
// (x_node, y_node)
// };
let options = SessionOptions::new();
let session = Session::new(&options, &g)?;
// Evaluate the graph.
let mut x = <Tensor<f32>>::new(&[2]);
x[0] = 2.0;
x[1] = 3.0;
let mut step = SessionRunArgs::new();
step.add_feed(&x_node, 0, &x);
let output_token = step.request_fetch(&y_node, 0);
session.run(&mut step).unwrap();
// Check our results.
let output_tensor = step.fetch::<f32>(output_token)?;
let mut checker = Checker::new(1e-3);
checker.check("output_tensor[0]", 5.0, output_tensor[0]);
checker.check("output_tensor[1]", 7.0, output_tensor[1]);
checker.result()
}