From 74313816826b373170c1fa095a0cae5ab6217224 Mon Sep 17 00:00:00 2001
From: Penghui Li <penghui@apache.org>
Date: Wed, 11 Jan 2023 11:01:13 +0800
Subject: [PATCH] [improve][broker] Add ref count for sticky hash to optimize
 the performance of Key_Shared subscription (#19167)

---
 .../MessageRedeliveryController.java          | 72 ++++++++++++++-----
 .../MessageRedeliveryControllerTest.java      | 34 +++++++++
 2 files changed, 88 insertions(+), 18 deletions(-)

diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryController.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryController.java
index d8667def5526d..5bf3f5506fa81 100644
--- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryController.java
+++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryController.java
@@ -23,22 +23,36 @@
 import java.util.List;
 import java.util.NavigableSet;
 import java.util.Set;
-import java.util.concurrent.atomic.AtomicBoolean;
+import javax.annotation.concurrent.NotThreadSafe;
 import org.apache.bookkeeper.mledger.impl.PositionImpl;
+import org.apache.bookkeeper.util.collections.ConcurrentLongLongHashMap;
 import org.apache.pulsar.common.util.collections.ConcurrentLongLongPairHashMap;
 import org.apache.pulsar.common.util.collections.ConcurrentLongLongPairHashMap.LongPair;
 import org.apache.pulsar.utils.ConcurrentBitmapSortedLongPairSet;
 
+/**
+ * The MessageRedeliveryController is a non-thread-safe container for maintaining the redelivery messages.
+ */
+@NotThreadSafe
 public class MessageRedeliveryController {
+
+    private final boolean allowOutOfOrderDelivery;
     private final ConcurrentBitmapSortedLongPairSet messagesToRedeliver;
     private final ConcurrentLongLongPairHashMap hashesToBeBlocked;
+    private final ConcurrentLongLongHashMap hashesRefCount;
 
     public MessageRedeliveryController(boolean allowOutOfOrderDelivery) {
+        this.allowOutOfOrderDelivery = allowOutOfOrderDelivery;
         this.messagesToRedeliver = new ConcurrentBitmapSortedLongPairSet();
-        this.hashesToBeBlocked = allowOutOfOrderDelivery
-                ? null
-                : ConcurrentLongLongPairHashMap
+        if (!allowOutOfOrderDelivery) {
+            this.hashesToBeBlocked = ConcurrentLongLongPairHashMap
+                    .newBuilder().concurrencyLevel(2).expectedItems(128).autoShrink(true).build();
+            this.hashesRefCount = ConcurrentLongLongHashMap
                     .newBuilder().concurrencyLevel(2).expectedItems(128).autoShrink(true).build();
+        } else {
+            this.hashesToBeBlocked = null;
+            this.hashesRefCount = null;
+        }
     }
 
     public void add(long ledgerId, long entryId) {
@@ -46,21 +60,43 @@ public void add(long ledgerId, long entryId) {
     }
 
     public void add(long ledgerId, long entryId, long stickyKeyHash) {
-        if (hashesToBeBlocked != null) {
-            hashesToBeBlocked.put(ledgerId, entryId, stickyKeyHash, 0);
+        if (!allowOutOfOrderDelivery) {
+            boolean inserted = hashesToBeBlocked.putIfAbsent(ledgerId, entryId, stickyKeyHash, 0);
+            if (!inserted) {
+                hashesToBeBlocked.put(ledgerId, entryId, stickyKeyHash, 0);
+            } else {
+                // Return -1 means the key was not present
+                long stored = hashesRefCount.get(stickyKeyHash);
+                hashesRefCount.put(stickyKeyHash, stored > 0 ? ++stored : 1);
+            }
         }
         messagesToRedeliver.add(ledgerId, entryId);
     }
 
     public void remove(long ledgerId, long entryId) {
-        if (hashesToBeBlocked != null) {
-            hashesToBeBlocked.remove(ledgerId, entryId);
+        if (!allowOutOfOrderDelivery) {
+            removeFromHashBlocker(ledgerId, entryId);
         }
         messagesToRedeliver.remove(ledgerId, entryId);
     }
 
+    private void removeFromHashBlocker(long ledgerId, long entryId) {
+        LongPair value = hashesToBeBlocked.get(ledgerId, entryId);
+        if (value != null) {
+            boolean removed = hashesToBeBlocked.remove(ledgerId, entryId, value.first, 0);
+            if (removed) {
+                long exists = hashesRefCount.get(value.first);
+                if (exists == 1) {
+                    hashesRefCount.remove(value.first, exists);
+                } else if (exists > 0) {
+                    hashesRefCount.put(value.first, exists - 1);
+                }
+            }
+        }
+    }
+
     public void removeAllUpTo(long markDeleteLedgerId, long markDeleteEntryId) {
-        if (hashesToBeBlocked != null) {
+        if (!allowOutOfOrderDelivery) {
             List<LongPair> keysToRemove = new ArrayList<>();
             hashesToBeBlocked.forEach((ledgerId, entryId, stickyKeyHash, none) -> {
                 if (ComparisonChain.start().compare(ledgerId, markDeleteLedgerId).compare(entryId, markDeleteEntryId)
@@ -68,7 +104,7 @@ public void removeAllUpTo(long markDeleteLedgerId, long markDeleteEntryId) {
                     keysToRemove.add(new LongPair(ledgerId, entryId));
                 }
             });
-            keysToRemove.forEach(longPair -> hashesToBeBlocked.remove(longPair.first, longPair.second));
+            keysToRemove.forEach(longPair -> removeFromHashBlocker(longPair.first, longPair.second));
             keysToRemove.clear();
         }
         messagesToRedeliver.removeUpTo(markDeleteLedgerId, markDeleteEntryId + 1);
@@ -79,8 +115,9 @@ public boolean isEmpty() {
     }
 
     public void clear() {
-        if (hashesToBeBlocked != null) {
+        if (!allowOutOfOrderDelivery) {
             hashesToBeBlocked.clear();
+            hashesRefCount.clear();
         }
         messagesToRedeliver.clear();
     }
@@ -90,15 +127,14 @@ public String toString() {
     }
 
     public boolean containsStickyKeyHashes(Set<Integer> stickyKeyHashes) {
-        final AtomicBoolean isContained = new AtomicBoolean(false);
-        if (hashesToBeBlocked != null) {
-            hashesToBeBlocked.forEach((ledgerId, entryId, stickyKeyHash, none) -> {
-                if (!isContained.get() && stickyKeyHashes.contains((int) stickyKeyHash)) {
-                    isContained.set(true);
+        if (!allowOutOfOrderDelivery) {
+            for (Integer stickyKeyHash : stickyKeyHashes) {
+                if (hashesRefCount.containsKey(stickyKeyHash)) {
+                    return true;
                 }
-            });
+            }
         }
-        return isContained.get();
+        return false;
     }
 
     public NavigableSet<PositionImpl> getMessagesToReplayNow(int maxMessagesToRead) {
diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryControllerTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryControllerTest.java
index f19fff08de6a7..be5294d1c0f63 100644
--- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryControllerTest.java
+++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/MessageRedeliveryControllerTest.java
@@ -28,6 +28,7 @@
 import java.util.Set;
 import java.util.TreeSet;
 import org.apache.bookkeeper.mledger.impl.PositionImpl;
+import org.apache.bookkeeper.util.collections.ConcurrentLongLongHashMap;
 import org.apache.pulsar.utils.ConcurrentBitmapSortedLongPairSet;
 import org.apache.pulsar.common.util.collections.ConcurrentLongLongPairHashMap;
 import org.testng.annotations.DataProvider;
@@ -54,16 +55,23 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
         ConcurrentLongLongPairHashMap hashesToBeBlocked = (ConcurrentLongLongPairHashMap) hashesToBeBlockedField
                 .get(controller);
 
+        Field hashesRefCountField = MessageRedeliveryController.class.getDeclaredField("hashesRefCount");
+        hashesRefCountField.setAccessible(true);
+        ConcurrentLongLongHashMap hashesRefCount = (ConcurrentLongLongHashMap) hashesRefCountField.get(controller);
+
         if (allowOutOfOrderDelivery) {
             assertNull(hashesToBeBlocked);
+            assertNull(hashesRefCount);
         } else {
             assertNotNull(hashesToBeBlocked);
+            assertNotNull(hashesRefCount);
         }
 
         assertTrue(controller.isEmpty());
         assertEquals(messagesToRedeliver.size(), 0);
         if (!allowOutOfOrderDelivery) {
             assertEquals(hashesToBeBlocked.size(), 0);
+            assertEquals(hashesRefCount.size(), 0);
         }
 
         controller.add(1, 1);
@@ -77,6 +85,7 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
             assertEquals(hashesToBeBlocked.size(), 0);
             assertFalse(hashesToBeBlocked.containsKey(1, 1));
             assertFalse(hashesToBeBlocked.containsKey(1, 2));
+            assertEquals(hashesRefCount.size(), 0);
         }
 
         controller.remove(1, 1);
@@ -88,6 +97,7 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
         assertFalse(messagesToRedeliver.contains(1, 2));
         if (!allowOutOfOrderDelivery) {
             assertEquals(hashesToBeBlocked.size(), 0);
+            assertEquals(hashesRefCount.size(), 0);
         }
 
         controller.add(2, 1, 100);
@@ -104,6 +114,20 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
             assertEquals(hashesToBeBlocked.get(2, 1).first, 100);
             assertEquals(hashesToBeBlocked.get(2, 2).first, 101);
             assertEquals(hashesToBeBlocked.get(2, 3).first, 101);
+            assertEquals(hashesRefCount.size(), 2);
+            assertEquals(hashesRefCount.get(100), 1);
+            assertEquals(hashesRefCount.get(101), 2);
+        }
+
+        controller.remove(2, 1);
+        controller.remove(2, 2);
+
+        if (!allowOutOfOrderDelivery) {
+            assertEquals(hashesToBeBlocked.size(), 1);
+            assertEquals(hashesToBeBlocked.get(2, 3).first, 101);
+            assertEquals(hashesRefCount.size(), 1);
+            assertEquals(hashesRefCount.get(100), -1);
+            assertEquals(hashesRefCount.get(101), 1);
         }
 
         controller.clear();
@@ -113,6 +137,8 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
         if (!allowOutOfOrderDelivery) {
             assertEquals(hashesToBeBlocked.size(), 0);
             assertTrue(hashesToBeBlocked.isEmpty());
+            assertEquals(hashesRefCount.size(), 0);
+            assertTrue(hashesRefCount.isEmpty());
         }
 
         controller.add(2, 2, 201);
@@ -135,6 +161,11 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
             assertEquals(hashesToBeBlocked.get(2, 2).first, 201);
             assertEquals(hashesToBeBlocked.get(3, 1).first, 300);
             assertEquals(hashesToBeBlocked.get(3, 2).first, 301);
+            assertEquals(hashesRefCount.size(), 4);
+            assertEquals(hashesRefCount.get(200), 1);
+            assertEquals(hashesRefCount.get(201), 1);
+            assertEquals(hashesRefCount.get(300), 1);
+            assertEquals(hashesRefCount.get(301), 1);
         }
 
         controller.removeAllUpTo(3, 1);
@@ -143,6 +174,8 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
         if (!allowOutOfOrderDelivery) {
             assertEquals(hashesToBeBlocked.size(), 1);
             assertEquals(hashesToBeBlocked.get(3, 2).first, 301);
+            assertEquals(hashesRefCount.size(), 1);
+            assertEquals(hashesRefCount.get(301), 1);
         }
 
         controller.removeAllUpTo(5, 10);
@@ -150,6 +183,7 @@ public void testAddAndRemove(boolean allowOutOfOrderDelivery) throws Exception {
         assertEquals(messagesToRedeliver.size(), 0);
         if (!allowOutOfOrderDelivery) {
             assertEquals(hashesToBeBlocked.size(), 0);
+            assertEquals(hashesRefCount.size(), 0);
         }
     }