Skip to content

Commit 676cd3f

Browse files
committed
Upgrade to TensorFlow 0.10.0
1 parent 0c0d0ac commit 676cd3f

File tree

7 files changed

+261
-223
lines changed

7 files changed

+261
-223
lines changed

examples/expressions.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ fn run() -> Result<(), Box<Error>> {
6565
let x_expr = <Placeholder<f32>>::new_expr(&vec![2], "x");
6666
try!(compiler.compile(x_expr * 2.0f32 + 1.0f32))
6767
};
68-
let x_node = try!(g.node_by_name_required("x"));
68+
let x_node = try!(g.operation_by_name_required("x"));
6969
// This is another valid way to get x_node and y_node:
7070
// let (x_node, y_node) = {
7171
// let mut compiler = Compiler::new(&mut g);

src/expr.rs

+35-35
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::ops;
1818
use std::rc::Rc;
1919
use super::DataType;
2020
use super::Graph;
21-
use super::Node;
21+
use super::Operation;
2222
use super::Port;
2323
use super::Status;
2424
use super::Tensor;
@@ -37,7 +37,7 @@ pub enum OpLevel {
3737

3838
////////////////////////
3939

40-
/// A node in an expression tree, which is a thin wrapper around an ExprImpl.
40+
/// A operation in an expression tree, which is a thin wrapper around an ExprImpl.
4141
///
4242
/// This is separate from ExprImpl because we want expressions to be wrapped in an Rc,
4343
/// and we can't directly implement std::ops::Add, etc., for Rc<E: ExprImpl<T>>.
@@ -68,7 +68,7 @@ pub trait ExprImpl<T: TensorType>: Display + Debug {
6868
/// Returns the precedence level for this operator.
6969
fn op_level(&self) -> OpLevel;
7070
fn children(&self) -> Vec<Box<AnyExpr>>; // TODO: return an iterator
71-
fn create_node(&self, graph: &mut Graph, children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status>;
71+
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status>;
7272
}
7373

7474
impl<T: TensorType> ExprImpl<T> for T {
@@ -80,8 +80,8 @@ impl<T: TensorType> ExprImpl<T> for T {
8080
vec![]
8181
}
8282

83-
fn create_node(&self, graph: &mut Graph, _children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
84-
let mut nd = try!(graph.new_node("Const", &id_gen()));
83+
fn create_operation(&self, graph: &mut Graph, _children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
84+
let mut nd = try!(graph.new_operation("Const", &id_gen()));
8585
try!(nd.set_attr_type("dtype", DataType::Float));
8686
let mut value = Tensor::new(&[1]);
8787
value[0] = *self;
@@ -157,10 +157,10 @@ macro_rules! impl_bin_op {
157157
vec![Box::new(self.left.clone()), Box::new(self.right.clone())]
158158
}
159159

160-
fn create_node(&self, graph: &mut Graph, children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
161-
let mut nd = try!(graph.new_node($tf_op, &id_gen()));
162-
nd.add_input(Port {node: &children[0], index: 0});
163-
nd.add_input(Port {node: &children[1], index: 0});
160+
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
161+
let mut nd = try!(graph.new_operation($tf_op, &id_gen()));
162+
nd.add_input(Port {operation: &children[0], index: 0});
163+
nd.add_input(Port {operation: &children[1], index: 0});
164164
nd.finish()
165165
}
166166
}
@@ -213,9 +213,9 @@ impl<T: TensorType> ExprImpl<T> for Neg<T> {
213213
vec![Box::new(self.expr.clone())]
214214
}
215215

216-
fn create_node(&self, graph: &mut Graph, children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
217-
let mut nd = try!(graph.new_node("Neg", &id_gen()));
218-
nd.add_input(Port {node: &children[0], index: 0});
216+
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
217+
let mut nd = try!(graph.new_operation("Neg", &id_gen()));
218+
nd.add_input(Port {operation: &children[0], index: 0});
219219
nd.finish()
220220
}
221221
}
@@ -261,8 +261,8 @@ impl<T: TensorType> ExprImpl<T> for Variable<T> {
261261
vec![]
262262
}
263263

264-
fn create_node(&self, graph: &mut Graph, _children: &[Rc<Node>], _id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
265-
let mut nd = try!(graph.new_node("Variable", &self.name));
264+
fn create_operation(&self, graph: &mut Graph, _children: &[Rc<Operation>], _id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
265+
let mut nd = try!(graph.new_operation("Variable", &self.name));
266266
nd.set_attr_type("dtype", DataType::Float).unwrap();
267267
nd.set_attr_shape("shape", &vec![]).unwrap();
268268
nd.finish()
@@ -310,8 +310,8 @@ impl<T: TensorType> ExprImpl<T> for Placeholder<T> {
310310
vec![]
311311
}
312312

313-
fn create_node(&self, graph: &mut Graph, _children: &[Rc<Node>], _id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
314-
let mut nd = try!(graph.new_node("Placeholder", &self.name));
313+
fn create_operation(&self, graph: &mut Graph, _children: &[Rc<Operation>], _id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
314+
let mut nd = try!(graph.new_operation("Placeholder", &self.name));
315315
nd.set_attr_type("dtype", DataType::Float).unwrap();
316316
nd.set_attr_shape("shape", &vec![]).unwrap();
317317
nd.finish()
@@ -358,10 +358,10 @@ impl<T: TensorType> ExprImpl<T> for Assign<T> {
358358
vec![Box::new(self.variable.clone()), Box::new(self.value.clone())]
359359
}
360360

361-
fn create_node(&self, graph: &mut Graph, children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
362-
let mut nd = try!(graph.new_node("Assign", &id_gen()));
363-
nd.add_input(Port {node: &children[0], index: 0});
364-
nd.add_input(Port {node: &children[1], index: 0});
361+
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
362+
let mut nd = try!(graph.new_operation("Assign", &id_gen()));
363+
nd.add_input(Port {operation: &children[0], index: 0});
364+
nd.add_input(Port {operation: &children[1], index: 0});
365365
nd.finish()
366366
}
367367
}
@@ -372,7 +372,7 @@ impl<T: TensorType> ExprImpl<T> for Assign<T> {
372372
pub trait AnyExpr {
373373
fn key(&self) -> *const ();
374374
fn children(&self) -> Vec<Box<AnyExpr>>; // TODO: return an iterator
375-
fn create_node(&self, graph: &mut Graph, children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status>;
375+
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status>;
376376
fn clone_box(&self) -> Box<AnyExpr>;
377377
}
378378

@@ -385,8 +385,8 @@ impl<T: TensorType> AnyExpr for Expr<T> {
385385
self.expr.children()
386386
}
387387

388-
fn create_node(&self, graph: &mut Graph, children: &[Rc<Node>], id_gen: &mut FnMut() -> String) -> Result<Node, Status> {
389-
self.expr.create_node(graph, children, id_gen)
388+
fn create_operation(&self, graph: &mut Graph, children: &[Rc<Operation>], id_gen: &mut FnMut() -> String) -> Result<Operation, Status> {
389+
self.expr.create_operation(graph, children, id_gen)
390390
}
391391

392392
fn clone_box(&self) -> Box<AnyExpr> {
@@ -414,45 +414,45 @@ impl Hash for Key {
414414

415415
pub struct Compiler<'l> {
416416
graph: &'l mut Graph,
417-
nodes: HashMap<Key, Rc<Node>>,
417+
operations: HashMap<Key, Rc<Operation>>,
418418
next_id: i32,
419419
}
420420

421421
impl<'l> Compiler<'l> {
422422
pub fn new(graph: &'l mut Graph) -> Self {
423423
Compiler {
424424
graph: graph,
425-
nodes: HashMap::new(),
425+
operations: HashMap::new(),
426426
next_id: 0,
427427
}
428428
}
429429

430-
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Rc<Node>, Status> {
430+
pub fn compile<T: TensorType>(&mut self, expr: Expr<T>) -> Result<Rc<Operation>, Status> {
431431
self.compile_any(Box::new(expr))
432432
}
433433

434-
pub fn compile_any(&mut self, expr: Box<AnyExpr>) -> Result<Rc<Node>, Status> {
435-
let mut child_nodes = vec![];
434+
pub fn compile_any(&mut self, expr: Box<AnyExpr>) -> Result<Rc<Operation>, Status> {
435+
let mut child_operations = vec![];
436436
for child in expr.children() {
437437
let key = Key(child.clone_box());
438438
// The result is mapped separately from the match statement below to avoid
439439
// reference lifetime isues.
440-
let value = self.nodes.get(&key).map(|v| v.clone());
441-
child_nodes.push(match value {
440+
let value = self.operations.get(&key).map(|v| v.clone());
441+
child_operations.push(match value {
442442
Some(v) => v,
443443
None => try!(self.compile_any(child)),
444444
});
445445
}
446446
let mut next_id = self.next_id;
447-
let result = expr.create_node(self.graph, &child_nodes, &mut || {
448-
let id = format!("node_{}", next_id);
447+
let result = expr.create_operation(self.graph, &child_operations, &mut || {
448+
let id = format!("operation_{}", next_id);
449449
next_id += 1;
450450
id
451451
});
452452
self.next_id = next_id;
453-
let node = Rc::new(try!(result));
454-
self.nodes.insert(Key(expr), node.clone());
455-
Ok(node)
453+
let operation = Rc::new(try!(result));
454+
self.operations.insert(Key(expr), operation.clone());
455+
Ok(operation)
456456
}
457457
}
458458

0 commit comments

Comments
 (0)