diff --git a/cerberus-audit-logger-athena/src/main/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicy.java b/cerberus-audit-logger-athena/src/main/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicy.java index 42c97c976..2b1353fad 100644 --- a/cerberus-audit-logger-athena/src/main/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicy.java +++ b/cerberus-audit-logger-athena/src/main/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicy.java @@ -20,6 +20,7 @@ import com.nike.internal.util.StringUtils; import java.util.concurrent.LinkedBlockingQueue; import java.util.stream.Stream; +import lombok.Setter; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; @@ -30,7 +31,7 @@ public class AuditLogsS3TimeBasedRollingPolicy extends TimeBasedRollingPolicy private final String bucket; private final String bucketRegion; - private LinkedBlockingQueue logChunkFileS3Queue = new LinkedBlockingQueue<>(); + @Setter private LinkedBlockingQueue logChunkFileS3Queue = new LinkedBlockingQueue<>(); private S3LogUploaderService s3LogUploaderService = null; @Autowired @@ -54,7 +55,7 @@ private boolean isS3AuditLogCopyingEnabled() { @Override public void rollover() throws RolloverFailure { - super.rollover(); + superRollOver(); if (isS3AuditLogCopyingEnabled()) { String filename = timeBasedFileNamingAndTriggeringPolicy.getElapsedPeriodsFileName() + ".gz"; @@ -65,4 +66,9 @@ public void rollover() throws RolloverFailure { } } } + + void superRollOver() { + super.rollover(); + ; + } } diff --git a/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListener.java b/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListener.java index 310e75f34..dc34da812 100644 --- a/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListener.java +++ b/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListener.java @@ -71,17 +71,12 @@ public void onApplicationEvent(AuditableEvent event) { .put( "principal_type", cerberusPrincipal - .map(p -> cerberusPrincipal.get().getPrincipalType().getName()) + .map(p -> p.getPrincipalType().getName()) .orElse(AuditableEventContext.UNKNOWN)) .put( "principal_token_created", cerberusPrincipal - .map( - p -> - cerberusPrincipal - .get() - .getCreated() - .format(ATHENA_DATE_FORMATTER)) + .map(p -> p.getCreated().format(ATHENA_DATE_FORMATTER)) .orElseGet( () -> OffsetDateTime.parse(PARTY_LIKE_ITS_99, ISO_OFFSET_DATE_TIME) @@ -89,12 +84,7 @@ public void onApplicationEvent(AuditableEvent event) { .put( "principal_token_expires", cerberusPrincipal - .map( - p -> - cerberusPrincipal - .get() - .getExpires() - .format(ATHENA_DATE_FORMATTER)) + .map(p -> p.getExpires().format(ATHENA_DATE_FORMATTER)) .orElseGet( () -> OffsetDateTime.parse(PARTY_LIKE_ITS_99, ISO_OFFSET_DATE_TIME) diff --git a/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/service/S3LogUploaderService.java b/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/service/S3LogUploaderService.java index fad919816..fcdc87e0a 100644 --- a/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/service/S3LogUploaderService.java +++ b/cerberus-audit-logger-athena/src/main/java/com/nike/cerberus/audit/logger/service/S3LogUploaderService.java @@ -31,6 +31,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.annotation.PreDestroy; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -52,7 +53,7 @@ public class S3LogUploaderService { private static final String ATHENA_LOG_NAME = "athena-audit-logger"; private static final String ATHENA_LOG_APPENDER = "athena-log-appender"; - private final ExecutorService executor = Executors.newSingleThreadExecutor(); + @Setter private ExecutorService executor = Executors.newSingleThreadExecutor(); private final AmazonS3 amazonS3; private final String bucket; private final String bucketRegion; diff --git a/cerberus-audit-logger-athena/src/test/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicyTest.java b/cerberus-audit-logger-athena/src/test/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicyTest.java new file mode 100644 index 000000000..3cdb368b0 --- /dev/null +++ b/cerberus-audit-logger-athena/src/test/java/ch/qos/logback/core/rolling/AuditLogsS3TimeBasedRollingPolicyTest.java @@ -0,0 +1,76 @@ +package ch.qos.logback.core.rolling; + +import com.nike.cerberus.audit.logger.service.S3LogUploaderService; +import java.util.concurrent.LinkedBlockingQueue; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +public class AuditLogsS3TimeBasedRollingPolicyTest { + + private AuditLogsS3TimeBasedRollingPolicy auditLogsS3TimeBasedRollingPolicy; + + @Test + public void testLogUploaderServiceIfLogChunkFileS3QueueIsEmpty() { + auditLogsS3TimeBasedRollingPolicy = + new AuditLogsS3TimeBasedRollingPolicy("bucket", "bucketRegion"); + S3LogUploaderService s3LogUploaderService = Mockito.mock(S3LogUploaderService.class); + auditLogsS3TimeBasedRollingPolicy.setS3LogUploaderService(s3LogUploaderService); + Mockito.verify(s3LogUploaderService, Mockito.never()).ingestLog(Mockito.anyString()); + } + + @Test + public void testRollOverIfAuditCopyIsNotEnabled() { + auditLogsS3TimeBasedRollingPolicy = Mockito.spy(new AuditLogsS3TimeBasedRollingPolicy("", "")); + Mockito.doNothing().when(auditLogsS3TimeBasedRollingPolicy).superRollOver(); + S3LogUploaderService s3LogUploaderService = Mockito.mock(S3LogUploaderService.class); + auditLogsS3TimeBasedRollingPolicy.setS3LogUploaderService(s3LogUploaderService); + TimeBasedFileNamingAndTriggeringPolicy timeBasedFileNamingAndTriggeringPolicy = + Mockito.spy(TimeBasedFileNamingAndTriggeringPolicy.class); + auditLogsS3TimeBasedRollingPolicy.setTimeBasedFileNamingAndTriggeringPolicy( + timeBasedFileNamingAndTriggeringPolicy); + auditLogsS3TimeBasedRollingPolicy.rollover(); + Mockito.verify(timeBasedFileNamingAndTriggeringPolicy, Mockito.never()) + .getElapsedPeriodsFileName(); + Mockito.verify(s3LogUploaderService, Mockito.never()).ingestLog(Mockito.anyString()); + } + + @Test + public void testRollOverIfAuditCopyIsEnabled() { + auditLogsS3TimeBasedRollingPolicy = + Mockito.spy(new AuditLogsS3TimeBasedRollingPolicy("bucket", "region")); + Mockito.doNothing().when(auditLogsS3TimeBasedRollingPolicy).superRollOver(); + S3LogUploaderService s3LogUploaderService = Mockito.mock(S3LogUploaderService.class); + auditLogsS3TimeBasedRollingPolicy.setS3LogUploaderService(s3LogUploaderService); + TimeBasedFileNamingAndTriggeringPolicy timeBasedFileNamingAndTriggeringPolicy = + Mockito.spy(TimeBasedFileNamingAndTriggeringPolicy.class); + Mockito.when(timeBasedFileNamingAndTriggeringPolicy.getElapsedPeriodsFileName()) + .thenReturn("elapsedfilename"); + auditLogsS3TimeBasedRollingPolicy.setTimeBasedFileNamingAndTriggeringPolicy( + timeBasedFileNamingAndTriggeringPolicy); + LinkedBlockingQueue logChunkFileS3Queue = new LinkedBlockingQueue<>(); + auditLogsS3TimeBasedRollingPolicy.setLogChunkFileS3Queue(logChunkFileS3Queue); + auditLogsS3TimeBasedRollingPolicy.rollover(); + Mockito.verify(timeBasedFileNamingAndTriggeringPolicy).getElapsedPeriodsFileName(); + Mockito.verify(s3LogUploaderService).ingestLog("elapsedfilename.gz"); + Assert.assertTrue(logChunkFileS3Queue.size() == 0); + } + + @Test + public void testRollOverIfAuditCopyIsEnabledAndS3UploaderIsNull() { + auditLogsS3TimeBasedRollingPolicy = + Mockito.spy(new AuditLogsS3TimeBasedRollingPolicy("bucket", "region")); + Mockito.doNothing().when(auditLogsS3TimeBasedRollingPolicy).superRollOver(); + TimeBasedFileNamingAndTriggeringPolicy timeBasedFileNamingAndTriggeringPolicy = + Mockito.spy(TimeBasedFileNamingAndTriggeringPolicy.class); + Mockito.when(timeBasedFileNamingAndTriggeringPolicy.getElapsedPeriodsFileName()) + .thenReturn("elapsedfilename"); + auditLogsS3TimeBasedRollingPolicy.setTimeBasedFileNamingAndTriggeringPolicy( + timeBasedFileNamingAndTriggeringPolicy); + LinkedBlockingQueue logChunkFileS3Queue = new LinkedBlockingQueue<>(); + auditLogsS3TimeBasedRollingPolicy.setLogChunkFileS3Queue(logChunkFileS3Queue); + auditLogsS3TimeBasedRollingPolicy.rollover(); + Mockito.verify(timeBasedFileNamingAndTriggeringPolicy).getElapsedPeriodsFileName(); + Assert.assertTrue(logChunkFileS3Queue.size() > 0); + } +} diff --git a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/AthenaClientFactoryTest.java b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/AthenaClientFactoryTest.java new file mode 100644 index 000000000..5fbe03a59 --- /dev/null +++ b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/AthenaClientFactoryTest.java @@ -0,0 +1,28 @@ +package com.nike.cerberus.audit.logger; + +import com.amazonaws.services.athena.AmazonAthena; +import org.junit.Assert; +import org.junit.Test; + +public class AthenaClientFactoryTest { + + @Test + public void testGetClientAlwaysReturnsSameAthenaInstance() { + AthenaClientFactory athenaClientFactory = new AthenaClientFactory(); + AmazonAthena clientInstance1 = athenaClientFactory.getClient("region-2"); + AmazonAthena clientInstance2 = athenaClientFactory.getClient("region-2"); + Assert.assertSame(clientInstance1, clientInstance2); + } + + @Test + public void testGetClientDoesNotThrowNPEWhenRegionIsEmptyString() { + AthenaClientFactory athenaClientFactory = new AthenaClientFactory(); + athenaClientFactory.getClient(""); + } + + @Test(expected = NullPointerException.class) + public void testGetClientThrowsNPEWhenRegionIsNull() { + AthenaClientFactory athenaClientFactory = new AthenaClientFactory(); + athenaClientFactory.getClient(null); + } +} diff --git a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/S3ClientFactoryTest.java b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/S3ClientFactoryTest.java new file mode 100644 index 000000000..dac86381a --- /dev/null +++ b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/S3ClientFactoryTest.java @@ -0,0 +1,28 @@ +package com.nike.cerberus.audit.logger; + +import com.amazonaws.services.s3.AmazonS3; +import org.junit.Assert; +import org.junit.Test; + +public class S3ClientFactoryTest { + + @Test + public void testS3ClientFactoryAlwaysReturnSameInstance() { + S3ClientFactory s3ClientFactory = new S3ClientFactory(); + AmazonS3 s3Instance1 = s3ClientFactory.getClient("region-1"); + AmazonS3 s3Instance2 = s3ClientFactory.getClient("region-1"); + Assert.assertSame(s3Instance1, s3Instance2); + } + + @Test + public void testS3ClientFactoryDoesNotThrowsNPEWhenEmptyStringIsPassedAsRegion() { + S3ClientFactory s3ClientFactory = new S3ClientFactory(); + s3ClientFactory.getClient(""); + } + + @Test(expected = NullPointerException.class) + public void testS3ClientFactoryThrowsNPEWhenNullRegionIsPassed() { + S3ClientFactory s3ClientFactory = new S3ClientFactory(); + s3ClientFactory.getClient(null); + } +} diff --git a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListenerTest.java b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListenerTest.java new file mode 100644 index 000000000..912752af2 --- /dev/null +++ b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/listener/AthenaLoggingEventListenerTest.java @@ -0,0 +1,327 @@ +package com.nike.cerberus.audit.logger.listener; + +import com.nike.cerberus.event.AuditableEvent; +import com.nike.cerberus.event.AuditableEventContext; +import java.time.OffsetDateTime; +import java.util.Optional; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.slf4j.Logger; + +public class AthenaLoggingEventListenerTest { + + @Mock private Logger auditLogger; + + @InjectMocks private AthenaLoggingEventListener athenaLoggingEventListener; + + private AuditableEvent auditableEvent; + + private AuditableEventContext auditableEventContext; + + @Before + public void setup() { + auditableEventContext = Mockito.mock(AuditableEventContext.class); + auditableEvent = Mockito.mock(AuditableEvent.class); + MockitoAnnotations.initMocks(this); + Mockito.when(auditableEventContext.getTimestamp()) + .thenReturn(OffsetDateTime.parse("2007-12-03T10:15:30+01:00")); + Mockito.when(auditableEventContext.getPrincipalAsCerberusPrincipal()) + .thenReturn(Optional.empty()); + } + + private void mockAuditableEvent() { + Mockito.when(auditableEvent.getAuditableEventContext()).thenReturn(auditableEventContext); + } + + private void mockAuditableEventContextPrincipalName() { + Mockito.when(auditableEventContext.getPrincipalName()).thenReturn("pricinpleName"); + } + + private void mockAuditableEventContextIpAddress() { + Mockito.when(auditableEventContext.getIpAddress()).thenReturn("ipAddress"); + } + + private void mockAuditableEventContextXForwarded() { + Mockito.when(auditableEventContext.getXForwardedFor()).thenReturn("xforwarder"); + } + + private void mockAuditableEventContextClientVersion() { + Mockito.when(auditableEventContext.getClientVersion()).thenReturn("clientVersion"); + } + + private void mockAuditableEventContextCerberusVerion() { + Mockito.when(auditableEventContext.getVersion()).thenReturn("version"); + } + + private void mockAuditableEventContextPost() { + Mockito.when(auditableEventContext.getMethod()).thenReturn("post"); + } + + private void mockAuditableEventContextPath() { + Mockito.when(auditableEventContext.getPath()).thenReturn("path"); + } + + @Test + public void testIfAuditableContextIsMissingInAuditEvent() { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + Mockito.verify(auditLogger, Mockito.never()).info(Mockito.anyString()); + } + + @Test + public void testIfAuditableContextIsPresentInAuditEvent() { + mockAuditableEventContextPrincipalName(); + mockAuditableEvent(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + Mockito.when(auditableEventContext.getTraceId()).thenReturn("traceId"); + Mockito.when(auditableEventContext.getStatusCode()).thenReturn(200); + Mockito.when(auditableEventContext.getAction()).thenReturn("action"); + Mockito.when(auditableEventContext.isSuccess()).thenReturn(true); + Mockito.when(auditableEventContext.getEventName()).thenReturn("eventName"); + Mockito.when(auditableEventContext.getOriginatingClass()).thenReturn("originatingClass"); + + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + Mockito.verify(auditLogger).info(Mockito.anyString()); + } + + @Test(expected = NullPointerException.class) + public void testIfEventTimeStampIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } + + @Test + public void testIfPrincipleNameIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("principal_name")); + } + + @Test + public void testIfIpAddressIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("ip_address")); + } + + @Test + public void testIfXForwardedForIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("x_forwarded_for")); + } + + @Test + public void testIfCerberusVersionIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("cerberus_version")); + } + + @Test + public void testIfClientVersionIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("client_version")); + } + + @Test + public void testIfHttpMethodIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("http_method")); + } + + @Test + public void testIfPathIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("path")); + } + + @Test + public void testIfActionIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("action")); + } + + @Test + public void testIfEventNameIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + Mockito.when(auditableEventContext.getAction()).thenReturn("action"); + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("name")); + } + + @Test + public void testIfOriginatingClassIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + Mockito.when(auditableEventContext.getAction()).thenReturn("action"); + Mockito.when(auditableEventContext.getEventName()).thenReturn("eventName"); + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("originating_class")); + } + + @Test + public void testIfTraceIdIsMissingInContext() { + mockAuditableEvent(); + mockAuditableEventContextPrincipalName(); + mockAuditableEventContextIpAddress(); + mockAuditableEventContextXForwarded(); + mockAuditableEventContextClientVersion(); + mockAuditableEventContextCerberusVerion(); + mockAuditableEventContextPost(); + mockAuditableEventContextPath(); + + Mockito.when(auditableEventContext.getAction()).thenReturn("action"); + Mockito.when(auditableEventContext.getEventName()).thenReturn("eventName"); + Mockito.when(auditableEventContext.getOriginatingClass()).thenReturn("originatingClass"); + String exceptionMessage = ""; + try { + athenaLoggingEventListener.onApplicationEvent(auditableEvent); + } catch (NullPointerException nullPointerException) { + exceptionMessage = nullPointerException.getMessage(); + } + Assert.assertTrue(exceptionMessage.contains("trace_id")); + } +} diff --git a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/AthenaServiceTest.java b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/AthenaServiceTest.java new file mode 100644 index 000000000..0e64d38e5 --- /dev/null +++ b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/AthenaServiceTest.java @@ -0,0 +1,67 @@ +package com.nike.cerberus.audit.logger.service; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.*; + +import com.amazonaws.AmazonClientException; +import com.amazonaws.services.athena.AmazonAthena; +import com.amazonaws.services.athena.model.StartQueryExecutionResult; +import com.nike.cerberus.audit.logger.AthenaClientFactory; +import java.util.HashSet; +import java.util.Set; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.internal.util.reflection.Whitebox; + +public class AthenaServiceTest { + + private AthenaClientFactory athenaClientFactory; + + private AthenaService athenaService; + + @Mock private Set partitions = new HashSet<>(); + + @Before + public void before() { + athenaClientFactory = mock(AthenaClientFactory.class); + athenaService = new AthenaService("fake-bucket", athenaClientFactory); + Whitebox.setInternalState(athenaService, "partitions", partitions); + } + + @Test + public void test_that_addPartition_works() { + String awsRegion = "us-west-2"; + AmazonAthena athena = mock(AmazonAthena.class); + when(athenaClientFactory.getClient(awsRegion)).thenReturn(athena); + when(athena.startQueryExecution(Mockito.any())) + .thenReturn(new StartQueryExecutionResult().withQueryExecutionId("query-execution-id")); + athenaService.addPartitionIfMissing(awsRegion, "fake-bucket", "2018", "01", "29", "12"); + verify(athenaClientFactory, times(1)).getClient(anyString()); + assertEquals(1, partitions.size()); + } + + @Test + public void test_addPartition_already_exist() { + String awsRegion = "us-west-2"; + String partition = String.format("year=%s/month=%s/day=%s/hour=%s", "2018", "01", "29", "12"); + Set pars = new HashSet<>(); + pars.add(partition); + Whitebox.setInternalState(athenaService, "partitions", pars); + athenaService.addPartitionIfMissing(awsRegion, "fake-bucket", "2018", "01", "29", "12"); + verify(athenaClientFactory, never()).getClient(anyString()); + assertEquals(1, pars.size()); + } + + @Test + public void test_addPartition_fails_when_amazon_client_exception() { + String awsRegion = "us-west-2"; + AmazonAthena athena = mock(AmazonAthena.class); + when(athenaClientFactory.getClient(awsRegion)).thenReturn(athena); + when(athena.startQueryExecution(Mockito.any())).thenThrow(AmazonClientException.class); + athenaService.addPartitionIfMissing(awsRegion, "fake-bucket", "2018", "01", "29", "12"); + verify(athenaClientFactory, times(1)).getClient(anyString()); + assertEquals(0, partitions.size()); + } +} diff --git a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/S3LogUploaderServiceTest.java b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/S3LogUploaderServiceTest.java index ea594dbf8..fac2971ad 100644 --- a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/S3LogUploaderServiceTest.java +++ b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/audit/logger/service/S3LogUploaderServiceTest.java @@ -22,9 +22,12 @@ import ch.qos.logback.classic.Logger; import ch.qos.logback.classic.LoggerContext; import com.nike.cerberus.audit.logger.S3ClientFactory; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; +import org.mockito.Mockito; public class S3LogUploaderServiceTest { @@ -34,7 +37,7 @@ public class S3LogUploaderServiceTest { private S3LogUploaderService s3LogUploader; - private Logger logger = new LoggerContext().getLogger("test-logger"); + private final Logger logger = new LoggerContext().getLogger("test-logger"); @Before public void before() { @@ -51,4 +54,30 @@ public void test_that_getPartition_works() { String expected = "partitioned/year=2018/month=01/day=29/hour=12"; assertEquals(expected, actual); } + + @Test + public void test_executor_shutdown_works() throws InterruptedException { + + ExecutorService executorService = Mockito.mock(ExecutorService.class); + s3LogUploader.setExecutor(executorService); + s3LogUploader.executeServerShutdownHook(); + + Mockito.verify(executorService).shutdown(); + Mockito.verify(executorService).awaitTermination(10, TimeUnit.MINUTES); + } + + @Test + public void test_executor_shutdown_works_when_awaits_termination_throws_exception() + throws InterruptedException { + + ExecutorService executorService = Mockito.mock(ExecutorService.class); + Mockito.when(executorService.awaitTermination(10, TimeUnit.MINUTES)) + .thenThrow(new InterruptedException()); + s3LogUploader.setExecutor(executorService); + s3LogUploader.executeServerShutdownHook(); + + Mockito.verify(executorService).shutdown(); + Mockito.verify(executorService).shutdownNow(); + Mockito.verify(executorService).awaitTermination(10, TimeUnit.MINUTES); + } } diff --git a/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/config/AthenaAuditLoggerConfigurationTest.java b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/config/AthenaAuditLoggerConfigurationTest.java new file mode 100644 index 000000000..cc09bf176 --- /dev/null +++ b/cerberus-audit-logger-athena/src/test/java/com/nike/cerberus/config/AthenaAuditLoggerConfigurationTest.java @@ -0,0 +1,29 @@ +package com.nike.cerberus.config; + +import ch.qos.logback.classic.LoggerContext; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.rolling.AuditLogsS3TimeBasedRollingPolicy; +import ch.qos.logback.core.rolling.FiveMinuteRollingFileAppender; +import ch.qos.logback.core.util.FileSize; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.Mockito; + +public class AthenaAuditLoggerConfigurationTest { + + @Test + public void testRollingPolicyStarted() { + AuditLogsS3TimeBasedRollingPolicy auditLogsS3TimeBasedRollingPolicy = + Mockito.mock(AuditLogsS3TimeBasedRollingPolicy.class); + AthenaAuditLoggerConfiguration athenaAuditLoggerConfiguration = + new AthenaAuditLoggerConfiguration("", auditLogsS3TimeBasedRollingPolicy); + Mockito.verify(auditLogsS3TimeBasedRollingPolicy).setContext(Mockito.any(LoggerContext.class)); + Mockito.verify(auditLogsS3TimeBasedRollingPolicy).setFileNamePattern(Mockito.anyString()); + Mockito.verify(auditLogsS3TimeBasedRollingPolicy).setMaxHistory(100); + Mockito.verify(auditLogsS3TimeBasedRollingPolicy) + .setParent(Mockito.any(FiveMinuteRollingFileAppender.class)); + Mockito.verify(auditLogsS3TimeBasedRollingPolicy).setTotalSizeCap(Mockito.any(FileSize.class)); + Mockito.verify(auditLogsS3TimeBasedRollingPolicy).start(); + Assert.assertNotNull(athenaAuditLoggerConfiguration.getAthenaAuditLogger()); + } +} diff --git a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaApiClientHelperTest.java b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaApiClientHelperTest.java index 2612779b5..0f48dd833 100644 --- a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaApiClientHelperTest.java +++ b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaApiClientHelperTest.java @@ -72,6 +72,23 @@ public void getUserGroupsHappy() throws Exception { assertTrue(result.contains(group)); } + @Test + public void getUserGroupsWithLimit() throws Exception { + + String id = "id"; + UserGroup group = mock(UserGroup.class); + PagedResults res = mock(PagedResults.class); + when(res.getResult()).thenReturn(Lists.newArrayList(group)); + when(res.isLastPage()).thenReturn(true); + when(userGroupApiClient.getUserGroupsPagedResultsByUrl(anyString())).thenReturn(res); + + // do the call + List result = this.oktaApiClientHelper.getUserGroups(id, 1); + + // verify results + assertTrue(result.contains(group)); + } + @Test(expected = ApiException.class) public void getUserGroupsFails() throws Exception { diff --git a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/PushStateHandlerTest.java b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/PushStateHandlerTest.java new file mode 100644 index 000000000..da68789e7 --- /dev/null +++ b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/PushStateHandlerTest.java @@ -0,0 +1,92 @@ +package com.nike.cerberus.auth.connector.okta; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; + +import com.nike.cerberus.auth.connector.AuthResponse; +import com.nike.cerberus.auth.connector.AuthStatus; +import com.nike.cerberus.auth.connector.okta.statehandlers.PushStateHandler; +import com.okta.authn.sdk.client.AuthenticationClient; +import com.okta.authn.sdk.resource.AuthenticationResponse; +import com.okta.authn.sdk.resource.AuthenticationStatus; +import com.okta.authn.sdk.resource.User; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +public class PushStateHandlerTest { + + private PushStateHandler pushStateHandler; + + @Mock private AuthenticationClient client; + private CompletableFuture authenticationResponseFuture; + + @Before + public void setup() { + + initMocks(this); + + authenticationResponseFuture = new CompletableFuture<>(); + + // create test object + this.pushStateHandler = new PushStateHandler(client, authenticationResponseFuture) {}; + } + + @Test + public void handleMfaChallengeHappy() + throws InterruptedException, ExecutionException, TimeoutException { + String email = "email"; + String id = "id"; + AuthStatus status = AuthStatus.MFA_CHALLENGE; + + AuthenticationResponse expectedResponse = mock(AuthenticationResponse.class); + + User user = mock(User.class); + when(user.getId()).thenReturn(id); + when(user.getLogin()).thenReturn(email); + when(expectedResponse.getUser()).thenReturn(user); + when(expectedResponse.getStatus()).thenReturn(AuthenticationStatus.MFA_CHALLENGE); + + // do the call + pushStateHandler.handleMfaChallenge(expectedResponse); + + AuthResponse actualResponse = authenticationResponseFuture.get(1, TimeUnit.SECONDS); + + // verify results + assertEquals(id, actualResponse.getData().getUserId()); + assertEquals(email, actualResponse.getData().getUsername()); + assertEquals(status, actualResponse.getStatus()); + } + + @Test + public void handleMfaSuccessHappy() + throws InterruptedException, ExecutionException, TimeoutException { + String email = "email"; + String id = "id"; + AuthStatus status = AuthStatus.SUCCESS; + + AuthenticationResponse expectedResponse = mock(AuthenticationResponse.class); + + User user = mock(User.class); + when(user.getId()).thenReturn(id); + when(user.getLogin()).thenReturn(email); + when(expectedResponse.getUser()).thenReturn(user); + when(expectedResponse.getStatus()).thenReturn(AuthenticationStatus.SUCCESS); + + // do the call + pushStateHandler.handleSuccess(expectedResponse); + + AuthResponse actualResponse = authenticationResponseFuture.get(1, TimeUnit.SECONDS); + + // verify results + assertEquals(id, actualResponse.getData().getUserId()); + assertEquals(email, actualResponse.getData().getUsername()); + assertEquals(status, actualResponse.getStatus()); + } +} diff --git a/cerberus-core/src/test/java/com/nike/cerberus/event/AuditableEventContextTest.java b/cerberus-core/src/test/java/com/nike/cerberus/event/AuditableEventContextTest.java new file mode 100644 index 000000000..a38949e0f --- /dev/null +++ b/cerberus-core/src/test/java/com/nike/cerberus/event/AuditableEventContextTest.java @@ -0,0 +1,94 @@ +package com.nike.cerberus.event; + +import com.nike.cerberus.domain.CerberusAuthToken; +import java.time.OffsetDateTime; +import java.util.Optional; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class AuditableEventContextTest { + + private AuditableEventContext auditableEventContext; + + @Before + public void setup() { + auditableEventContext = new AuditableEventContext.AuditableEventContextBuilder().build(); + } + + @Test + public void testCheckAuthTokenIsEmptyIfPrincipleIsNotInstanceOfCerberusAuthToken() { + auditableEventContext.setPrincipal(new Object()); + Optional principalAsCerberusPrincipal = + auditableEventContext.getPrincipalAsCerberusPrincipal(); + Assert.assertFalse(principalAsCerberusPrincipal.isPresent()); + } + + @Test + public void testCheckAuthTokenIsEmptyIfPrincipleIsInstanceOfCerberusAuthToken() { + CerberusAuthToken cerberusAuthToken = new CerberusAuthToken(); + auditableEventContext.setPrincipal(cerberusAuthToken); + Optional principalAsCerberusPrincipal = + auditableEventContext.getPrincipalAsCerberusPrincipal(); + Assert.assertTrue(principalAsCerberusPrincipal.isPresent()); + Assert.assertSame(cerberusAuthToken, principalAsCerberusPrincipal.get()); + } + + @Test + public void testGetPrincipalNameIfPrincipleIsInstanceOfCerberusAuthToken() { + CerberusAuthToken cerberusAuthToken = new CerberusAuthToken(); + String cerberusPrinciple = "cerberusPrinciple"; + cerberusAuthToken.setPrincipal(cerberusPrinciple); + auditableEventContext.setPrincipal(cerberusAuthToken); + String principalName = auditableEventContext.getPrincipalName(); + Assert.assertEquals(cerberusPrinciple, principalName); + } + + @Test + public void testGetPrincipalNameIfPrincipleIsInstanceOfString() { + String stringPrinciple = "stringPrinciple"; + auditableEventContext.setPrincipal(stringPrinciple); + String principalName = auditableEventContext.getPrincipalName(); + Assert.assertEquals(stringPrinciple, principalName); + } + + @Test + public void testGetPrincipalNameIfPrincipleIsNull() { + String principalName = auditableEventContext.getPrincipalName(); + Assert.assertEquals("Unknown", principalName); + } + + @Test + public void testGetPrincipalNameIfPrincipleIsInstanceOfObject() { + Object principle = new Object(); + auditableEventContext.setPrincipal(principle); + String principalName = auditableEventContext.getPrincipalName(); + Assert.assertEquals(principle.toString(), principalName); + } + + @Test + public void testGetEventsAsString() { + auditableEventContext = + new AuditableEventContext.AuditableEventContextBuilder() + .eventName("eventName") + .principal("principal") + .action("action") + .method("method") + .statusCode(0) + .success(true) + .path("path") + .ipAddress("ipaddress") + .xForwardedFor("xforwarder") + .clientVersion("1") + .version("0") + .originatingClass("originating") + .sdbNameSlug("sdbname") + .traceId("traceId") + .timestamp(OffsetDateTime.parse("2007-12-03T10:15:30+01:00")) + .build(); + String eventAsString = auditableEventContext.getEventAsString(); + String events = + "eventName, Principal: principal, Action: 'action', Method: method, Status Code: 0, Was Success: true, Path: path, IP Address: ipaddress, X-Forwarded-For: xforwarder, Client Version: 1, Cerberus Version: 0, Originating Class: originating, SDB Name Slug: sdbname, Trace ID: traceId, Event Timestamp: Dec 3 2007, 10:15:30 AM +0100"; + Assert.assertEquals(events, eventAsString); + } +} diff --git a/cerberus-core/src/test/java/com/nike/cerberus/util/CustomApiErrorTest.java b/cerberus-core/src/test/java/com/nike/cerberus/util/CustomApiErrorTest.java new file mode 100644 index 000000000..a18186fa0 --- /dev/null +++ b/cerberus-core/src/test/java/com/nike/cerberus/util/CustomApiErrorTest.java @@ -0,0 +1,24 @@ +package com.nike.cerberus.util; + +import static junit.framework.TestCase.assertEquals; + +import com.nike.backstopper.apierror.ApiError; +import com.nike.cerberus.error.DefaultApiError; +import org.junit.Test; + +public class CustomApiErrorTest { + + @Test(expected = NullPointerException.class) + public void test_creat_custom_api_error_fails() { + CustomApiError.createCustomApiError(null, "message"); + } + + @Test + public void test_creat_custom_api_error_works() { + DefaultApiError err = DefaultApiError.SDB_UNIQUE_NAME; + ApiError error = CustomApiError.createCustomApiError(err, "message"); + String actual = error.getMessage(); + String expected = err.getMessage() + " " + "message"; + assertEquals(expected, actual); + } +} diff --git a/cerberus-domain/src/test/java/com/nike/cerberus/validation/PatternListAnyMatchValidatorTest.java b/cerberus-domain/src/test/java/com/nike/cerberus/validation/PatternListAnyMatchValidatorTest.java new file mode 100644 index 000000000..26558e746 --- /dev/null +++ b/cerberus-domain/src/test/java/com/nike/cerberus/validation/PatternListAnyMatchValidatorTest.java @@ -0,0 +1,24 @@ +package com.nike.cerberus.validation; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mockito; + +public class PatternListAnyMatchValidatorTest { + PatternListAnyMatch patternListAnyMatch; + PatternListAnyMatchValidator patternListAnyMatchValidator = new PatternListAnyMatchValidator(); + + @Before + public void setup() { + patternListAnyMatch = Mockito.mock(PatternListAnyMatch.class); + } + + @Test + public void test_isValid() { + Mockito.when(patternListAnyMatch.value()).thenReturn(new String[] {"\\d"}); + patternListAnyMatchValidator.initialize(patternListAnyMatch); + Assert.assertTrue(patternListAnyMatchValidator.isValid("1", null)); + Assert.assertFalse(patternListAnyMatchValidator.isValid("s", null)); + } +} diff --git a/gradle.properties b/gradle.properties index dc0c92b66..b363f9dbd 100644 --- a/gradle.properties +++ b/gradle.properties @@ -14,5 +14,5 @@ # limitations under the License. # -version=4.10.2 +version=4.10.3 group=com.nike.cerberus