-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DGL-Go] Inference for Node Prediction Pipeline (full & ns) (dmlc#4095)
* Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update * Update
- Loading branch information
Showing
43 changed files
with
719 additions
and
186 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .nodepred import ApplyNodepredPipeline | ||
from .nodepred_sample import ApplyNodepredNsPipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gen import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import ruamel.yaml | ||
import torch | ||
import typer | ||
|
||
from copy import deepcopy | ||
from jinja2 import Template | ||
from pathlib import Path | ||
from pydantic import Field | ||
from typing import Optional | ||
|
||
from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, NodeModelFactory | ||
from ...utils.yaml_dump import deep_convert_dict, merge_comment | ||
|
||
@ApplyPipelineFactory.register("nodepred") | ||
class ApplyNodepredPipeline(PipelineBase): | ||
|
||
def __init__(self): | ||
self.pipeline = { | ||
"name": "nodepred", | ||
"mode": "apply" | ||
} | ||
|
||
@classmethod | ||
def setup_user_cfg_cls(cls): | ||
from ...utils.enter_config import UserConfig | ||
class ApplyNodePredUserConfig(UserConfig): | ||
data: DataFactory.filter("nodepred").get_pydantic_config() = Field(..., discriminator="name") | ||
|
||
cls.user_cfg_cls = ApplyNodePredUserConfig | ||
|
||
@property | ||
def user_cfg_cls(self): | ||
return self.__class__.user_cfg_cls | ||
|
||
def get_cfg_func(self): | ||
def config( | ||
data: DataFactory.filter("nodepred").get_dataset_enum() = typer.Option(None, help="input data name"), | ||
cfg: Optional[str] = typer.Option(None, help="output configuration file path"), | ||
cpt: str = typer.Option(..., help="input checkpoint file path") | ||
): | ||
# Training configuration | ||
train_cfg = torch.load(cpt)["cfg"] | ||
if data is None: | ||
print("data is not specified, use the training dataset") | ||
data = train_cfg["data_name"] | ||
else: | ||
data = data.name | ||
if cfg is None: | ||
cfg = "_".join(["apply", "nodepred", data, train_cfg["model_name"]]) + ".yaml" | ||
|
||
self.__class__.setup_user_cfg_cls() | ||
generated_cfg = { | ||
"pipeline_name": self.pipeline["name"], | ||
"pipeline_mode": self.pipeline["mode"], | ||
"device": train_cfg["device"], | ||
"data": {"name": data}, | ||
"cpt_path": cpt, | ||
"general_pipeline": {"save_path": "apply_results"} | ||
} | ||
output_cfg = self.user_cfg_cls(**generated_cfg).dict() | ||
output_cfg = deep_convert_dict(output_cfg) | ||
# Not applicable for inference | ||
output_cfg['data'].pop('split_ratio') | ||
comment_dict = { | ||
"device": "Torch device name, e.g., cpu or cuda or cuda:0", | ||
"cpt_path": "Path to the checkpoint file", | ||
"general_pipeline": {"save_path": "Directory to save the inference results"} | ||
} | ||
comment_dict = merge_comment(output_cfg, comment_dict) | ||
|
||
yaml = ruamel.yaml.YAML() | ||
yaml.dump(comment_dict, Path(cfg).open("w")) | ||
print("Configuration file is generated at {}".format(Path(cfg).absolute())) | ||
|
||
return config | ||
|
||
@classmethod | ||
def gen_script(cls, user_cfg_dict): | ||
# Check validation | ||
cls.setup_user_cfg_cls() | ||
cls.user_cfg_cls(**user_cfg_dict) | ||
|
||
# Training configuration | ||
train_cfg = torch.load(user_cfg_dict["cpt_path"])["cfg"] | ||
|
||
# Dict for code rendering | ||
render_cfg = deepcopy(user_cfg_dict) | ||
model_name = train_cfg["model_name"] | ||
model_code = NodeModelFactory.get_source_code(model_name) | ||
render_cfg["model_code"] = model_code | ||
render_cfg["model_class_name"] = NodeModelFactory.get_model_class_name(model_name) | ||
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])) | ||
|
||
# Dict for defining cfg in the rendered code | ||
generated_user_cfg = deepcopy(user_cfg_dict) | ||
generated_user_cfg["data"].pop("name") | ||
generated_user_cfg.pop("pipeline_name") | ||
generated_user_cfg.pop("pipeline_mode") | ||
# model arch configuration | ||
generated_user_cfg["model"] = train_cfg["model"] | ||
|
||
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}" | ||
render_cfg["user_cfg"] = user_cfg_dict | ||
|
||
file_current_dir = Path(__file__).resolve().parent | ||
with open(file_current_dir / "nodepred.jinja-py", "r") as f: | ||
template = Template(f.read()) | ||
|
||
return template.render(**render_cfg) | ||
|
||
@staticmethod | ||
def get_description() -> str: | ||
return "Node classification pipeline for inference" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import dgl | ||
import os | ||
import csv | ||
|
||
from dgl.data import AsNodePredDataset | ||
{{ data_import_code }} | ||
|
||
{{ model_code }} | ||
|
||
def infer(device, data, model): | ||
g = data[0] # Only infer on the first graph | ||
g = dgl.remove_self_loop(g) | ||
g = dgl.add_self_loop(g) | ||
g = g.to(device) | ||
|
||
node_feat = g.ndata.get('feat', None) | ||
edge_feat = g.edata.get('feat', None) | ||
|
||
model = model.to(device) | ||
model.eval() | ||
|
||
with torch.no_grad(): | ||
logits = model(g, node_feat, edge_feat) | ||
|
||
return logits | ||
|
||
def main(): | ||
{{ user_cfg_str }} | ||
|
||
device = cfg['device'] | ||
if not torch.cuda.is_available(): | ||
device = 'cpu' | ||
|
||
# load data | ||
data = AsNodePredDataset({{data_initialize_code}}) | ||
# validation | ||
if cfg['model']['embed_size'] > 0: | ||
model_num_nodes = cfg['model']['data_info']['num_nodes'] | ||
data_num_nodes = data[0].num_nodes() | ||
assert model_num_nodes == data_num_nodes, \ | ||
'Training and inference need to be on the same dataset when node embeddings were learned from scratch' | ||
else: | ||
model_in_size = cfg['model']['data_info']['in_size'] | ||
data_in_size = data[0].ndata['feat'].shape[1] | ||
assert model_in_size == data_in_size, \ | ||
'Expect the training data and inference data to have the same number of input node \ | ||
features, got {:d} and {:d}'.format(model_in_size, data_in_size) | ||
|
||
model = {{ model_class_name }}(**cfg['model']) | ||
model.load_state_dict(torch.load(cfg['cpt_path'], map_location='cpu')['model']) | ||
logits = infer(device, data, model) | ||
pred = logits.argmax(dim=1).cpu() | ||
|
||
# Dump the results | ||
os.makedirs(cfg['general_pipeline']["save_path"]) | ||
file_path = os.path.join(cfg['general_pipeline']["save_path"], 'output.csv') | ||
with open(file_path, 'w') as f: | ||
writer = csv.writer(f) | ||
writer.writerow(['node id', 'predicted label']) | ||
writer.writerows([ | ||
[i, pred[i].item()] for i in range(len(pred)) | ||
]) | ||
print('Saved inference results to {}'.format(file_path)) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gen import * |
Oops, something went wrong.