Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
AI: GraphSAGE final fixes for removing original dataset in distribute…
Browse files Browse the repository at this point in the history
…d training (#1615)
  • Loading branch information
jerrychenhf authored Jul 1, 2023
1 parent 9cfd129 commit fdf3c61
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def get_eids_mask(g, relations, mask_name, reverse_etypes=None):

def get_eids_from_mask(g, relations, mask_name, reverse_etypes=None):
eids_mask = get_eids_mask(g, relations, mask_name, reverse_etypes)
# eids_mask is dict with DistTensor, convert to a torch Tensor by copying
eids_mask = {k: v[0: v.shape[0]] for k, v in eids_mask.items()}
return {k: torch.nonzero(v, as_tuple=False).squeeze() for k, v in eids_mask.items()}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
homogeneous.transductive.model import TransductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.\
homogeneous.inductive.model import InductiveGraphSAGEModel
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.utils import get_common_node_features, \
get_common_edge_features
from cloudtik.runtime.ai.modeling.graph_modeling.graph_sage.modeling.model.utils \
import get_common_node_features, get_common_edge_features, get_in_feats_of_feature


def predict(dataset_dir, model_file,
Expand Down Expand Up @@ -52,7 +52,7 @@ def predict(dataset_dir, model_file,
print("Predicting with an inductive model on homogeneous graph")
in_feats = 1
if node_feature:
in_feats = graph.ndata[node_feature].shape[1]
in_feats = get_in_feats_of_feature(graph, node_feature)
model = InductiveGraphSAGEModel(
in_feats, num_hidden, num_layers,
node_feature=node_feature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,9 @@ def get_common_features(g, node_or_edge):
def get_in_feats_of_feature(g, node_feature):
in_feats = 1
if node_feature:
# Use the first node type
ntype = g.ntypes[0]
# The feature dimension must be the same for all the nodes
# for the time being
feature = g.ndata[node_feature]
if not feature:
raise RuntimeError("The graph has no node feature: ".format(
node_feature))
if isinstance(feature, dict):
any_value = next(iter(feature.values()))
else:
any_value = feature
in_feats = any_value.shape[1]
in_feats = g.nodes[ntype].data[node_feature].shape[1]
return in_feats

0 comments on commit fdf3c61

Please sign in to comment.