Skip to content

Commit

Permalink
add yelp dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
gyzhou2000 committed Jun 5, 2024
1 parent 64d3909 commit 5c7ccfe
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 20 deletions.
20 changes: 2 additions & 18 deletions examples/gnnlfhf/gnnlfhf_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TL_BACKEND'] = 'torch'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import sys
import argparse
Expand Down Expand Up @@ -116,8 +116,6 @@ def main(args):
test_acc = calculate_acc(test_logits, test_y, metrics)
print("Test acc: {:.4f}".format(test_acc))

return test_acc


if __name__ == '__main__':
# parameters setting
Expand Down Expand Up @@ -145,18 +143,4 @@ def main(args):
else:
tlx.set_device("CPU")

import numpy as np

number = []
for i in range(5):
acc = main(args)

number.append(acc)

print("实验结果:")
print(np.mean(number))
print(np.std(number))
print(number)

# main(args)

main(args)
3 changes: 2 additions & 1 deletion gammagl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .heterograph import HeteroGraph
from .dataset import Dataset
from .batch import BatchGraph
from .download import download_url
from .download import download_url, download_google_url
from .in_memory_dataset import InMemoryDataset
from .extract import extract_zip, extract_tar
from .utils import global_config_init
Expand All @@ -14,6 +14,7 @@
'HeteroGraph',
'Dataset',
'download_url',
'download_google_url',
'InMemoryDataset',
'extract_zip',
'extract_tar',
Expand Down
7 changes: 7 additions & 0 deletions gammagl/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ def download_url(url: str, folder: str, log: bool = True,
pbar.update(chunk_size)

return path


def download_google_url(id: str, folder: str,
filename: str, log: bool = True):
r"""Downloads the content of a Google Drive ID to a specific folder."""
url = f'https://drive.usercontent.google.com/download?id={id}&confirm=t'
return download_url(url, folder, log, filename)
4 changes: 3 additions & 1 deletion gammagl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .wikics import WikiCS
from .blogcatalog import BlogCatalog
from .molecule_net import MoleculeNet
from .yelp import Yelp

__all__ = [
'Amazon',
Expand All @@ -40,7 +41,8 @@
'AMiner',
'PolBlogs',
'WikiCS',
'MoleculeNet'
'MoleculeNet',
'Yelp'
]

classes = __all__
115 changes: 115 additions & 0 deletions gammagl/datasets/yelp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import os.path as osp
from typing import Callable, List, Optional

import numpy as np
import scipy.sparse as sp
import tensorlayerx as tlx

from gammagl.data import Graph, InMemoryDataset, download_google_url


class Yelp(InMemoryDataset):
r"""The Yelp dataset from the `"GraphSAINT: Graph Sampling Based
Inductive Learning Method" <https://arxiv.org/abs/1907.04931>`_ paper,
containing customer reviewers and their friendship.
Parameters
----------
root: str, optional
Root directory where the dataset should be saved.
transform: callable, optional
A function/transform that takes in an
:obj:`gammagl.data.Graph` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform: callable, optional
A function/transform that takes in
an :obj:`gammagl.data.Graph` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
Tip
---
.. list-table::
:widths: 10 10 10 10 10
:header-rows: 1
* - #nodes
- #edges
- #features
- #tasks
* - 716,847
- 13,954,819
- 300
- 100
"""

adj_full_id = '1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA'
feats_id = '1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM'
class_map_id = '1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU'
role_id = '1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_'

def __init__(
self,
root: str = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.data, self.slices = self.load_data(self.processed_paths[0])

@property
def raw_file_names(self) -> List[str]:
return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']

@property
def processed_file_names(self) -> str:
return 'data.pt'

def download(self) -> None:
download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')
download_google_url(self.feats_id, self.raw_dir, 'feats.npy')
download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')
download_google_url(self.role_id, self.raw_dir, 'role.json')

def process(self) -> None:
f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))
adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])
adj = adj.tocoo()
row = tlx.convert_to_tensor(adj.row, dtype=tlx.int64)
col = tlx.convert_to_tensor(adj.col, dtype=tlx.int64)
edge_index = tlx.stack([row, col], axis=0)

x = np.load(osp.join(self.raw_dir, 'feats.npy'))
x = tlx.convert_to_tensor(x, dtype=tlx.float32)

ys = [-1] * x.size(0)
with open(osp.join(self.raw_dir, 'class_map.json')) as f:
class_map = json.load(f)
for key, item in class_map.items():
ys[int(key)] = item
y = tlx.convert_to_tensor(ys)

with open(osp.join(self.raw_dir, 'role.json')) as f:
role = json.load(f)

train_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
train_mask[tlx.convert_to_tensor(role['tr'])] = True

val_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
val_mask[tlx.convert_to_tensor(role['va'])] = True

test_mask = tlx.zeros((x.shape[0],), dtype=tlx.bool)
test_mask[tlx.convert_to_tensor(role['te'])] = True

data = Graph(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
val_mask=val_mask, test_mask=test_mask)

data = data if self.pre_transform is None else self.pre_transform(data)

self.save_data(self.collate([data]), self.processed_paths[0])
15 changes: 15 additions & 0 deletions tests/datasets/test_yelp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
# os.environ['TL_BACKEND'] = 'paddle'
import tensorlayerx as tlx
from gammagl.datasets import Yelp

def yelp():
dataset = Yelp()
graph = dataset[0]
assert len(dataset) == 1
assert dataset.num_classes == 100
assert dataset.num_node_features == 300
assert graph.edge_index.shape[1] == 13954819
assert graph.x.shape[0] == 716847

0 comments on commit 5c7ccfe

Please sign in to comment.