diff --git a/core-api/src/main/java/com/optimizely/ab/notification/NotificationManager.java b/core-api/src/main/java/com/optimizely/ab/notification/NotificationManager.java index 5254d76b8..7415e6b23 100644 --- a/core-api/src/main/java/com/optimizely/ab/notification/NotificationManager.java +++ b/core-api/src/main/java/com/optimizely/ab/notification/NotificationManager.java @@ -19,6 +19,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -33,7 +34,7 @@ public class NotificationManager { private static final Logger logger = LoggerFactory.getLogger(NotificationManager.class); - private final Map> handlers = new LinkedHashMap<>(); + private final Map> handlers = Collections.synchronizedMap(new LinkedHashMap<>()); private final AtomicInteger counter; public NotificationManager() { @@ -47,10 +48,12 @@ public NotificationManager(AtomicInteger counter) { public int addHandler(NotificationHandler newHandler) { // Prevent registering a duplicate listener. - for (NotificationHandler handler: handlers.values()) { - if (handler.equals(newHandler)) { - logger.warn("Notification listener was already added"); - return -1; + synchronized (handlers) { + for (NotificationHandler handler : handlers.values()) { + if (handler.equals(newHandler)) { + logger.warn("Notification listener was already added"); + return -1; + } } } @@ -61,11 +64,13 @@ public int addHandler(NotificationHandler newHandler) { } public void send(final T message) { - for (Map.Entry> handler: handlers.entrySet()) { - try { - handler.getValue().handle(message); - } catch (Exception e) { - logger.warn("Catching exception sending notification for class: {}, handler: {}", message.getClass(), handler.getKey()); + synchronized (handlers) { + for (Map.Entry> handler: handlers.entrySet()) { + try { + handler.getValue().handle(message); + } catch (Exception e) { + logger.warn("Catching exception sending notification for class: {}, handler: {}", message.getClass(), handler.getKey()); + } } } } diff --git a/core-api/src/test/java/com/optimizely/ab/notification/NotificationManagerTest.java b/core-api/src/test/java/com/optimizely/ab/notification/NotificationManagerTest.java index c51a84e3f..58767ac7a 100644 --- a/core-api/src/test/java/com/optimizely/ab/notification/NotificationManagerTest.java +++ b/core-api/src/test/java/com/optimizely/ab/notification/NotificationManagerTest.java @@ -20,6 +20,11 @@ import org.junit.Test; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.*; @@ -70,4 +75,32 @@ public void testSendWithError() { assertEquals(1, messages.size()); assertEquals("message1", messages.get(0).getMessage()); } + + @Test + public void testThreadSafety() throws InterruptedException { + int numThreads = 10; + int numRepeats = 2; + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + CountDownLatch latch = new CountDownLatch(numThreads); + AtomicBoolean failedAlready = new AtomicBoolean(false); + + for(int i = 0; i < numThreads; i++) { + executor.execute(() -> { + try { + for (int j = 0; j < numRepeats; j++) { + if(!failedAlready.get()) { + notificationManager.addHandler(new TestNotificationHandler<>()); + notificationManager.send(new TestNotification("message1")); + } + } + } catch (Exception e) { + failedAlready.set(true); + } finally { + latch.countDown(); + } + }); + } + assertTrue(latch.await(10, TimeUnit.SECONDS)); + assertEquals(numThreads * numRepeats, notificationManager.size()); + } }