diff --git a/Mathlib/Data/Nat/Bitwise.lean b/Mathlib/Data/Nat/Bitwise.lean index 2a2306e5f1f71..ea2eb8e119049 100644 --- a/Mathlib/Data/Nat/Bitwise.lean +++ b/Mathlib/Data/Nat/Bitwise.lean @@ -344,40 +344,39 @@ theorem xor_eq_zero {n m : ℕ} : n ^^^ m = 0 ↔ n = m := by theorem xor_ne_zero {n m : ℕ} : n ^^^ m ≠ 0 ↔ n ≠ m := xor_eq_zero.not -theorem xor_trichotomy {a b c : ℕ} (h : a ≠ b ^^^ c) : - b ^^^ c < a ∨ a ^^^ c < b ∨ a ^^^ b < c := by - set v := a ^^^ (b ^^^ c) with hv +theorem xor_trichotomy {a b c : ℕ} (h : a ^^^ b ^^^ c ≠ 0) : + b ^^^ c < a ∨ c ^^^ a < b ∨ a ^^^ b < c := by + set v := a ^^^ b ^^^ c with hv -- The xor of any two of `a`, `b`, `c` is the xor of `v` and the third. have hab : a ^^^ b = c ^^^ v := by - rw [hv] - conv_rhs => - rw [Nat.xor_comm] - simp [Nat.xor_assoc] - have hac : a ^^^ c = b ^^^ v := by - rw [hv] - conv_rhs => - right - rw [← Nat.xor_comm] - rw [← Nat.xor_assoc, ← Nat.xor_assoc, xor_self, zero_xor, Nat.xor_comm] - have hbc : b ^^^ c = a ^^^ v := by simp [hv, ← Nat.xor_assoc] + rw [Nat.xor_comm c, xor_cancel_right] + have hbc : b ^^^ c = a ^^^ v := by + rw [← Nat.xor_assoc, xor_cancel_left] + have hca : c ^^^ a = b ^^^ v := by + rw [hv, Nat.xor_assoc, Nat.xor_comm a, ← Nat.xor_assoc, xor_cancel_left] -- If `i` is the position of the most significant bit of `v`, then at least one of `a`, `b`, `c` -- has a one bit at position `i`. - obtain ⟨i, ⟨hi, hi'⟩⟩ := exists_most_significant_bit (xor_ne_zero.2 h) - have : testBit a i = true ∨ testBit b i = true ∨ testBit c i = true := by + obtain ⟨i, ⟨hi, hi'⟩⟩ := exists_most_significant_bit h + have : testBit a i ∨ testBit b i ∨ testBit c i := by contrapose! hi - simp only [Bool.eq_false_eq_not_eq_true, Ne, testBit_xor, Bool.bne_eq_xor] at hi ⊢ - rw [hi.1, hi.2.1, hi.2.2, Bool.xor_false, Bool.xor_false] + simp_rw [Bool.eq_false_eq_not_eq_true] at hi ⊢ + rw [testBit_xor, testBit_xor, hi.1, hi.2.1, hi.2.2] + rfl -- If, say, `a` has a one bit at position `i`, then `a xor v` has a zero bit at position `i`, but -- the same bits as `a` in positions greater than `j`, so `a xor v < a`. - rcases this with (h | h | h) + obtain h | h | h := this on_goal 1 => left; rw [hbc] - on_goal 2 => right; left; rw [hac] + on_goal 2 => right; left; rw [hca] on_goal 3 => right; right; rw [hab] all_goals - exact lt_of_testBit i (by simp [h, hi]) h fun j hj => by simp [hi' _ hj] - -theorem lt_xor_cases {a b c : ℕ} (h : a < b ^^^ c) : a ^^^ c < b ∨ a ^^^ b < c := - (or_iff_right fun h' => (h.asymm h').elim).1 <| xor_trichotomy h.ne + refine lt_of_testBit i ?_ h fun j hj => ?_ + · rw [testBit_xor, h, hi] + rfl + · simp only [testBit_xor, hi' _ hj, Bool.bne_false] + +theorem lt_xor_cases {a b c : ℕ} (h : a < b ^^^ c) : a ^^^ c < b ∨ a ^^^ b < c := by + obtain ha | hb | hc := xor_trichotomy <| Nat.xor_assoc _ _ _ ▸ xor_ne_zero.2 h.ne + exacts [(h.asymm ha).elim, Or.inl <| Nat.xor_comm _ _ ▸ hb, Or.inr hc] @[simp] theorem bit_lt_two_pow_succ_iff {b x n} : bit b x < 2 ^ (n + 1) ↔ x < 2 ^ n := by cases b <;> simp <;> omega