Skip to content

Commit

Permalink
Dev 1654 pagerank (Pometry#869)
Browse files Browse the repository at this point in the history
* add dangling test

* simplify page_rank and set initial value to 1

* added initial support for dangling

* remove println

* a few more changes to match PageRank on old Raphtory

* pagerank has the expected results
  • Loading branch information
fabianmurariu authored May 10, 2023
1 parent 91ae71c commit d1bbeee
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 41 deletions.
179 changes: 140 additions & 39 deletions raphtory/src/algorithms/pagerank.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
task::{
context::Context,
task::{ATask, Job, Step},
task_runner::TaskRunner
task_runner::TaskRunner,
},
view_api::{GraphViewOps, VertexViewOps},
},
Expand All @@ -32,15 +32,11 @@ pub fn unweighted_page_rank<G: GraphViewOps>(
let score = accumulators::val::<f32>(0).init::<InitOneF32>();
let recv_score = accumulators::sum::<f32>(1);
let max_diff = accumulators::max::<f32>(2);
let dangling = accumulators::sum::<f32>(3);

ctx.agg_reset(recv_score);
ctx.global_agg_reset(max_diff);

let step1 = ATask::new(move |vv| {
let initial_score = 1f32 / total_vertices as f32;
vv.update_local(&score, initial_score);
Step::Continue
});
ctx.global_agg_reset(dangling);

let step2 = ATask::new(move |s| {
let out_degree = s.out_degree();
Expand All @@ -49,14 +45,18 @@ pub fn unweighted_page_rank<G: GraphViewOps>(
for t in s.neighbours_out() {
t.update(&recv_score, new_score)
}
} else {
s.global_update(&dangling, s.read_local(&score) / total_vertices as f32);
}
Step::Continue
});

let step3 = ATask::new(move |s| {
let dangling_v = s.read_global_state(&dangling).unwrap_or_default();

s.update_local(
&score,
(1f32 - damping_factor) + (damping_factor * s.read(&recv_score)),
(1f32 - damping_factor) + (damping_factor * (s.read(&recv_score) + dangling_v)),
);
let prev = s.read_local_prev(&score);
let curr = s.read_local(&score);
Expand All @@ -77,7 +77,7 @@ pub fn unweighted_page_rank<G: GraphViewOps>(
let mut runner: TaskRunner<G, _> = TaskRunner::new(ctx);

let (_, _, local_states) = runner.run(
vec![Job::new(step1)],
vec![],
vec![Job::new(step2), Job::new(step3), step4],
threads,
iter_count,
Expand All @@ -87,6 +87,8 @@ pub fn unweighted_page_rank<G: GraphViewOps>(

let mut map: FxHashMap<String, f32> = FxHashMap::default();

let num_vertices = g.num_vertices() as f32;

for state in local_states {
if let Some(state) = state.as_ref() {
state.fold_state_internal(
Expand All @@ -95,7 +97,7 @@ pub fn unweighted_page_rank<G: GraphViewOps>(
&score,
|res, shard, pid, score| {
if let Some(v_ref) = g.lookup_by_pid_and_shard(pid, shard) {
res.insert(g.vertex(v_ref.g_id).unwrap().name(), score);
res.insert(g.vertex(v_ref.g_id).unwrap().name(), score / num_vertices);
}
res
},
Expand All @@ -108,6 +110,9 @@ pub fn unweighted_page_rank<G: GraphViewOps>(

#[cfg(test)]
mod page_rank_tests {
use std::borrow::Borrow;

use itertools::Itertools;
use pretty_assertions::assert_eq;

use crate::db::graph::Graph;
Expand Down Expand Up @@ -135,10 +140,10 @@ mod page_rank_tests {
assert_eq!(
results,
vec![
("2".to_string(), 0.78044075),
("4".to_string(), 0.78044075),
("1".to_string(), 1.4930439),
("3".to_string(), 0.8092761)
("2".to_string(), 0.20249715),
("4".to_string(), 0.20249715),
("1".to_string(), 0.38669053),
("3".to_string(), 0.20831521)
]
.into_iter()
.collect::<FxHashMap<String, f32>>()
Expand Down Expand Up @@ -205,36 +210,132 @@ mod page_rank_tests {
.collect();

let expected_2 = vec![
("10".to_string(), 0.6598998),
("7".to_string(), 0.14999998),
("4".to_string(), 0.72722703),
("1".to_string(), 1.0329459),
("11".to_string(), 0.5662594),
("8".to_string(), 1.2494258),
("5".to_string(), 1.7996559),
("2".to_string(), 0.32559997),
("9".to_string(), 0.5662594),
("6".to_string(), 0.6598998),
("3".to_string(), 1.4175149),
("10".to_string(), 0.07208286),
("11".to_string(), 0.061855234),
("5".to_string(), 0.19658245),
("4".to_string(), 0.07943771),
("9".to_string(), 0.061855234),
("3".to_string(), 0.15484008),
("8".to_string(), 0.136479),
("2".to_string(), 0.035566494),
("7".to_string(), 0.016384698),
("1".to_string(), 0.1128334),
("6".to_string(), 0.07208286),
];

// let expected = vec![
// (1, 1.2411863819664029),
// (2, 0.39123721383779864),
// (3, 1.7032272385548306),
// (4, 0.873814473224871),
// (5, 2.162387978524525),
// (6, 0.7929037468922092),
// (8, 1.5012556698522248),
// (7, 0.1802324126887131),
// (9, 0.6804255687831074),
// (10, 0.7929037468922092),
// (11, 0.6804255687831074),
// ];

assert_eq!(
results,
expected_2.into_iter().collect::<FxHashMap<String, f32>>()
);
}

#[test]
fn two_nodes_page_rank() {
let edges = vec![(1, 2), (2, 1)];

let graph = Graph::new(4);

for (t, (src, dst)) in edges.into_iter().enumerate() {
graph.add_edge(t as i64, src, dst, &vec![], None).unwrap();
}

let results: FxHashMap<String, f32> =
unweighted_page_rank(&graph, 1000, Some(4), Some(0.00001))
.into_iter()
.collect();

assert_eq_f32(results.get("1"), Some(&0.5), 3);
assert_eq_f32(results.get("2"), Some(&0.5), 3);
}

#[test]
fn three_nodes_page_rank_one_dangling() {
let edges = vec![(1, 2), (2, 1), (2, 3)];

let graph = Graph::new(4);

for (t, (src, dst)) in edges.into_iter().enumerate() {
graph.add_edge(t as i64, src, dst, &vec![], None).unwrap();
}

let results: FxHashMap<String, f32> =
unweighted_page_rank(&graph, 1000, Some(4), Some(0.0000001))
.into_iter()
.collect();

assert_eq_f32(results.get("1"), Some(&0.303), 3);
assert_eq_f32(results.get("2"), Some(&0.394), 3);
assert_eq_f32(results.get("3"), Some(&0.303), 3);
}

#[test]
fn dangling_page_rank() {
let edges = vec![
(1, 2),
(1, 3),
(2, 3),
(3, 1),
(3, 2),
(3, 4),
// dangling from here
(4, 5),
(5, 6),
(6, 7),
(7, 8),
(8, 9),
(9, 10),
(10, 11),
]
.into_iter()
.enumerate()
.map(|(t, (src, dst))| (src, dst, t as i64))
.collect_vec();

let graph = Graph::new(4);

for (src, dst, t) in edges {
graph.add_edge(t, src, dst, &vec![], None).unwrap();
}

let results: FxHashMap<String, f32> =
unweighted_page_rank(&graph, 1000, Some(4), Some(0.00001))
.into_iter()
.collect();

assert_eq_f32(results.get("1"), Some(&0.055), 3);
assert_eq_f32(results.get("2"), Some(&0.079), 3);
assert_eq_f32(results.get("3"), Some(&0.113), 3);
assert_eq_f32(results.get("4"), Some(&0.055), 3);
assert_eq_f32(results.get("5"), Some(&0.070), 3);
assert_eq_f32(results.get("6"), Some(&0.083), 3);
assert_eq_f32(results.get("7"), Some(&0.093), 3);
assert_eq_f32(results.get("8"), Some(&0.102), 3);
assert_eq_f32(results.get("9"), Some(&0.110), 3);
assert_eq_f32(results.get("10"), Some(&0.117), 3);
assert_eq_f32(results.get("11"), Some(&0.122), 3);
}

fn assert_eq_f32<T: Borrow<f32> + PartialEq + std::fmt::Debug>(
a: Option<T>,
b: Option<T>,
decimals: u8,
) {
if a.is_none() || b.is_none() {
assert_eq!(a, b);
} else {
let factor = 10.0_f32.powi(decimals as i32);
match (a, b) {
(Some(a), Some(b)) => {
assert_eq!(
(a.borrow() * factor).round(),
(b.borrow() * factor).round(),
"{:?} != {:?}",
a,
b
);
}
_ => unreachable!(),
}
}
}
}
3 changes: 1 addition & 2 deletions raphtory/src/db/view_api/internal.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::core::tgraph::{EdgeRef, VertexRef};
use crate::core::{Direction, Prop};
use rayon::prelude::*;
use std::collections::HashMap;
use std::{ops::Range, sync::Arc};
use std::ops::Range;

/// The GraphViewInternalOps trait provides a set of methods to query a directed graph
/// represented by the raphtory_core::tgraph::TGraph struct.
Expand Down

0 comments on commit d1bbeee

Please sign in to comment.