Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
AI: GraphSAGE transductive model for heterogeneous graph (#1607)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrychenhf authored Jun 29, 2023
1 parent 2598ffd commit 35121db
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
heterogeneous.distributed.trainer import Trainer
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.inductive.distributed.model import DistInductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.heterogeneous.utils import \
get_in_feats_of_feature
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.transductive.distributed.model import DistTransductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.utils import get_in_feats_of_feature, get_node_types


def main(args):
Expand Down Expand Up @@ -64,7 +66,15 @@ def main(args):
in_feats, args.num_hidden, args.num_layers,
relations=args.relations, node_feature=args.node_feature)
else:
raise NotImplementedError("Transductive model on heterogeneous graph not supported")
# vocab_size is a dict of node type as key
node_types = get_node_types(graph, args.relations)
vocab_size = {k: graph.num_nodes(k) for k in node_types}
model = DistTransductiveGraphSAGEModel(
vocab_size, args.num_hidden, args.num_layers,
relations=args.relations)
model_eval = DistTransductiveGraphSAGEModel(
vocab_size, args.num_hidden, args.num_layers,
relations=args.relations)

trainer = Trainer(model, model_eval, args)
trainer.train(graph)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ def inference(self, g, x, device, batch_size):
for input_nodes, output_nodes, blocks in tqdm.tqdm(
dataloader, desc="Inference"
):
# Inductive: use node feature or id
# Heterogeneous: everything is a dict here
# input_nodes is a dict, output_nodes is dict
# input_block.srcdata[dgl.NID] is a dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
heterogeneous.predictor import Predictor
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.inductive.model import InductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.heterogeneous.utils import \
get_in_feats_of_feature
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.transductive.model import TransductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model. \
heterogeneous.utils import get_in_feats_of_feature, get_node_types


def predict(dataset_dir, model_file,
Expand Down Expand Up @@ -55,7 +57,13 @@ def predict(dataset_dir, model_file,
in_feats, num_hidden, num_layers,
relations=relations, node_feature=node_feature)
else:
raise NotImplementedError("Transductive model on heterogeneous graph not supported")
print("Predicting with a transductive model on heterogeneous graph")
# vocab_size is a dict of node type as key
node_types = get_node_types(graph, relations)
vocab_size = {k: graph.num_nodes(k) for k in node_types}
model = TransductiveGraphSAGEModel(
vocab_size, num_hidden, num_layers,
relations=relations)

predictor = Predictor(
model, model_file=model_file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
heterogeneous.trainer import Trainer
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.inductive.model import InductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.transductive.model import TransductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model. \
heterogeneous.utils import tensor_dict_shape, get_in_feats_of_feature
heterogeneous.utils import tensor_dict_shape, get_in_feats_of_feature, \
get_node_types


def main(args):
Expand Down Expand Up @@ -73,7 +76,12 @@ def main(args):
in_feats, args.num_hidden, args.num_layers,
relations=args.relations, node_feature=args.node_feature)
else:
raise NotImplementedError("Transductive model on heterogeneous graph not supported")
# vocab_size is a dict of node type as key
node_types = get_node_types(graph, args.relations)
vocab_size = {k: graph.num_nodes(k) for k in node_types}
model = TransductiveGraphSAGEModel(
vocab_size, args.num_hidden, args.num_layers,
relations=args.relations)

# train
trainer = Trainer(model, args)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Chen Haifeng
"""

from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.distributed.model import DistGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.transductive.model import TransductiveGraphSAGEModel


class DistTransductiveGraphSAGEModel(DistGraphSAGEModel, TransductiveGraphSAGEModel):
def __init__(self, vocab_size, hidden_size, num_layers, relations):
TransductiveGraphSAGEModel.__init__(
self, vocab_size, hidden_size, num_layers, relations)
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Author: Chen Haifeng
"""

import torch

from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
heterogeneous.model import GraphSAGEModel


class TransductiveGraphSAGEModel(GraphSAGEModel):
def __init__(self, vocab_size, hidden_size, num_layers, relations):
super().__init__(hidden_size, hidden_size, num_layers, relations)

# node embedding
# vocab_size is a dict of vocab_size of each node type in the relations
self.emb = torch.nn.ModuleDict(
{node_type: torch.nn.Embedding(
node_vocab_size, hidden_size) for node_type, node_vocab_size in vocab_size.items()})

def forward(self, pair_graph, neg_pair_graph, blocks, x):
h = {k: self.emb[k](v) for k, v in x.items()}
return super().forward(
pair_graph, neg_pair_graph, blocks, h)

def get_input_embeddings(self):
return {k: v.weight.data for k, v in self.emb}

def get_inputs(self, input_nodes, blocks):
return input_nodes

def get_inference_inputs(self, g):
return self.get_input_embeddings()

def get_encoder_inputs(self, input_nodes, blocks):
x = self.get_inputs(input_nodes, blocks)
return {k: self.emb[k].weight.data[v] for k, v in x}

0 comments on commit 35121db

Please sign in to comment.