From cbec53934cca03039cdb711650969588c2411b0c Mon Sep 17 00:00:00 2001 From: Eric Ge Date: Mon, 9 Sep 2024 11:45:19 +0800 Subject: [PATCH] Add temperature coefficient for DSSM model (#482) * add temperature coefficient for DSSM model * github actions_runner test * fix code style check * do normalization when use cosine loss * remove labelcomment in code style check * fix simi func in test case dssm_distribute_eval_reg --- .github/workflows/code_style.yml | 78 ------------------- easy_rec/python/model/dssm.py | 8 +- easy_rec/python/protos/dssm.proto | 2 + easy_rec/python/test/run.py | 2 +- .../dssm_distribute_eval_reg_on_taobao.config | 2 + 5 files changed, 10 insertions(+), 82 deletions(-) diff --git a/.github/workflows/code_style.yml b/.github/workflows/code_style.yml index 80316c4a9..7a7a79d83 100644 --- a/.github/workflows/code_style.yml +++ b/.github/workflows/code_style.yml @@ -36,85 +36,7 @@ jobs: echo "ci_test_passed=0" >> $GITHUB_OUTPUT fi fi - - name: LabelAndComment - env: - CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}} - uses: actions/github-script@v5 - with: - script: | - const { CI_TEST_PASSED } = process.env - labels = await github.rest.issues.listLabelsOnIssue({ - issue_number: context.issue.number, - repo:context.repo.repo, - owner:context.repo.owner - }) - console.log('labels.url=' + labels.url) - - labels = labels.data - - var label_names = [] - if (labels != null) { - labels.forEach(tmp_lbl => label_names.push(tmp_lbl.name)) - } - console.log(`ci_test_passed=${CI_TEST_PASSED} labels=${label_names}`); - - var pass_label = null; - if (labels != null) { - pass_label = labels.find(label=>label.name=='code_style_test_passed'); - } - - var fail_label = null; - if (labels != null) { - fail_label = labels.find(label=>label.name=='code_style_test_failed'); - } - - if (pass_label) { - github.rest.issues.removeLabel({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - name: 'code_style_test_passed' - }) - } - - if (fail_label) { - github.rest.issues.removeLabel({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - name: 'code_style_test_failed' - }) - } - - if (CI_TEST_PASSED == 1) { - github.rest.issues.addLabels({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - labels: ['code_style_test_passed'] - }) - - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: "Code Style Test Passed" - }) - } else { - github.rest.issues.addLabels({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - labels: ['code_style_test_failed'] - }) - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: "Code Style Test Failed" - }) - } - name: SignalFail env: CI_TEST_PASSED: ${{steps.run_ci_test.outputs.ci_test_passed}} diff --git a/easy_rec/python/model/dssm.py b/easy_rec/python/model/dssm.py index 334850d94..c80d49c95 100644 --- a/easy_rec/python/model/dssm.py +++ b/easy_rec/python/model/dssm.py @@ -61,12 +61,14 @@ def build_predict_graph(self): kernel_regularizer=self._l2_reg, name='item_dnn/dnn_%d' % (num_item_dnn_layer - 1)) - if self._loss_type == LossType.CLASSIFICATION: - if self._model_config.simi_func == Similarity.COSINE: + if self._model_config.simi_func == Similarity.COSINE: user_tower_emb = self.norm(user_tower_emb) item_tower_emb = self.norm(item_tower_emb) + temperature = self._model_config.temperature + else: + temperature = 1.0 - user_item_sim = self.sim(user_tower_emb, item_tower_emb) + user_item_sim = self.sim(user_tower_emb, item_tower_emb) / temperature if self._model_config.scale_simi: sim_w = tf.get_variable( 'sim_w', diff --git a/easy_rec/python/protos/dssm.proto b/easy_rec/python/protos/dssm.proto index 64c2744eb..c0015a28e 100644 --- a/easy_rec/python/protos/dssm.proto +++ b/easy_rec/python/protos/dssm.proto @@ -20,4 +20,6 @@ message DSSM { optional bool scale_simi = 5 [default = true]; optional string item_id = 9; required bool ignore_in_batch_neg_sam = 10 [default = false]; + // normalize user_tower_embedding and item_tower_embedding + optional float temperature = 11 [default = 1.0]; } diff --git a/easy_rec/python/test/run.py b/easy_rec/python/test/run.py index e8ecc61d0..0c7ac4c79 100644 --- a/easy_rec/python/test/run.py +++ b/easy_rec/python/test/run.py @@ -80,7 +80,7 @@ def main(argv): (len(all_tests), test_dir)) max_num_port_per_proc = 3 - total_port_num = (max_num_port_per_proc + 2) * FLAGS.num_parallel + total_port_num = (max_num_port_per_proc + 2) * FLAGS.num_parallel * 10 all_available_ports = test_utils.get_ports_base(total_port_num).tolist() procs = {} diff --git a/samples/model_config/dssm_distribute_eval_reg_on_taobao.config b/samples/model_config/dssm_distribute_eval_reg_on_taobao.config index a9c8189e1..b8dd91fe4 100644 --- a/samples/model_config/dssm_distribute_eval_reg_on_taobao.config +++ b/samples/model_config/dssm_distribute_eval_reg_on_taobao.config @@ -271,9 +271,11 @@ model_config:{ } } l2_regularization: 1e-6 + simi_func: INNER_PRODUCT } embedding_regularization: 5e-5 loss_type: L2_LOSS + } export_config {