Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_dynamic_weight_for_muti_label #469

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/quick_start/mc_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pai -name easy_rec_ext -project algo_public
- -Dtables: 定义其他依赖表(可选),如负采样的表
- -Dcluster: 定义PS的数目和worker的数目。具体见:[PAI-TF任务参数介绍](https://help.aliyun.com/document_detail/154186.html?spm=a2c4g.11186623.4.3.e56f1adb7AJ9T5)
- -Deval_method: 评估方法
- separate: 用worker(task_id=1)做评估
- separate: 用worker(task_id=1)做评估。点击训练的logview中worker#1_0的stderr,出现类似字段"Saving dict for global step 3949: auc = 0.7643898, global_step = 3949, loss = 0.38898173, loss/loss/cross_entropy_loss = 0.38898173, loss/loss/total_loss = 0.38898173"即是评估指标
- none: 不需要评估
- master: 在master(task_id=0)上做评估
- -Dfine_tune_checkpoint: 可选,从checkpoint restore参数,进行finetune
Expand Down
18 changes: 17 additions & 1 deletion easy_rec/python/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self,
x.default_val for x in data_config.input_fields
]
self._label_fields = list(data_config.label_fields)
self._label_dynamic_weight = list(data_config.label_dynamic_weight)
self._feature_fields = list(data_config.feature_fields)
self._label_sep = list(data_config.label_sep)
self._label_dim = list(data_config.label_dim)
Expand Down Expand Up @@ -139,6 +140,8 @@ def __init__(self,
# add sample weight to effective fields
if self._data_config.HasField('sample_weight'):
self._effective_fields.append(self._data_config.sample_weight)
if len(self._label_dynamic_weight) > 0:
self._effective_fields.extend(self._label_dynamic_weight)

# add uid_field of GAUC and session_fields of SessionAUC
if self._pipeline_config is not None:
Expand Down Expand Up @@ -234,6 +237,7 @@ def get_feature_input_fields(self):
return [
x for x in self._input_fields
if x not in self._label_fields and x != self._data_config.sample_weight
and x not in self._label_dynamic_weight
]

def should_stop(self, curr_epoch):
Expand Down Expand Up @@ -269,13 +273,14 @@ def create_multi_placeholders(self, export_config):
effective_fids = [
fid for fid in range(len(self._input_fields))
if self._input_fields[fid] not in self._label_fields and
self._input_fields[fid] not in self._label_dynamic_weight and
self._input_fields[fid] != sample_weight_field
]

inputs = {}
for fid in effective_fids:
input_name = self._input_fields[fid]
if input_name == sample_weight_field:
if input_name == sample_weight_field or input_name in self._label_dynamic_weight:
continue
if placeholder_named_by_input:
placeholder_name = input_name
Expand Down Expand Up @@ -318,6 +323,7 @@ def create_placeholders(self, export_config):
effective_fids = [
fid for fid in range(len(self._input_fields))
if self._input_fields[fid] not in self._label_fields and
self._input_fields[fid] not in self._label_dynamic_weight and
self._input_fields[fid] != sample_weight_field
]
logging.info(
Expand All @@ -330,6 +336,8 @@ def create_placeholders(self, export_config):
ftype = self._input_field_types[fid]
tf_type = get_tf_type(ftype)
input_name = self._input_fields[fid]
if input_name in self._label_dynamic_weight:
continue
if tf_type in [tf.float32, tf.double, tf.int32, tf.int64]:
features[input_name] = tf.string_to_number(
input_vals[:, tmp_id],
Expand Down Expand Up @@ -925,6 +933,14 @@ def _preprocess(self, field_dict):
if self._mode != tf.estimator.ModeKeys.PREDICT:
parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
self._data_config.sample_weight]
if len(self._label_dynamic_weight
) > 0 and self._mode != tf.estimator.ModeKeys.PREDICT:
for label_weight in self._label_dynamic_weight:
if field_dict[label_weight].dtype == tf.float32:
parsed_dict[label_weight] = field_dict[label_weight]
else:
parsed_dict[label_weight] = tf.cast(
field_dict[label_weight], dtype=tf.float64)

if Input.DATA_OFFSET in field_dict:
parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/model/multi_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ def build_loss_graph(self):
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
loss_weight = task_tower_cfg.weight
if task_tower_cfg.HasField('dynamic_weight'):
loss_weight *= self._feature_dict[task_tower_cfg.dynamic_weight]
if task_tower_cfg.use_sample_weight:
loss_weight *= self._sample_weight

Expand Down
2 changes: 2 additions & 0 deletions easy_rec/python/protos/dataset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ message DatasetConfig {

// input field for sample weight
optional string sample_weight = 22;
// input field for label dynimic weight
repeated string label_dynamic_weight = 27;
// the compression type of tfrecord
optional string data_compression_type = 23 [default = ''];

Expand Down
4 changes: 4 additions & 0 deletions easy_rec/python/protos/tower.proto
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ message TaskTower {
optional DNN dnn = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// training loss label dynamic weights
optional string dynamic_weight = 8;
// label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 10;
// the loss weight for sample in the task space
Expand Down Expand Up @@ -72,4 +74,6 @@ message BayesTaskTower {
repeated Loss losses = 15;
// whether to use sample weight in this tower
required bool use_sample_weight = 16 [default = true];
// training loss label dynamic weights
optional string dynamic_weight = 17;
};
12 changes: 12 additions & 0 deletions easy_rec/python/test/train_eval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,12 +937,24 @@ def test_sequence_esmm(self):
self._test_dir)
self.assertTrue(self._success)

def test_label_dynamic_weight_esmm(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/esmm_on_label_dynamic_weight_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_sequence_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_label_dynamic_weight_sequence_mmoe(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/mmoe_on_label_dynamic_weight_sequence_feature_taobao.config',
self._test_dir)
self.assertTrue(self._success)

def test_sequence_ple(self):
self._success = test_utils.test_single_train_eval(
'samples/model_config/ple_on_sequence_feature_taobao.config',
Expand Down
Loading