-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[DGL-Go] Inference for Graph Prediction Pipeline (dmlc#4157)
* Update * Update * Update * Update * Update * Update * Update * Update * Update * Update
- Loading branch information
Showing
8 changed files
with
211 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .nodepred import ApplyNodepredPipeline | ||
from .nodepred_sample import ApplyNodepredNsPipeline | ||
from .graphpred import ApplyGraphpredPipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .gen import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import ruamel.yaml | ||
import torch | ||
import typer | ||
|
||
from copy import deepcopy | ||
from jinja2 import Template | ||
from pathlib import Path | ||
from pydantic import BaseModel, Field | ||
from typing import Optional | ||
|
||
from ...utils.factory import ApplyPipelineFactory, PipelineBase, DataFactory, GraphModelFactory | ||
from ...utils.yaml_dump import deep_convert_dict, merge_comment | ||
|
||
pipeline_comments = { | ||
"batch_size": "Graph batch size", | ||
"num_workers": "Number of workers for data loading", | ||
"save_path": "Directory to save the inference results" | ||
} | ||
|
||
class ApplyGraphpredPipelineCfg(BaseModel): | ||
batch_size: int = 32 | ||
num_workers: int = 4 | ||
save_path: str = "apply_results" | ||
|
||
@ApplyPipelineFactory.register("graphpred") | ||
class ApplyGraphpredPipeline(PipelineBase): | ||
def __init__(self): | ||
self.pipeline = { | ||
"name": "graphpred", | ||
"mode": "apply" | ||
} | ||
|
||
@classmethod | ||
def setup_user_cfg_cls(cls): | ||
from ...utils.enter_config import UserConfig | ||
class ApplyGraphPredUserConfig(UserConfig): | ||
data: DataFactory.filter("graphpred").get_pydantic_config() = Field(..., discriminator="name") | ||
general_pipeline: ApplyGraphpredPipelineCfg = ApplyGraphpredPipelineCfg() | ||
|
||
cls.user_cfg_cls = ApplyGraphPredUserConfig | ||
|
||
@property | ||
def user_cfg_cls(self): | ||
return self.__class__.user_cfg_cls | ||
|
||
def get_cfg_func(self): | ||
def config( | ||
data: DataFactory.filter("graphpred").get_dataset_enum() = typer.Option(None, help="input data name"), | ||
cfg: Optional[str] = typer.Option(None, help="output configuration file path"), | ||
cpt: str = typer.Option(..., help="input checkpoint file path") | ||
): | ||
# Training configuration | ||
train_cfg = torch.load(cpt)["cfg"] | ||
if data is None: | ||
print("data is not specified, use the training dataset") | ||
data = train_cfg["data_name"] | ||
else: | ||
data = data.name | ||
if cfg is None: | ||
cfg = "_".join(["apply", "graphpred", data, train_cfg["model_name"]]) + ".yaml" | ||
|
||
self.__class__.setup_user_cfg_cls() | ||
generated_cfg = { | ||
"pipeline_name": self.pipeline["name"], | ||
"pipeline_mode": self.pipeline["mode"], | ||
"device": train_cfg["device"], | ||
"data": {"name": data}, | ||
"cpt_path": cpt, | ||
"general_pipeline": {"batch_size": train_cfg["general_pipeline"]["eval_batch_size"], | ||
"num_workers": train_cfg["general_pipeline"]["num_workers"]} | ||
} | ||
output_cfg = self.user_cfg_cls(**generated_cfg).dict() | ||
output_cfg = deep_convert_dict(output_cfg) | ||
# Not applicable for inference | ||
output_cfg['data'].pop('split_ratio') | ||
comment_dict = { | ||
"device": "Torch device name, e.g., cpu or cuda or cuda:0", | ||
"cpt_path": "Path to the checkpoint file", | ||
"general_pipeline": pipeline_comments | ||
} | ||
comment_dict = merge_comment(output_cfg, comment_dict) | ||
|
||
yaml = ruamel.yaml.YAML() | ||
yaml.dump(comment_dict, Path(cfg).open("w")) | ||
print("Configuration file is generated at {}".format(Path(cfg).absolute())) | ||
|
||
return config | ||
|
||
@classmethod | ||
def gen_script(cls, user_cfg_dict): | ||
# Check validation | ||
cls.setup_user_cfg_cls() | ||
cls.user_cfg_cls(**user_cfg_dict) | ||
|
||
# Training configuration | ||
train_cfg = torch.load(user_cfg_dict["cpt_path"])["cfg"] | ||
|
||
# Dict for code rendering | ||
render_cfg = deepcopy(user_cfg_dict) | ||
model_name = train_cfg["model_name"] | ||
model_code = GraphModelFactory.get_source_code(model_name) | ||
render_cfg["model_code"] = model_code | ||
render_cfg["model_class_name"] = GraphModelFactory.get_model_class_name(model_name) | ||
render_cfg.update(DataFactory.get_generated_code_dict(user_cfg_dict["data"]["name"])) | ||
|
||
# Dict for defining cfg in the rendered code | ||
generated_user_cfg = deepcopy(user_cfg_dict) | ||
generated_user_cfg.pop("pipeline_name") | ||
generated_user_cfg.pop("pipeline_mode") | ||
# model arch configuration | ||
generated_user_cfg["model"] = train_cfg["model"] | ||
|
||
render_cfg["user_cfg_str"] = f"cfg = {str(generated_user_cfg)}" | ||
render_cfg["user_cfg"] = user_cfg_dict | ||
|
||
file_current_dir = Path(__file__).resolve().parent | ||
with open(file_current_dir / "graphpred.jinja-py", "r") as f: | ||
template = Template(f.read()) | ||
|
||
return template.render(**render_cfg) | ||
|
||
@staticmethod | ||
def get_description() -> str: | ||
return "Graph classification pipeline for inference on binary classification" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
import os | ||
import csv | ||
|
||
from tqdm import tqdm | ||
from dgl.data import AsGraphPredDataset | ||
from dgl.dataloading import GraphDataLoader | ||
{{ data_import_code }} | ||
|
||
{{ model_code }} | ||
|
||
def infer(device, loader, model): | ||
model = model.to(device) | ||
model.eval() | ||
all_pred = [] | ||
|
||
with torch.no_grad(): | ||
for _, (g, labels) in enumerate(tqdm(loader, desc="Iteration")): | ||
g = g.to(device) | ||
node_feat = g.ndata['feat'] | ||
edge_feat = g.edata['feat'] | ||
pred = model(g, node_feat, edge_feat) | ||
pred = (pred.sigmoid() >= 0.5).long() | ||
all_pred.append(pred) | ||
|
||
return torch.cat(all_pred, dim=0) | ||
|
||
def main(): | ||
{{ user_cfg_str }} | ||
|
||
device = cfg['device'] | ||
if not torch.cuda.is_available(): | ||
device = 'cpu' | ||
pipeline_cfg = cfg['general_pipeline'] | ||
|
||
# load data | ||
data = AsGraphPredDataset({{data_initialize_code}}) | ||
data_loader = GraphDataLoader(data, batch_size=pipeline_cfg['batch_size'], | ||
num_workers=pipeline_cfg['num_workers'], shuffle=False) | ||
|
||
# validation | ||
train_data_name = cfg['model']['data_info']['name'] | ||
infer_data_name = cfg['data']['name'] | ||
if train_data_name.startswith('ogbg-mol'): | ||
assert infer_data_name.startswith('ogbg-mol'), 'Expect the inference data name to start \ | ||
with ogbg-mol, got {}'.format(infer_data_name) | ||
else: | ||
assert train_data_name == infer_data_name, 'Expect the training and inference data to \ | ||
have the same name, got {} and {}'.format(train_data_name, infer_data_name) | ||
model_node_feat_size = cfg['model']['data_info']['node_feat_size'] | ||
model_edge_feat_size = cfg['model']['data_info']['edge_feat_size'] | ||
data_node_feat_size = data.node_feat_size | ||
data_edge_feat_size = data.edge_feat_size | ||
assert model_node_feat_size == data_node_feat_size, 'Expect the training data and inference \ | ||
data to have the same number of input node features, got {:d} and {:d}'.format(model_node_feat_size, data_node_feat_size) | ||
assert model_edge_feat_size == data_edge_feat_size, 'Expect the training data and inference \ | ||
data to have the same number of input edge features, got {:d} and {:d}'.format(model_edge_feat_size, data_edge_feat_size) | ||
|
||
model = {{ model_class_name }}(**cfg['model']) | ||
model.load_state_dict(torch.load(cfg['cpt_path'], map_location='cpu')['model']) | ||
pred = infer(device, data_loader, model).detach().cpu() | ||
|
||
# Dump the results | ||
os.makedirs(cfg['general_pipeline']["save_path"]) | ||
file_path = os.path.join(cfg['general_pipeline']["save_path"], 'output.csv') | ||
header = ['graph id'] | ||
header.extend(['task_{:d}'.format(i) for i in range(cfg['model']['data_info']['out_size'])]) | ||
with open(file_path, 'w') as f: | ||
writer = csv.writer(f) | ||
writer.writerow(header) | ||
writer.writerows([ | ||
[i] + pred[i].tolist() for i in range(len(pred)) | ||
]) | ||
print('Saved inference results to {}'.format(file_path)) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters