From b2f6f323fe9f0c4c8cd1c84830fa6c37882ad6cb Mon Sep 17 00:00:00 2001 From: Deeptendu Santra Date: Mon, 18 Apr 2022 18:46:47 +0530 Subject: [PATCH] Accordance with rest graph datasets + optimizied --- docs/src/datasets/graphs.md | 2 + src/datasets_graph/Reddit/Reddit.jl | 140 +++++++++++++++------------- 2 files changed, 75 insertions(+), 67 deletions(-) diff --git a/docs/src/datasets/graphs.md b/docs/src/datasets/graphs.md index 5092c512..4baaf2dd 100644 --- a/docs/src/datasets/graphs.md +++ b/docs/src/datasets/graphs.md @@ -23,4 +23,6 @@ TUDataset KarateClub KarateClub.edge_index KarateClub.labels +Reddit +Reddit.dataset ``` diff --git a/src/datasets_graph/Reddit/Reddit.jl b/src/datasets_graph/Reddit/Reddit.jl index 5ae21590..7d623e6f 100644 --- a/src/datasets_graph/Reddit/Reddit.jl +++ b/src/datasets_graph/Reddit/Reddit.jl @@ -3,9 +3,21 @@ export Reddit """ Reddit -# References + The Reddit dataset was introduced in Ref [1]. + It is a graph dataset of Reddit posts made in the month of September, 2014. + The dataset contains a single post-to-post graph, connecting posts if the same user comments on both. + The node label in this case is one of the 50 communities, or “subreddit”s, that a post belongs to. + This dataset contains 231,443 posts. + The first 20 days are used for training and the remaining days for testing (with 30% used for validation). + For features, off-the-shelf 300-dimensional GloVe CommonCrawl word vectors are used. + +# Interface +- [`Reddit.dataset`](@ref) + +# References [1]: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) +[2]: [Benchmarks on the Reddit Dataset](https://paperswithcode.com/dataset/reddit) """ module Reddit using DataDeps @@ -13,16 +25,17 @@ using JSON3 using ..MLDatasets: datafile using NPZ: npzread -DATA = [ +DATAFILES = [ "reddit-G.json", "reddit-G_full.json", "reddit-adjlist.txt", - "reddit-class_map.json", "reddit-feats.npy", "reddit-id_map.json", "reddit-walks.txt" + "reddit-class_map.json", "reddit-feats.npy", "reddit-id_map.json" ] +DATA = joinpath.("reddit", DATAFILES) DEPNAME = "Reddit" function __init__() DEPNAME = "Reddit" - LINK = "" - DOCS = "" + LINK = "http://snap.stanford.edu/graphsage/reddit.zip" + DOCS = "http://snap.stanford.edu/graphsage/" register(DataDep( DEPNAME, @@ -37,94 +50,87 @@ function __init__() end """ - edge_index(self_loops=true) - + dataset(self_loops=true; dir=nothing) + +Retrieve the Reddit Graph Dataset. The output is a named tuple with fields +```julia-repl +julia> keys(Reddit.dataset()) +(:directed, :multigraph, :num_graphs, :num_edges, :num_nodes, :edge_index, :node_labels, :node_features, :split) +``` +See also [`CiteSeer`](@ref). + +# Usage Examples + +```julia +using MLDatasets: Reddit +data = Reddit.dataset() +train_indices = data.split["train"] +train_features = data.node_features[train_indices, :] +train_labels = data.node_labels[train_indices] +``` """ -function edge_index(self_loops=true) +function dataset(self_loops=true; dir=nothing) if self_loops - data_path = datafile(DEPNAME, DATA[1]) + graph_json = datafile(DEPNAME, DATA[1], dir) else - data_path = datafile(DEPNAME, DATA[0]) + graph_json = datafile(DEPNAME, DATA[0], dir) end - graph = open(JSON3.read, data_path) + class_map_json = datafile(DEPNAME, DATA[4], dir) + id_map_json = datafile(DEPNAME, DATA[6], dir) + feat_path = datafile(DEPNAME, DATA[5], dir) + + # Read the json files + graph = open(JSON3.read, graph_json) + class_map = open(JSON3.read, class_map_json) + id_map = open(JSON3.read, id_map_json) + + # Metadata + directed = graph["directed"] + multigraph = graph["multigraph"] links = graph["links"] + nodes = graph["nodes"] + num_edges = directed ? length(links) * 2 : length(links) + num_nodes = length(nodes) + num_graphs = length(graph["graph"]) # should be zero + # edges s = map(link->link["source"], links) t = map(link->link["target"], links) + edge_index = directed ? [s t; t s] : [s; t] # not a vector of vector - edge_index = [s t; t s] - - return edge_index -end - -""" - labels() - -""" -function labels() - data_path = datafile(DEPNAME, DATA[1]) - graph = open(JSON3.read, data_path) - nodes = graph[:nodes] - - class_map_file = datafile(DEPNAME, DATA[4]) - id_map_file = datafile(DEPNAME, DATA[6]) - class_map = open(JSON3.read, class_map_file) - id_map = open(JSON3.read, id_map_file) - + # labels node_keys = get.(nodes, "id", nothing) node_idx = [id_map[key] for key in node_keys] - class_idx = [class_map[key] for key in node_keys] - # Sort according to the node_idx - @assert length(node_idx) == length(class_idx) - return class_idx -end + labels = [class_map[key] for key in node_keys] + # TODO: ort according to the node_idx + @assert length(node_idx) == length(labels) -""" - node_features() - -""" -function node_features() - feat_path = datafile(DEPNAME, DATA[5]) + # features features = npzread(feat_path) - return features -end - -""" - split() - -""" -function split() - data_path = datafile(DEPNAME, DATA[1]) - graph = open(JSON3.read, data_path) - nodes = graph["nodes"] - - id_map_file = datafile(DEPNAME, DATA[6]) - id_map = open(JSON3.read, id_map_file) + # split test_mask = get.(nodes, "test", nothing) val_mask = get.(nodes, "val", nothing) + # A node should not be both test and validation @assert sum(val_mask .& test_mask) == 0 train_mask = nor.(test_mask, val_mask) - train_nodes = nodes[train_mask] - val_nodes = nodes[val_mask] - test_nodes = nodes[test_mask] - - train_keys = get.(train_nodes, "id", nothing) - test_keys = get.(test_nodes, "id", nothing) - val_keys = get.(val_nodes, "id", nothing) + train_idx = node_idx[train_mask] + test_idx = node_idx[test_mask] + val_idx = node_idx[val_mask] - train_idx = [id_map[key] for key in train_keys] - test_idx = [id_map[key] for key in test_keys] - val_idx = [id_map[key] for key in val_keys] - - split_dict = Dict( + split = Dict( "train" => train_idx, "test" => test_idx, "val" => val_idx ) + + return ( + directed=directed, multigraph=multigraph, num_graphs=num_graphs, num_edges=num_edges, num_nodes=num_nodes, + edge_index=edge_index, node_labels=labels, node_features=features, split=split + ) end end #module \ No newline at end of file