diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a223b9c2c..d61941412 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,11 +21,6 @@ jobs: TEST_DEVICES: "" run: | source activate /home/admin/tf12_py2/ - if [ ! -e "/tmp/easyrec_data_20220113.tar.gz" ] - then - wget https://easyrec.oss-cn-beijing.aliyuncs.com/data/easyrec_data_20220113.tar.gz -O /tmp/easyrec_data_20220113.tar.gz - fi - tar -zvxf /tmp/easyrec_data_20220113.tar.gz source scripts/ci_test.sh - name: LabelAndComment env: diff --git a/docs/images/models/eges_1.png b/docs/images/models/eges_1.png new file mode 100644 index 000000000..900f847ce Binary files /dev/null and b/docs/images/models/eges_1.png differ diff --git a/docs/images/models/eges_2.png b/docs/images/models/eges_2.png new file mode 100644 index 000000000..a10de3c17 Binary files /dev/null and b/docs/images/models/eges_2.png differ diff --git a/docs/source/models/eges.md b/docs/source/models/eges.md new file mode 100644 index 000000000..455bb3adf --- /dev/null +++ b/docs/source/models/eges.md @@ -0,0 +1,170 @@ +# EGES + +### 简介 + +图i2i召回模型, 通过在图上随机游走生成随机路径,然后在路径上使用skip-gram算法进行训练. +![eges_1](../../images/models/eges_1.png) +![eges_2](../../images/models/eges_2.png) + +### 配置说明 +#### 输入配置 +```protobuf +graph_train_input { + node_inputs: 'data/test/graph/taobao_data/ad_feature_5k.csv' + edge_inputs: 'data/test/graph/taobao_data/graph_edges.txt' +} + +graph_eval_input { + node_inputs: 'data/test/graph/taobao_data/ad_feature_5k.csv' + edge_inputs: 'data/test/graph/taobao_data/graph_edges.txt' +} +``` +- node_inputs: 图中的节点为item, 这个输入给出了item的节点id, weight(采样时使用)和feature(side info)信息 + - 示例输入如下: + ``` + id:int64 weight:float feature:string + 521512 1 521512,4282,173332,237,NULL,298.0 + 476210 1 476210,4292,418411,515,377957,249.0 + 646682 1 646682,7205,365036,676,321803,9.9 + ... + ``` +- edge_inputs: 图中的边描述item在同一个session共现的频率 + - 示例输入如下: + ``` + src_id:int64 dst_id:int64 weight:float + 565248 565248 100 + 565248 786433 2 + 565248 638980 20 + ... + ``` +- node_inputs和edge_inputs在MaxCompute上的输入类似,每一列存放成一个column + - node表包含3列:id, weight, feature + - edge表包含3列:src_id, dst_id, weight + - int64对应的类型是bigint + - float对应的类型是double + - string对应的类型是string + +#### 数据配置 +```protobuf +data_config { + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: DOUBLE + } + + graph_config { + random_walk_len: 10 + window_size: 5 + negative_num: 10 + directed: true + } + + batch_size: 64 + num_epochs: 2 + prefetch_size: 32 + input_type: GraphInput +} +``` + +- input_fields: + - input_name: 输入特征名, 对应odps表的字段名或者csv文件的header名(如果没有header,按照字段顺序一一对应) + - input_type: 数据类型, STRING, DOUBLE, INT32, INT64, 不设置默认为STRING +- graph_config: 图上随机游走相关的参数 + - walk_len: 随机游走的长度 + - window_size: skip-gram的窗口大小 + - negative_num: 负采样时每个正样本对应的负样本数目 + - directed: 是否是有向图, 默认是false +- batch_size: 随机游走起始节点的数量 +- num_epochs: 数据过多少遍 +- prefetch_size: 数据预取的batch_size数目 +- input_type: 输入数据格式,针对图类型的算法默认为GraphInput + +#### 特征配置 +```protobuf +feature_config: { + features: { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + ... + features: { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'price' + feature_type: RawFeature + } +} +``` +- features.input_names: 特征的输入,对应data_config.input_fields.input_name + +#### 模型配置 +```protobuf +model_config:{ + model_class: "EGES" + + feature_groups: { + group_name: "item" + feature_names: 'adgroup_id' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + wide_deep:DEEP + } + eges { + dnn { + hidden_units: [256, 128, 64, 32] + } + l2_regularization: 1e-6 + } + loss_type: SOFTMAX_CROSS_ENTROPY + embedding_regularization: 0.0 + + group_as_scope: true +} +``` +- model_class: 默认为EGES +- feature_groups: 特征组,需要配置一个特征组, group_name为item,不能变 + - feature_names: 对应data_config.features.input_names[0](或者feature_name, 如果有设置) +- eges: dnn为特征变换mlp +- loss_type: SOFTMAX_CROSS_ENTROPY,因为有负采样在 +- group_as_scope: 使用group_name作为embedding等variable的scope_name,建议设置成true + +### 示例Config +[EGES_demo.config](https://easyrec.oss-cn-beijing.aliyuncs.com/config/eges_on_taobao.config) + +### 参考论文 +[EGES.pdf](https://arxiv.org/pdf/1803.02349.pdf) diff --git a/docs/source/models/recall.rst b/docs/source/models/recall.rst index 2b0839471..7a9028004 100644 --- a/docs/source/models/recall.rst +++ b/docs/source/models/recall.rst @@ -8,6 +8,7 @@ dssm_neg_sampler mind co_metric_learning_i2i + eges 冷启动召回模型 ======== diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index d68ee618a..a4f9092d9 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -299,7 +299,8 @@ def custom_early_stop_hook(estimator, if eval_dir is None: eval_dir = estimator.eval_dir() - if isinstance(custom_stop_func, str) or isinstance(custom_stop_func, unicode): + if isinstance(custom_stop_func, str) or isinstance(custom_stop_func, + type(u'')): custom_stop_func = load_by_path(custom_stop_func) def _custom_stop_fn(): diff --git a/easy_rec/python/compat/feature_column/feature_column.py b/easy_rec/python/compat/feature_column/feature_column.py index fbadcf2d8..172173dc4 100644 --- a/easy_rec/python/compat/feature_column/feature_column.py +++ b/easy_rec/python/compat/feature_column/feature_column.py @@ -228,8 +228,12 @@ def _get_logits(): # pylint: disable=missing-docstring if from_template: return _get_logits() else: + reuse = None if scope is None else variable_scope.AUTO_REUSE with variable_scope.variable_scope( - scope, default_name='input_layer', values=features.values()): + scope, + default_name='input_layer', + values=features.values(), + reuse=reuse): return _get_logits() @@ -239,7 +243,8 @@ def input_layer(features, trainable=True, cols_to_vars=None, cols_to_output_tensors=None, - feature_name_to_output_tensors=None): + feature_name_to_output_tensors=None, + scope=None): """Returns a dense `Tensor` as input layer based on given `feature_columns`. Generally a single example in training data is described with FeatureColumns. @@ -287,6 +292,7 @@ def input_layer(features, cols_to_output_tensors: If not `None`, must be a dictionary that will be filled with a mapping from '_FeatureColumn' to the associated output `Tensor`s. + scope: variable scope. Returns: A `Tensor` which represents input layer of a model. Its shape @@ -303,7 +309,8 @@ def input_layer(features, trainable=trainable, cols_to_vars=cols_to_vars, cols_to_output_tensors=cols_to_output_tensors, - feature_name_to_output_tensors=feature_name_to_output_tensors) + feature_name_to_output_tensors=feature_name_to_output_tensors, + scope=scope) # TODO(akshayka): InputLayer should be a subclass of Layer, and it diff --git a/easy_rec/python/core/sampler.py b/easy_rec/python/core/sampler.py index ca3a8f15d..2d7bc1d2a 100644 --- a/easy_rec/python/core/sampler.py +++ b/easy_rec/python/core/sampler.py @@ -3,7 +3,6 @@ from __future__ import division from __future__ import print_function -import json import logging import math import os @@ -13,6 +12,7 @@ import tensorflow as tf from easy_rec.python.protos.dataset_pb2 import DatasetConfig +from easy_rec.python.utils import graph_utils try: import graphlearn as gl @@ -76,46 +76,7 @@ def __init__(self, fields, num_sample, num_eval_sample=None): self._build_field_types(fields) def _init_graph(self): - if 'TF_CONFIG' in os.environ: - tf_config = json.loads(os.environ['TF_CONFIG']) - if 'ps' in tf_config['cluster']: - # ps mode - tf_config = json.loads(os.environ['TF_CONFIG']) - ps_count = len(tf_config['cluster']['ps']) - task_count = len(tf_config['cluster']['worker']) + 2 - cluster = {'server_count': ps_count, 'client_count': task_count} - if tf_config['task']['type'] in ['chief', 'master']: - self._g.init(cluster=cluster, job_name='client', task_index=0) - elif tf_config['task']['type'] == 'worker': - self._g.init( - cluster=cluster, - job_name='client', - task_index=tf_config['task']['index'] + 2) - # TODO(hongsheng.jhs): check cluster has evaluator or not? - elif tf_config['task']['type'] == 'evaluator': - self._g.init( - cluster=cluster, - job_name='client', - task_index=tf_config['task']['index'] + 1) - if self._num_eval_sample is not None and self._num_eval_sample > 0: - self._num_sample = self._num_eval_sample - elif tf_config['task']['type'] == 'ps': - self._g.init( - cluster=cluster, - job_name='server', - task_index=tf_config['task']['index']) - else: - # worker mode - task_count = len(tf_config['cluster']['worker']) + 1 - if tf_config['task']['type'] in ['chief', 'master']: - self._g.init(task_index=0, task_count=task_count) - elif tf_config['task']['type'] == 'worker': - self._g.init( - task_index=tf_config['task']['index'] + 1, task_count=task_count) - # TODO(hongsheng.jhs): check cluster has evaluator or not? - else: - # local mode - self._g.init() + graph_utils.graph_init(self._g, os.environ.get('TF_CONFIG', None)) def _build_field_types(self, fields): self._attr_names = [] diff --git a/easy_rec/python/inference/vector_retrieve.py b/easy_rec/python/inference/vector_retrieve.py index 4baca38db..7da123579 100644 --- a/easy_rec/python/inference/vector_retrieve.py +++ b/easy_rec/python/inference/vector_retrieve.py @@ -13,7 +13,7 @@ try: import graphlearn as gl -except: +except ImportError: logging.WARN( 'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl.' # noqa: E501 ) diff --git a/easy_rec/python/input/csv_input.py b/easy_rec/python/input/csv_input.py index 50ecb668a..92fd29109 100644 --- a/easy_rec/python/input/csv_input.py +++ b/easy_rec/python/input/csv_input.py @@ -24,6 +24,7 @@ def __init__(self, super(CSVInput, self).__init__(data_config, feature_config, input_path, task_index, task_num) self._with_header = data_config.with_header + # only for csv file with headers self._field_names = None def _parse_csv(self, line): diff --git a/easy_rec/python/input/graph_input.py b/easy_rec/python/input/graph_input.py new file mode 100644 index 000000000..758d508b1 --- /dev/null +++ b/easy_rec/python/input/graph_input.py @@ -0,0 +1,257 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging +import os +import sys + +import numpy as np +import tensorflow as tf + +from easy_rec.python.core import sampler +from easy_rec.python.input.input import Input +from easy_rec.python.utils import graph_utils + +try: + import graphlearn as gl +except ImportError: + logging.error( + 'GraphLearn is not installed. You can install it by "pip install https://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/graphlearn-0.7-cp27-cp27mu-linux_x86_64.whl"' # noqa: E501 + ) + sys.exit(1) + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class GraphInput(Input): + + node_type = 'item' + edge_type = 'relation' + + graph = None + + def __init__(self, + data_config, + feature_configs, + input_path, + task_index=0, + task_num=1): + super(GraphInput, self).__init__(data_config, feature_configs, input_path, + task_index, task_num) + self._model = None + self._build_field_types() + + self._walk_len = self._data_config.graph_config.random_walk_len + self._window_size = self._data_config.graph_config.window_size + self._negative_num = self._data_config.graph_config.negative_num + logging.info('walk_len=%d window_size=%d negative_num=%d' % + (self._walk_len, self._window_size, self._negative_num)) + + # build item co-occurance graph, the weight is co-occurance frequency + if input_path: + if GraphInput.graph is None: + GraphInput.graph = gl.Graph().node( + str(input_path.node_inputs[0]), + node_type=GraphInput.node_type, + decoder=gl.Decoder( + attr_types=self._attr_gl_types, + weighted=True, + attr_delimiter=',')).edge( + str(input_path.edge_inputs[0]), + edge_type=(GraphInput.node_type, GraphInput.node_type, + GraphInput.edge_type), + decoder=gl.Decoder(weighted=True), + directed=data_config.graph_config.directed) + graph_utils.graph_init(GraphInput.graph, + os.environ.get('TF_CONFIG', None)) + if GraphInput.graph is not None: + self._neg_sampler = GraphInput.graph.negative_sampler( + GraphInput.node_type, + expand_factor=self._negative_num, + strategy='node_weight') + + def _build_field_types(self): + """Build field types for item features.""" + self._attr_names = [] + self._attr_types = [] + self._attr_gl_types = [] + self._attr_np_types = [] + self._attr_tf_types = [] + for field_name, field_type in zip(self._input_fields, + self._input_field_types): + self._attr_names.append(field_name) + self._attr_types.append(field_type) + self._attr_gl_types.append(sampler._get_gl_type(field_type)) + self._attr_np_types.append(sampler._get_np_type(field_type)) + self._attr_tf_types.append(sampler._get_tf_type(field_type)) + + def _parse_nodes(self, nodes): + features = [] + int_idx = 0 + float_idx = 0 + string_idx = 0 + + for attr_gl_type, attr_np_type in zip(self._attr_gl_types, + self._attr_np_types): + if attr_gl_type == 'int': + if len(nodes.shape) == 1: + feature = nodes.int_attrs[:, int_idx] + elif len(nodes.shape) == 2: + feature = nodes.int_attrs[:, :, int_idx] + int_idx += 1 + elif attr_gl_type == 'float': + if len(nodes.shape) == 1: + feature = nodes.float_attrs[:, float_idx] + elif len(nodes.shape) == 2: + feature = nodes.float_attrs[:, :, float_idx] + float_idx += 1 + elif attr_gl_type == 'string': + if len(nodes.shape) == 1: + feature = nodes.string_attrs[:, string_idx] + elif len(nodes.shape) == 2: + feature = nodes.string_attrs[:, :, string_idx] + string_idx += 1 + else: + raise ValueError('Unknown attr type %s' % attr_gl_type) + feature = np.reshape(feature, [-1]).astype(attr_np_type) + if attr_gl_type == 'string': + feature = np.asarray(feature, order='C', dtype=object) + features.append(feature) + return features + + def _gen_pair(self, path, left_window_size, right_window_size): + """Generate skip-gram pairs as positive pairs. + + Args: + path: a list of ids start with root node's ids, each element is 1d numpy array + with the same size. + Returns: + a pair of numpy array ids. + + Example: + >>> path = [np.array([1, 2]), np.array([3, 4]), np.array([5, 6])] + >>> left_window_size = right_window_size = 1 + >>> src_id, dst_ids = self._gen_pair(path, left_window_size, right_window_size) + >>> print print(src_ids, dst_ids) + >>> (array([1, 2, 3, 4, 3, 4, 5, 6]), array([3, 4, 1, 2, 5, 6, 3, 4])) + """ + path_len = len(path) + pairs = [[], []] # [src ids list, dst ids list] + + for center_idx in range(path_len): + cursor = 0 + while center_idx - cursor > 0 and cursor < left_window_size: + pairs[0].append(path[center_idx]) + pairs[1].append(path[center_idx - cursor - 1]) + cursor += 1 + + cursor = 0 + while center_idx + cursor + 1 < path_len and cursor < right_window_size: + pairs[0].append(path[center_idx]) + pairs[1].append(path[center_idx + cursor + 1]) + cursor += 1 + return np.concatenate(pairs[0]), np.concatenate(pairs[1]) + + def _sample_generator(self): + epoch_id = 0 + while self.num_epochs is None or epoch_id < self.num_epochs: + # sample start nodes + if self._mode == tf.estimator.ModeKeys.TRAIN: + start_nodes = GraphInput.graph.V(GraphInput.node_type).batch( + self._data_config.batch_size).alias('rand_walk_0') + else: + start_nodes = GraphInput.graph.V(GraphInput.node_type).batch( + self._data_config.batch_size).alias('rand_walk_0') + # sample paths + for i in range(1, self._walk_len): + out_alias = 'rand_walk_%d' % i + start_nodes = start_nodes.outV( + GraphInput.edge_type).sample(1).by('random').alias(out_alias) + + ds = gl.Dataset(start_nodes.values()) + + while True: + try: + paths = ds.next() + paths = [ + paths['rand_walk_%d' % i].ids.reshape([-1]) + for k in range(0, self._walk_len) + ] + # build positive pairs + src_ids, dst_ids = self._gen_pair(paths, self._window_size, + self._window_size) + src_nodes = GraphInput.graph.get_nodes(GraphInput.node_type, src_ids) + dst_nodes = GraphInput.graph.get_nodes(GraphInput.node_type, dst_ids) + neg_nodes = self._neg_sampler.get(dst_ids) + + src_node_fea_arr = self._parse_nodes(src_nodes) + dst_node_fea_arr = self._parse_nodes(dst_nodes) + neg_node_fea_arr = self._parse_nodes(neg_nodes) + + yield tuple(src_node_fea_arr + dst_node_fea_arr + neg_node_fea_arr) + except gl.OutOfRangeError: + break + if self._mode != tf.estimator.ModeKeys.TRAIN: + break + epoch_id += 1 + + def _to_fea_dict(self, *features): + fea_num = len(self._input_fields) + assert fea_num * 3 == len(features) + fea_dict_groups = {'src_fea': {}, 'positive_fea': {}, 'negative_fea': {}} + for fid, fea in enumerate(features[:fea_num]): + fea_dict_groups['src_fea'][self._input_fields[fid]] = fea + for fid, fea in enumerate(features[fea_num:(fea_num * 2)]): + fea_dict_groups['positive_fea'][self._input_fields[fid]] = fea + for fid, fea in enumerate(features[(fea_num * 2):]): + fea_dict_groups['negative_fea'][self._input_fields[fid]] = fea + return fea_dict_groups + + def _group_preprocess(self, field_dict_groups): + for g in field_dict_groups: + field_dict_groups[g] = self._preprocess(field_dict_groups[g]) + return field_dict_groups + + def _get_labels(self, field_dict): + return { + 'positive_fea': field_dict['positive_fea'], + 'negative_fea': field_dict['negative_fea'] + } + + def _get_features(self, field_dict_groups): + return field_dict_groups['src_fea'] + + def _build(self, mode, params): + """Build graph dataset input for estimator. + + Args: + mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT) + params: `dict` of hyper parameters, from Estimator + + Return: + dataset: dataset for graph models. + """ + self._mode = mode + # get input type + list_type = [self.get_tf_type(x) for x in self._input_field_types] + list_type = tuple(list_type) + list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))] + list_shapes = tuple(list_shapes) + dataset = tf.data.Dataset.from_generator( + self._sample_generator, + output_types=list_type * 3, + output_shapes=list_shapes * 3) + + # transform list to feature dict + dataset = dataset.map(map_func=self._to_fea_dict) + + dataset = dataset.map( + map_func=self._group_preprocess, + num_parallel_calls=self._data_config.num_parallel_calls) + dataset = dataset.prefetch(buffer_size=self._prefetch_size) + if mode != tf.estimator.ModeKeys.PREDICT: + dataset = dataset.map(lambda x: + (self._get_features(x), self._get_labels(x))) + else: + dataset = dataset.map(lambda x: (self._get_features(x))) + return dataset diff --git a/easy_rec/python/input/hive_input.py b/easy_rec/python/input/hive_input.py index e4a978e74..c96b59391 100644 --- a/easy_rec/python/input/hive_input.py +++ b/easy_rec/python/input/hive_input.py @@ -4,11 +4,15 @@ import numpy as np import tensorflow as tf -from pyhive import hive from easy_rec.python.input.input import Input from easy_rec.python.utils import odps_util +try: + from pyhive import hive +except ImportError: + logging.warning('pyhive is not installed.') + class TableInfo(object): @@ -42,7 +46,9 @@ def gen_sql(self): sql = """select {} from {}""".format(self.selected_cols, self.tablename) assert self.hash_fields is not None, 'hash_fields must not be empty' - fields = ['cast({} as string)'.format(key) for key in self.hash_fields.split(',')] + fields = [ + 'cast({} as string)'.format(key) for key in self.hash_fields.split(',') + ] str_fields = ','.join(fields) if not part: sql += """ @@ -59,12 +65,7 @@ def gen_sql(self): class HiveManager(object): - def __init__(self, - host, - port, - username, - info, - database='default'): + def __init__(self, host, port, username, info, database='default'): self.host = host self.port = port self.username = username diff --git a/easy_rec/python/input/input.py b/easy_rec/python/input/input.py index c0d8653bf..db81a8d2c 100644 --- a/easy_rec/python/input/input.py +++ b/easy_rec/python/input/input.py @@ -654,8 +654,8 @@ def _input_fn(mode=None, params=None, config=None): Args: mode: tf.estimator.ModeKeys.(TRAIN, EVAL, PREDICT) - params: `dict` of hyper parameters, from Estimator - config: tf.estimator.RunConfig instance + params: `dict` of hyper parameters, from Estimator + config: tf.estimator.RunConfig instance Return: if mode is not None, return: diff --git a/easy_rec/python/layers/common_layers.py b/easy_rec/python/layers/common_layers.py index 883f2a67c..80ad1496f 100644 --- a/easy_rec/python/layers/common_layers.py +++ b/easy_rec/python/layers/common_layers.py @@ -65,8 +65,8 @@ def text_cnn(x, # conv shape: (batch_size, seq_len - filter_size + 1, num_filters) conv = tf.layers.conv1d( x, - filters=num_filter, - kernel_size=filter_size, + filters=int(num_filter), + kernel_size=int(filter_size), activation=tf.nn.relu, name='conv_layer', reuse=reuse, diff --git a/easy_rec/python/layers/input_layer.py b/easy_rec/python/layers/input_layer.py index 001873abb..517d028f4 100644 --- a/easy_rec/python/layers/input_layer.py +++ b/easy_rec/python/layers/input_layer.py @@ -3,6 +3,7 @@ import logging import tensorflow as tf +from tensorflow.python.ops import variable_scope from easy_rec.python.compat import regularizers from easy_rec.python.compat.feature_column import feature_column @@ -34,7 +35,22 @@ def __init__(self, use_embedding_variable=False, embedding_regularizer=None, kernel_regularizer=None, - is_training=False): + is_training=False, + group_as_scope=False): + """Build an input_layer to generate features specified by feature_configs. + + Args: + feature_configs: feature_config.features in pipeline.config. + feature_groups_config: feature_groups defined in model_config. + variational_dropout_config: for variational dropout. + wide_output_dim: for wide_and_deep and deepfm models, the wide part generates embedding. + use_embedding_variable: whether to use sparse embedding(kv indexed). + embedding_regularizer: regularization loss over the embedding_lookup results. + kernel_regularizer: regularization loss over dnn kernel parameters. + is_training: true if train phase, otherwise false. + group_as_scope: use group name as variable scope name to ensure embedding + sharing between different feature groups. + """ self._feature_groups = { x.group_name: FeatureGroup(x) for x in feature_groups_config } @@ -64,6 +80,7 @@ def __init__(self, self._kernel_regularizer = kernel_regularizer self._is_training = is_training self._variational_dropout_config = variational_dropout_config + self._group_as_scope = group_as_scope def has_group(self, group_name): return group_name in self._feature_groups @@ -197,18 +214,24 @@ def single_call_input_layer(self, feature_group = self._feature_groups[group_name] group_columns, group_seq_columns = feature_group.select_columns( self._fc_parser) + scope_name = group_name if self._group_as_scope else None + reuse = variable_scope.AUTO_REUSE if self._group_as_scope else None if is_combine: cols_to_output_tensors = {} output_features = feature_column.input_layer( features, group_columns, cols_to_output_tensors=cols_to_output_tensors, - feature_name_to_output_tensors=feature_name_to_output_tensors) + feature_name_to_output_tensors=feature_name_to_output_tensors, + scope=scope_name) embedding_reg_lst = [output_features] builder = feature_column._LazyBuilder(features) seq_features = [] + if scope_name is None: + scope_name = 'input_layer' for column in sorted(group_seq_columns, key=lambda x: x.name): - with tf.variable_scope(None, default_name=column._var_scope_name): + with variable_scope.variable_scope( + scope_name + '/' + column._var_scope_name, reuse=reuse): seq_feature, seq_len = column._get_sequence_dense_tensor(builder) embedding_reg_lst.append(seq_feature) @@ -241,7 +264,8 @@ def single_call_input_layer(self, seq_features.append(cnn_feature) cols_to_output_tensors[column] = cnn_feature else: - raise NotImplementedError + raise NotImplementedError('unknown sequence combiner type: %s' % + sequence_combiner.WhichOneof('combiner')) if self._variational_dropout_config is not None: features_dimension = [ cols_to_output_tensors[x].get_shape()[-1] for x in group_columns @@ -266,8 +290,11 @@ def single_call_input_layer(self, builder = feature_column._LazyBuilder(features) seq_features = [] embedding_reg_lst = [] + if scope_name is None: + scope_name = 'input_layer' for fc in group_seq_columns: - with tf.variable_scope('input_layer/' + fc.categorical_column.name): + with variable_scope.variable_scope( + scope_name + '/' + fc._var_scope_name, reuse=reuse): tmp_embedding, tmp_seq_len = fc._get_sequence_dense_tensor(builder) seq_features.append((tmp_embedding, tmp_seq_len)) embedding_reg_lst.append(tmp_embedding) diff --git a/easy_rec/python/main.py b/easy_rec/python/main.py index cbaaf5ed2..332fbdc7a 100644 --- a/easy_rec/python/main.py +++ b/easy_rec/python/main.py @@ -246,29 +246,15 @@ def train_and_evaluate(pipeline_config_path, continue_train=False): return pipeline_config -def _get_input_object_by_name(pipeline_config, worker_type): - """" - get object by worker type +def _get_input_object_by_task_type(pipeline_config, task_type): + """Get subclass input by task type. - pipeline_config: pipeline_config - worker_type: train or eval + pipeline_config: pipeline_config + task_type: train or eval """ - input_type = "{}_path".format(worker_type) + input_type = '{}_path'.format(task_type) input_name = pipeline_config.WhichOneof(input_type) - _dict = {"kafka_train_input": pipeline_config.kafka_train_input, - "kafka_eval_input": pipeline_config.kafka_eval_input, - "datahub_train_input": pipeline_config.datahub_train_input, - "datahub_eval_input": pipeline_config.datahub_eval_input, - "hive_train_input": pipeline_config.hive_train_input, - "hive_eval_input": pipeline_config.hive_eval_input - } - if input_name in _dict: - return _dict[input_name] - - if worker_type == "train": - return pipeline_config.train_input_path - else: - return pipeline_config.eval_input_path + return getattr(pipeline_config, input_name) def _train_and_evaluate_impl(pipeline_config, continue_train=False): @@ -284,8 +270,8 @@ def _train_and_evaluate_impl(pipeline_config, continue_train=False): % pipeline_config.train_config.train_distribute) pipeline_config.train_config.sync_replicas = False - train_data = _get_input_object_by_name(pipeline_config, 'train') - eval_data = _get_input_object_by_name(pipeline_config, 'eval') + train_data = _get_input_object_by_task_type(pipeline_config, 'train') + eval_data = _get_input_object_by_task_type(pipeline_config, 'eval') distribution = strategy_builder.build(train_config) estimator, run_config = _create_estimator( @@ -362,10 +348,7 @@ def evaluate(pipeline_config, pipeline_config.eval_input_path = eval_data_path train_config = pipeline_config.train_config - if pipeline_config.WhichOneof('eval_path') == 'kafka_eval_input': - eval_data = pipeline_config.kafka_eval_input - else: - eval_data = pipeline_config.eval_input_path + eval_data = _get_input_object_by_task_type(pipeline_config, 'eval') server_target = None if 'TF_CONFIG' in os.environ: diff --git a/easy_rec/python/model/deepfm.py b/easy_rec/python/model/deepfm.py index 4e96bd024..2de68ef41 100644 --- a/easy_rec/python/model/deepfm.py +++ b/easy_rec/python/model/deepfm.py @@ -49,7 +49,8 @@ def build_input_layer(self, model_config, feature_configs): wide_output_dim=model_config.deepfm.wide_output_dim, use_embedding_variable=model_config.use_embedding_variable, embedding_regularizer=self._emb_reg, - kernel_regularizer=self._l2_reg) + kernel_regularizer=self._l2_reg, + group_as_scope=model_config.group_as_scope) def build_predict_graph(self): # Wide diff --git a/easy_rec/python/model/easy_rec_estimator.py b/easy_rec/python/model/easy_rec_estimator.py index 77ec6bbf0..7ec25421a 100644 --- a/easy_rec/python/model/easy_rec_estimator.py +++ b/easy_rec/python/model/easy_rec_estimator.py @@ -86,6 +86,7 @@ def _train_model_fn(self, features, labels, run_config): is_training=True) predict_dict = model.build_predict_graph() loss_dict = model.build_loss_graph() + assert loss_dict is not None, 'loss_dict should be returned by build_loss_graph' regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) diff --git a/easy_rec/python/model/easy_rec_model.py b/easy_rec/python/model/easy_rec_model.py index 5f8256e27..10fb0513e 100644 --- a/easy_rec/python/model/easy_rec_model.py +++ b/easy_rec/python/model/easy_rec_model.py @@ -85,7 +85,8 @@ def build_input_layer(self, model_config, feature_configs): kernel_regularizer=self._l2_reg, variational_dropout_config=model_config.variational_dropout if model_config.HasField('variational_dropout') else None, - is_training=False) + is_training=False, + group_as_scope=model_config.group_as_scope) @abstractmethod def build_predict_graph(self): diff --git a/easy_rec/python/model/eges.py b/easy_rec/python/model/eges.py new file mode 100644 index 000000000..5a466e294 --- /dev/null +++ b/easy_rec/python/model/eges.py @@ -0,0 +1,108 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import logging + +import tensorflow as tf + +from easy_rec.python.layers import dnn +from easy_rec.python.model.easy_rec_model import EasyRecModel + +from easy_rec.python.protos.eges_pb2 import EGES as EGESConfig # NOQA + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +class EGES(EasyRecModel): + + def __init__(self, + model_config, + feature_configs, + features, + labels=None, + is_training=False): + super(EGES, self).__init__(model_config, feature_configs, features, labels, + is_training) + self._model_config = model_config.eges + assert isinstance(self._model_config, EGESConfig) + self._group_name = 'item' + assert self._input_layer.has_group( + self._group_name), 'group[%s] is not specified' % self._group_name + + if labels is not None: + self._positive_features = labels['positive_fea'] + self._negative_features = labels['negative_fea'] + self._src_features = features + self._positive_embedding, _ = self._input_layer(self._positive_features, + self._group_name) + self._negative_embedding, _ = self._input_layer(self._negative_features, + self._group_name) + self._src_embedding, _ = self._input_layer(self._src_features, + self._group_name) + else: + self._src_embedding, _ = self._input_layer(features, self._group_name) + self._negative_embedding = None + self._positive_embedding = None + + def build_predict_graph(self): + if self._negative_embedding is None: + logging.info('build predict item embedding graph.') + src_embedding = tf.layers.batch_normalization( + self._src_embedding, + training=self._is_training, + trainable=True, + name='%s_fea_bn' % self._group_name) + dnn_layer = dnn.DNN(self._model_config.dnn, self._l2_reg, 'dnn', + self._is_training) + src_embedding = dnn_layer(src_embedding) + self._prediction_dict['item_embedding'] = src_embedding + return self._prediction_dict + + all_embedding = tf.concat([ + self._src_embedding, self._positive_embedding, self._negative_embedding + ], + axis=0) # noqa: E126 + all_embedding = tf.layers.batch_normalization( + all_embedding, + training=self._is_training, + trainable=True, + name='%s_fea_bn' % self._group_name) + batch_size = tf.shape(self._src_embedding)[0] + src_embedding = all_embedding[:batch_size] + pos_embedding = all_embedding[batch_size:(batch_size * 2)] + neg_embedding = all_embedding[(batch_size * 2):] + tf.summary.scalar('actual_batch_size', tf.shape(src_embedding)[0]) + tf.summary.histogram('src_fea', src_embedding) + tf.summary.histogram('neg_fea', neg_embedding) + tf.summary.histogram('pos_fea', pos_embedding) + dnn_layer = dnn.DNN(self._model_config.dnn, self._l2_reg, 'dnn', + self._is_training) + all_embedding = dnn_layer(all_embedding) + + embed_dim = all_embedding.get_shape()[-1] + src_embedding = all_embedding[:batch_size] + pos_embedding = all_embedding[batch_size:(batch_size * 2)] + neg_embedding = all_embedding[(batch_size * 2):] + neg_embedding = tf.reshape(neg_embedding, [batch_size, -1, embed_dim]) + target_embedding = tf.concat([pos_embedding[:, None, :], neg_embedding], + axis=1) + + self._prediction_dict['item_embedding'] = src_embedding + self._prediction_dict['target_embedding'] = target_embedding + return self._prediction_dict + + def build_loss_graph(self): + src_embedding = self._prediction_dict['item_embedding'] + target_embedding = self._prediction_dict['target_embedding'] + logits = tf.einsum('be,bne->bn', src_embedding, target_embedding) + batch_size = tf.shape(src_embedding)[0] + labels = tf.zeros([batch_size], dtype=tf.int32) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + return {'cross_entropy': tf.reduce_mean(loss)} + + def build_metric_graph(self, eval_config): + return {} + + def get_outputs(self): + return ['item_embedding'] diff --git a/easy_rec/python/model/mind.py b/easy_rec/python/model/mind.py index a749e6dfe..b3a66b304 100644 --- a/easy_rec/python/model/mind.py +++ b/easy_rec/python/model/mind.py @@ -72,12 +72,12 @@ def build_predict_graph(self): self._is_training) time_id_fea = [ - x[0] for x in self._hist_seq_features if 'time_id/' in x[0].name + x[0] for x in self._hist_seq_features if 'time_id_embedding/' in x[0].name ] time_id_fea = time_id_fea[0] if len(time_id_fea) > 0 else None hist_seq_feas = [ - x[0] for x in self._hist_seq_features if 'time_id/' not in x[0].name + x[0] for x in self._hist_seq_features if 'time_id_embedding/' not in x[0].name ] # it is assumed that all hist have the same length hist_seq_len = self._hist_seq_features[0][1] diff --git a/easy_rec/python/model/wide_and_deep.py b/easy_rec/python/model/wide_and_deep.py index 119af575c..b26989d90 100755 --- a/easy_rec/python/model/wide_and_deep.py +++ b/easy_rec/python/model/wide_and_deep.py @@ -46,7 +46,8 @@ def build_input_layer(self, model_config, feature_configs): wide_output_dim=wide_output_dim, use_embedding_variable=model_config.use_embedding_variable, embedding_regularizer=self._emb_reg, - kernel_regularizer=self._l2_reg) + kernel_regularizer=self._l2_reg, + group_as_scope=model_config.group_as_scope) def build_predict_graph(self): wide_fea = tf.add_n(self._wide_features) diff --git a/easy_rec/python/protos/data_source.proto b/easy_rec/python/protos/data_source.proto index a05134d12..8edc4ee81 100644 --- a/easy_rec/python/protos/data_source.proto +++ b/easy_rec/python/protos/data_source.proto @@ -18,3 +18,8 @@ message DatahubServer{ required uint32 shard_num = 6; required uint32 life_cycle = 7; } + +message GraphLearnInput { + repeated string node_inputs = 1; + repeated string edge_inputs = 2; +} diff --git a/easy_rec/python/protos/dataset.proto b/easy_rec/python/protos/dataset.proto index 326e03e88..88432405c 100644 --- a/easy_rec/python/protos/dataset.proto +++ b/easy_rec/python/protos/dataset.proto @@ -1,6 +1,8 @@ syntax = "proto2"; package protos; +import "easy_rec/python/protos/graph.proto"; + // Weighted Random Sampling ItemID not in Batch message NegativeSampler { // sample data path @@ -179,6 +181,7 @@ message DatasetConfig { DummyInput = 8; KafkaInput = 13; HiveInput = 17; + GraphInput = 18; } required InputType input_type = 10; @@ -248,4 +251,5 @@ message DatasetConfig { HardNegativeSamplerV2 hard_negative_sampler_v2 = 104; } + optional GraphConfig graph_config = 26; } diff --git a/easy_rec/python/protos/easy_rec_model.proto b/easy_rec/python/protos/easy_rec_model.proto index 6f8ca590d..70e1318c0 100644 --- a/easy_rec/python/protos/easy_rec_model.proto +++ b/easy_rec/python/protos/easy_rec_model.proto @@ -10,6 +10,8 @@ import "easy_rec/python/protos/feature_config.proto"; import "easy_rec/python/protos/dropoutnet.proto"; import "easy_rec/python/protos/dssm.proto"; import "easy_rec/python/protos/collaborative_metric_learning.proto"; +import "easy_rec/python/protos/mind.proto"; +import "easy_rec/python/protos/eges.proto"; import "easy_rec/python/protos/mmoe.proto"; import "easy_rec/python/protos/esmm.proto"; import "easy_rec/python/protos/dbmtl.proto"; @@ -17,7 +19,6 @@ import "easy_rec/python/protos/ple.proto"; import "easy_rec/python/protos/simple_multi_task.proto"; import "easy_rec/python/protos/dcn.proto"; import "easy_rec/python/protos/autoint.proto"; -import "easy_rec/python/protos/mind.proto"; import "easy_rec/python/protos/loss.proto"; import "easy_rec/python/protos/rocket_launching.proto"; import "easy_rec/python/protos/variational_dropout.proto"; @@ -65,6 +66,7 @@ message EasyRecModel { MIND mind = 202; DropoutNet dropoutnet = 203; CoMetricLearningI2I metric_learning = 204; + EGES eges = 205; MMoE mmoe = 301; ESMM esmm = 302; @@ -95,4 +97,8 @@ message EasyRecModel { optional VariationalDropoutLayer variational_dropout = 14; repeated Loss losses = 15; + + // use group_name as scope_name to ensure embedding sharing + // between different feature_groups. + optional bool group_as_scope = 1001 [default=false]; } diff --git a/easy_rec/python/protos/eges.proto b/easy_rec/python/protos/eges.proto new file mode 100644 index 000000000..61409b2f7 --- /dev/null +++ b/easy_rec/python/protos/eges.proto @@ -0,0 +1,10 @@ +syntax = "proto2"; +package protos; + +import "easy_rec/python/protos/dnn.proto"; + +message EGES { + required DNN dnn = 1; + + optional float l2_regularization = 2 [default=0.0]; +} diff --git a/easy_rec/python/protos/graph.proto b/easy_rec/python/protos/graph.proto new file mode 100644 index 000000000..6235c3a60 --- /dev/null +++ b/easy_rec/python/protos/graph.proto @@ -0,0 +1,17 @@ +syntax = "proto2"; +package protos; + +// for graph based algorithms: EGES, DeepWalk, etc. +message GraphConfig { + // graph edges are directed or directless, default to be directless + optional bool directed = 1 [default=false]; + + // for random walk, random walk length + optional uint32 random_walk_len = 2 [default=10]; + + // skip-gram window size + optional uint32 window_size = 3 [default=5]; + + // for negative sampling, number negatives sampled for each positive sample + optional uint32 negative_num = 4 [default=10]; +} diff --git a/easy_rec/python/protos/pipeline.proto b/easy_rec/python/protos/pipeline.proto index 1a351337b..136e89d38 100644 --- a/easy_rec/python/protos/pipeline.proto +++ b/easy_rec/python/protos/pipeline.proto @@ -14,17 +14,18 @@ import "easy_rec/python/protos/hive_config.proto"; message EasyRecConfig { oneof train_path { - string train_input_path = 1; - KafkaServer kafka_train_input = 2; - DatahubServer datahub_train_input = 12; - HiveConfig hive_train_input = 21; + string train_input_path = 101; + KafkaServer kafka_train_input = 102; + DatahubServer datahub_train_input = 103; + HiveConfig hive_train_input = 104; + GraphLearnInput graph_train_input = 105; } oneof eval_path { - string eval_input_path = 3; - KafkaServer kafka_eval_input = 4; - DatahubServer datahub_eval_input = 13; - HiveConfig hive_eval_input= 22; - + string eval_input_path = 201; + KafkaServer kafka_eval_input = 202; + DatahubServer datahub_eval_input = 203; + HiveConfig hive_eval_input= 204; + GraphLearnInput graph_eval_input = 205; } required string model_dir = 5; diff --git a/easy_rec/python/test/hive_input_test.py b/easy_rec/python/test/hive_input_test.py index 07d219c11..95cafe22b 100644 --- a/easy_rec/python/test/hive_input_test.py +++ b/easy_rec/python/test/hive_input_test.py @@ -1,20 +1,24 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """Define cv_input, the base class for cv tasks.""" -import tensorflow as tf +import logging +import os import unittest + +import tensorflow as tf +from google.protobuf import text_format + from easy_rec.python.input.hive_input import HiveInput from easy_rec.python.protos.dataset_pb2 import DatasetConfig +from easy_rec.python.protos.feature_config_pb2 import FeatureConfig +from easy_rec.python.protos.hive_config_pb2 import HiveConfig +from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig +from easy_rec.python.utils import config_util from easy_rec.python.utils import test_utils -from easy_rec.python.utils.config_util import * -from easy_rec.python.utils.test_utils import * from easy_rec.python.utils.test_utils import _load_config_for_test -from easy_rec.python.protos.hive_config_pb2 import HiveConfig -import os if tf.__version__ >= '2.0': - #tf = tf.compat.v1 - import tensorflow.compat.v1 as tf + tf = tf.compat.v1 gfile = tf.gfile @@ -28,10 +32,10 @@ class HiveInputTest(tf.test.TestCase): def _init_config(self): - hive_host = os.environ["hive_host"] - hive_username = os.environ["hive_username"] - hive_table_name = os.environ["hive_table_name"] - hive_hash_fields = os.environ["hive_hash_fields"] + hive_host = os.environ['hive_host'] + hive_username = os.environ['hive_username'] + hive_table_name = os.environ['hive_table_name'] + hive_hash_fields = os.environ['hive_hash_fields'] hive_train_input = """ host: "{}" @@ -40,7 +44,7 @@ def _init_config(self): limit_num: 500 hash_fields: "{}" """.format(hive_host, hive_username, hive_table_name, hive_hash_fields) - hive_eval_input =""" + hive_eval_input = """ host: "{}" username: "{}" table_name: "{}" @@ -56,10 +60,11 @@ def _init_config(self): def __init__(self, methodName='HiveInputTest'): super(HiveInputTest, self).__init__(methodName=methodName) - @unittest.skipIf( - 'hive_host' not in os.environ or 'hive_username' not in os.environ or - 'hive_table_name' not in os.environ or 'hive_hash_fields' not in os.environ, - """Only execute hive_config var are specified,hive_host、 + @unittest.skipIf('hive_host' not in os.environ or + 'hive_username' not in os.environ or + 'hive_table_name' not in os.environ or + 'hive_hash_fields' not in os.environ, + """Only execute hive_config var are specified,hive_host、 hive_username、hive_table_name、hive_hash_fields is available.""") def test_hive_input(self): self._init_config() @@ -244,23 +249,24 @@ def test_hive_input(self): feature_dict, label_dict = sess.run([features, labels]) for key in feature_dict: print(key, feature_dict[key][:5]) - + for key in label_dict: print(key, label_dict[key][:5]) return 0 - @unittest.skipIf( - 'hive_host' not in os.environ or 'hive_username' not in os.environ or - 'hive_table_name' not in os.environ or 'hive_hash_fields' not in os.environ, - """Only execute hive_config var are specified,hive_host、 + @unittest.skipIf('hive_host' not in os.environ or + 'hive_username' not in os.environ or + 'hive_table_name' not in os.environ or + 'hive_hash_fields' not in os.environ, + """Only execute hive_config var are specified,hive_host、 hive_username、hive_table_name、hive_hash_fields is available.""") def test_mmoe(self): pipeline_config_path = 'samples/emr_script/mmoe/mmoe_census_income.config' - gpus = get_available_gpus() + gpus = test_utils.get_available_gpus() if len(gpus) > 0: - set_gpu_id(gpus[0]) + test_utils.set_gpu_id(gpus[0]) else: - set_gpu_id(None) + test_utils.set_gpu_id(None) if not isinstance(pipeline_config_path, EasyRecConfig): logging.info('testing pipeline config %s' % pipeline_config_path) @@ -270,7 +276,8 @@ def test_mmoe(self): if isinstance(pipeline_config_path, EasyRecConfig): pipeline_config = pipeline_config_path else: - pipeline_config = _load_config_for_test(pipeline_config_path, self._test_dir) + pipeline_config = _load_config_for_test(pipeline_config_path, + self._test_dir) pipeline_config.train_config.train_distribute = 0 pipeline_config.train_config.num_gpus_per_worker = 1 @@ -278,10 +285,11 @@ def test_mmoe(self): config_util.save_pipeline_config(pipeline_config, self._test_dir) test_pipeline_config_path = os.path.join(self._test_dir, 'pipeline.config') - hyperparam_str = "" + hyperparam_str = '' train_cmd = 'python -m easy_rec.python.train_eval --pipeline_config_path %s %s' % ( - test_pipeline_config_path, hyperparam_str) - proc = run_cmd(train_cmd, '%s/log_%s.txt' % (self._test_dir, 'master')) + test_pipeline_config_path, hyperparam_str) + proc = test_utils.run_cmd(train_cmd, + '%s/log_%s.txt' % (self._test_dir, 'master')) proc.wait() if proc.returncode != 0: logging.error('train %s failed' % test_pipeline_config_path) diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index a3078e090..657123154 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -561,6 +561,10 @@ def test_sequence_fm(self): self._test_dir) self.assertTrue(self._success) + def test_eges(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/eges_on_taobao.config', self._test_dir) + def test_sequence_mmoe(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/mmoe_on_sequence_feature_taobao.config', diff --git a/easy_rec/python/utils/graph_utils.py b/easy_rec/python/utils/graph_utils.py new file mode 100644 index 000000000..0291ea4ac --- /dev/null +++ b/easy_rec/python/utils/graph_utils.py @@ -0,0 +1,56 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import json +import logging + +from easy_rec.python.utils import pai_util + + +def graph_init(graph, tf_config=None): + if tf_config: + if isinstance(tf_config, str) or isinstance(tf_config, type(u'')): + tf_config = json.loads(tf_config) + if 'ps' in tf_config['cluster']: + # ps mode + ps_count = len(tf_config['cluster']['ps']) + evaluator_cnt = 1 if pai_util.has_evaluator() else 0 + if evaluator_cnt == 0: + logging.warning( + 'evaluator is not set as an client in GraphLearn,' + 'if you actually set evaluator in TF_CONFIG, please do: export' + ' HAS_EVALUATOR=1.') + task_count = len(tf_config['cluster']['worker']) + 1 + evaluator_cnt + cluster = { + 'server_count': ps_count, + 'client_count': task_count, + 'tracker': 'graph_test_tracker/' + } + if tf_config['task']['type'] in ['chief', 'master']: + graph.init(cluster=cluster, job_name='client', task_index=0) + elif tf_config['task']['type'] == 'worker': + graph.init( + cluster=cluster, + job_name='client', + task_index=tf_config['task']['index'] + 1 + evaluator_cnt) + elif tf_config['task']['type'] == 'evaluator': + graph.init( + cluster=cluster, + job_name='client', + task_index=tf_config['task']['index'] + 1) + elif tf_config['task']['type'] == 'ps': + graph.init( + cluster=cluster, + job_name='server', + task_index=tf_config['task']['index']) + else: + # worker mode + task_count = len(tf_config['cluster']['worker']) + 1 + if tf_config['task']['type'] in ['chief', 'master']: + graph.init(task_index=0, task_count=task_count) + elif tf_config['task']['type'] == 'worker': + graph.init( + task_index=tf_config['task']['index'] + evaluator_cnt, + task_count=task_count) + else: + # local mode + graph.init() diff --git a/easy_rec/python/utils/pai_util.py b/easy_rec/python/utils/pai_util.py index de7ce99aa..c4ddbe7c2 100644 --- a/easy_rec/python/utils/pai_util.py +++ b/easy_rec/python/utils/pai_util.py @@ -26,6 +26,15 @@ def set_on_pai(): os.environ['IS_ON_PAI'] = '1' +def set_has_evaluator(): + logging.info('set environment variable: HAS_EVALUATOR') + os.environ['HAS_EVALUATOR'] = '1' + + +def has_evaluator(): + return 'HAS_EVALUATOR' in os.environ + + def download(url): _, fname = os.path.split(url) request = Request(url=url) diff --git a/pai_jobs/run.py b/pai_jobs/run.py index 0961f7eb9..2ea72c9c0 100644 --- a/pai_jobs/run.py +++ b/pai_jobs/run.py @@ -283,6 +283,10 @@ def main(argv): ], 'invalid evalaute_method: %s' % FLAGS.eval_method if FLAGS.with_evaluator: FLAGS.eval_method = 'separate' + + if FLAGS.eval_method == 'separate': + pai_util.set_has_evaluator() + num_worker = set_tf_config_and_get_train_worker_num( FLAGS.ps_hosts, FLAGS.worker_hosts, diff --git a/samples/emr_script/mmoe/mmoe_census_income.config b/samples/emr_script/mmoe/mmoe_census_income.config index a6e820ca7..9201af431 100644 --- a/samples/emr_script/mmoe/mmoe_census_income.config +++ b/samples/emr_script/mmoe/mmoe_census_income.config @@ -564,4 +564,4 @@ feature_configs { embedding_dim: 9 hash_bucket_size: 400 embedding_name: "feature" -} \ No newline at end of file +} diff --git a/samples/model_config/eges_on_taobao.config b/samples/model_config/eges_on_taobao.config new file mode 100644 index 000000000..76bb005c0 --- /dev/null +++ b/samples/model_config/eges_on_taobao.config @@ -0,0 +1,136 @@ +model_dir: "experiments/eges_taobao_ckpt" + +graph_train_input { + node_inputs: 'data/test/graph/taobao_data/ad_feature_5k.csv' + edge_inputs: 'data/test/graph/taobao_data/graph_edges.txt' +} + +graph_eval_input { + node_inputs: 'data/test/graph/taobao_data/ad_feature_5k.csv' + edge_inputs: 'data/test/graph/taobao_data/graph_edges.txt' +} + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + constant_learning_rate { + learning_rate: 1e-4 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 500 + save_summary_steps: 10 + sync_replicas: false + num_steps: 2500 +} + +eval_config { +} + +data_config { + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: DOUBLE + } + + graph_config { + random_walk_len: 10 + window_size: 5 + negative_num: 10 + directed: true + } + + batch_size: 64 + num_epochs: 2 + prefetch_size: 32 + input_type: GraphInput +} + +feature_config: { + features: { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'price' + feature_type: RawFeature + } +} + +model_config:{ + model_class: "EGES" + + feature_groups: { + group_name: "item" + feature_names: 'adgroup_id' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + wide_deep:DEEP + } + eges { + dnn { + hidden_units: [256, 128, 64, 32] + } + l2_regularization: 1e-6 + } + loss_type: SOFTMAX_CROSS_ENTROPY + embedding_regularization: 0.0 + + group_as_scope: true +} + +export_config { +} diff --git a/scripts/ci_test.sh b/scripts/ci_test.sh index 1e6696353..1c06b159c 100755 --- a/scripts/ci_test.sh +++ b/scripts/ci_test.sh @@ -5,7 +5,11 @@ pip install oss2 pip install -r requirements.txt # update/generate proto -bash scripts/gen_proto.sh +bash scripts/init.sh +if [ $? -ne 0 ] +then + exit 1 +fi export CUDA_VISIBLE_DEVICES="" export TEST_DEVICES="" diff --git a/scripts/init.sh b/scripts/init.sh index f6e266faf..20c23ab80 100644 --- a/scripts/init.sh +++ b/scripts/init.sh @@ -7,3 +7,35 @@ chmod a+rx .git/hooks/pre-commit # compile proto files source scripts/gen_proto.sh + +if [ $? -ne 0 ] +then + echo "generate proto failed." + exit 1 +fi + +file_name=easyrec_data_20220304.tar.gz + +tmp_dir=$TMPDIR + +if [ -z "$tmp_dir" ] +then + tmp_dir="/tmp" +fi + +tmp_path="$tmp_dir/$file_name" +if [ ! -e "$tmp_path" ] +then + wget https://easyrec.oss-cn-beijing.aliyuncs.com/data/$file_name -O $tmp_path + if [ $? -ne 0 ] + then + echo "download data failed" + exit 1 + fi +fi +tar -zvxf $tmp_path +if [ $? -ne 0 ] +then + echo "extract $file_name failed" + exit 1 +fi