Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add HGAT #213

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions examples/hgat/hgat_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# !/usr/bin/env python3
# -*- coding:utf-8 -*-

# @Time : 2022/04/16 25:16
# @Author : Jingyu Huang
# @FileName: hgat_trainer.py
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
os.environ['TL_BACKEND'] = 'tensorflow'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
import numpy as np
import argparse
import tensorlayerx as tlx
import gammagl.transforms as T
from gammagl.datasets import AGNews,IMDB, OHSUMED, Twitter


from gammagl.models import HGATModel;
from gammagl.utils import mask_to_index, set_device
from tensorlayerx.model import TrainOneStep, WithLoss

class SemiSpvzLoss(WithLoss):
def __init__(self, net, loss_fn):
super(SemiSpvzLoss, self).__init__(backbone=net, loss_fn=loss_fn)

def forward(self, data, y, node_tpye):
logits = self.backbone_network(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
train_logits = tlx.gather(logits[node_tpye], data['train_idx'])
train_y = tlx.gather(data['y'], data['train_idx'])
loss = self._loss_fn(train_logits, train_y)
return loss


def calculate_acc(logits, y, metrics):
"""
Args:
logits: node logits
y: node labels
metrics: tensorlayerx.metrics

Returns:
rst
"""

metrics.update(logits, y)
rst = metrics.result()
metrics.reset()
return rst


def main(args):
# NOTE: ONLY IMDB DATASET
# If you want to execute HAN on other dataset (e.g. ACM),
# you will be needed to init `metepaths`
# and set `movie` string with proper values.
# path = osp.join(osp.dirname(osp.realpath(__file__)), '../IMDB')
if(args.dataset=="IMDB"):
dataset = IMDB(args.dataset_path)
graph = dataset[0]
print(graph)
y = graph['movie'].y
node_type = 'movie'



if(args.dataset=="agnews"):
dataset = AGNews(args.dataset_path)
graph = dataset[0]
print(graph)
y = graph['text'].y
node_type = 'text'


if(args.dataset=="ohsumed"):
dataset = OHSUMED(args.dataset_path)
graph = dataset[0]
print(graph)
y = graph['documents'].y
node_type = 'documents'



if(args.dataset=="twitter"):
dataset = Twitter(args.dataset_path)
graph = dataset[0]
print(graph)
y = graph['twitter'].y
node_type = 'twitter'



# for mindspore, it should be passed into node indices
train_idx = mask_to_index(graph[node_type].train_mask)
test_idx = mask_to_index(graph[node_type].test_mask)
val_idx = mask_to_index(graph[node_type].val_mask)
node_type_list = graph.metadata()[0]
in_channel = {}
num_nodes_dict = {}
for node_type in node_type_list:
in_channel[node_type]=graph.x_dict[node_type].shape[1]
num_nodes_dict[node_type]=graph.x_dict[node_type].shape[0]


net = HGATModel(
in_channels=in_channel,
out_channels=len(np.unique(graph.y.cpu())), # graph.num_classes,
metadata=graph.metadata(),
drop_rate=args.drop_rate,
hidden_channels=args.hidden_dim,
name = 'hgat',
)


optimizer = tlx.optimizers.Adam(lr=args.lr, weight_decay=args.l2_coef)
metrics = tlx.metrics.Accuracy()
train_weights = net.trainable_weights

loss_func = tlx.losses.softmax_cross_entropy_with_logits
semi_spvz_loss = SemiSpvzLoss(net, loss_func)
train_one_step = TrainOneStep(semi_spvz_loss, optimizer, train_weights)

data = {
"x_dict": graph.x_dict,
"y": y,
"edge_index_dict": graph.edge_index_dict,
"train_idx": train_idx,
"test_idx": test_idx,
"val_idx": val_idx,
"num_nodes_dict": num_nodes_dict,
}
print(np.unique(y.cpu()))
best_val_acc = 0

for epoch in range(args.n_epoch):
net.set_train()
train_loss = train_one_step(data, y, node_type)
net.set_eval()

logits = net(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
val_logits = tlx.gather(logits[node_type], data['val_idx'])
val_y = tlx.gather(data['y'], data['val_idx'])
val_acc = calculate_acc(val_logits, val_y, metrics)

print("Epoch [{:0>3d}] ".format(epoch + 1)
+ " train_loss: {:.4f}".format(train_loss.item())
# + " train_acc: {:.4f}".format(train_acc)
+ " val_acc: {:.4f}".format(val_acc))

# save best model on evaluation set
if val_acc > best_val_acc:
best_val_acc = val_acc
net.save_weights(args.best_model_path + net.name + ".npz", format='npz_dict')

net.load_weights(args.best_model_path + net.name + ".npz", format='npz_dict')
net.set_eval()
logits = net(data['x_dict'], data['edge_index_dict'], data['num_nodes_dict'])
test_logits = tlx.gather(logits[node_type], data['test_idx'])
test_y = tlx.gather(data['y'], data['test_idx'])
test_acc = calculate_acc(test_logits, test_y, metrics)
print("Test acc: {:.4f}".format(test_acc))


if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=0.005, help="learnin rate")
parser.add_argument("--n_epoch", type=int, default=100, help="number of epoch")
parser.add_argument("--hidden_dim", type=int, default=64, help="dimention of hidden layers")
parser.add_argument("--l2_coef", type=float, default=1e-3, help="l2 loss coeficient")
parser.add_argument("--heads", type=int, default=8, help="number of heads for stablization")
parser.add_argument("--drop_rate", type=float, default=0.6, help="drop_rate")
parser.add_argument("--gpu", type=int, default=0, help="gpu id")
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
parser.add_argument('--dataset', type=str, default='IMDB', help='dataset')
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")

args = parser.parse_args()
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")

main(args)
29 changes: 29 additions & 0 deletions examples/hgat/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Heterogeneous Graph Attention Network (HGAT)

This is an implementation of `HAN` for heterogeneous graphs.

- Paper link: [https://aclanthology.org/D19-1488/](https://aclanthology.org/D19-1488/)
- Author's code repo: [https://github.com/BUPT-GAMMA/HGAT](https://github.com/BUPT-GAMMA/HGAT). Note that the original code is
implemented with Tensorflow for the paper.

## Usage

`python hgat_trainer.py` for reproducing HGAT's work on IMDB.



## Performance



| Dataset |Paper(80% training) | Our(tf) | Our(th) | Our(pd) |
| ------- | ------------------ | ------- | ------- |-------- |
| AGNews | 72.10 | 63.80 | | |
| Ohsumed | 42.68 | 25.82 | | |
| Twitter | 63.21 | 61.06 | | |
| IMDB | | 57.71 | | |

```bash
TL_BACKEND="tensorflow" python3 hgat_trainer.py --n_epoch 100 --lr 0.01 --l2_coef 0.0001 --drop_rate 0.8

```
8 changes: 7 additions & 1 deletion gammagl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from .molecule_net import MoleculeNet
from .acm4heco import ACM4HeCo
from .yelp import Yelp
from .agnews import AGNews
from .ohsumed import OHSUMED
from .twitter import Twitter

__all__ = [
'ACM4HeCo',
Expand All @@ -46,7 +49,10 @@
'WikiCS',
'MoleculeNet',
'NGSIM_US_101',
'Yelp'
'Yelp',
'AGNews',
'OHSUMED',
'Twitter'
]

classes = __all__
81 changes: 81 additions & 0 deletions gammagl/datasets/agnews.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import os.path as osp
from itertools import product
from typing import Callable, List, Optional

import numpy as np
import scipy.sparse as sp
import tensorlayerx as tlx

from gammagl.data import (HeteroGraph, InMemoryDataset, download_url,
extract_zip)

class AGNews(InMemoryDataset):
r"""AGNews dataset processed for use in GNN models."""

url = 'https://www.dropbox.com/scl/fi/m809k1xdqzf0rhdmb83jf/agnews.zip?rlkey=wrz4by7f4tvtsdte2scuiec5k&st=s3ty36oi&dl=1'

def __init__(self, root: str = None, transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None, force_reload: bool = False):
super().__init__(root, transform, pre_transform, force_reload=force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return [
'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz',
'labels.npy', 'train_val_test_idx.npz'
]

@property
def processed_file_names(self) -> str:
return tlx.BACKEND + 'data.pt'

def download(self):
path = download_url(self.url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.remove(path)

def process(self):
data = HeteroGraph()

node_types = ['text', 'topic', 'entity']
for i, node_type in enumerate(node_types):
x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz'))
data[node_type].x = tlx.convert_to_tensor(x.todense(), dtype=tlx.float32)
y = np.load(osp.join(self.raw_dir, 'labels.npy'))
y = np.argmax(y,axis=1)
data['text'].y = tlx.convert_to_tensor(y, dtype=tlx.int64)

split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz'))
for name in ['train', 'val', 'test']:
idx = split[f'{name}_idx']
mask = np.zeros(data['text'].num_nodes, dtype=np.bool_)
mask[idx] = True
data['text'][f'{name}_mask'] = tlx.convert_to_tensor(mask, dtype=tlx.bool)


s = {}
N_m = data['text'].num_nodes
N_d = data['topic'].num_nodes
N_a = data['entity'].num_nodes
s['text'] = (0, N_m)
s['topic'] = (N_m, N_m + N_d)
s['entity'] = (N_m + N_d, N_m + N_d + N_a)

A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')).tocsr()
for src, dst in product(node_types, node_types):
A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo()
if A_sub.nnz > 0:
row = tlx.convert_to_tensor(A_sub.row, dtype=tlx.int64)
col = tlx.convert_to_tensor(A_sub.col, dtype=tlx.int64)
data[src, dst].edge_index = tlx.stack([row, col], axis=0)
print(src+"____"+dst)

if self.pre_transform is not None:
data = self.pre_transform(data)

self.save_data(self.collate([data]), self.processed_paths[0])

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
Loading
Loading