Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hardens PropertiesUtil against recursive property sources #3263

Merged
merged 5 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import java.util.Map;
import java.util.Properties;
import java.util.stream.Stream;
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.test.ListStatusListener;
import org.apache.logging.log4j.test.junit.UsingStatusListener;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.parallel.ResourceAccessMode;
Expand Down Expand Up @@ -193,16 +196,56 @@ void testPublish() {
@Test
@ResourceLock(value = Resources.SYSTEM_PROPERTIES, mode = ResourceAccessMode.READ)
@Issue("https://github.com/spring-projects/spring-boot/issues/33450")
void testBadPropertySource() {
@UsingStatusListener
void testErrorPropertySource(ListStatusListener statusListener) {
final String key = "testKey";
final Properties props = new Properties();
props.put(key, "test");
final PropertiesUtil util = new PropertiesUtil(props);
final ErrorPropertySource source = new ErrorPropertySource();
util.addPropertySource(source);
try {
statusListener.clear();
assertEquals("test", util.getStringProperty(key));
assertTrue(source.exceptionThrown);
assertThat(statusListener.findStatusData(Level.WARN))
.anySatisfy(data ->
assertThat(data.getMessage().getFormattedMessage()).contains("Failed"));
} finally {
util.removePropertySource(source);
}
}

@Test
@ResourceLock(value = Resources.SYSTEM_PROPERTIES, mode = ResourceAccessMode.READ)
@Issue("https://github.com/apache/logging-log4j2/issues/3252")
@UsingStatusListener
void testRecursivePropertySource(ListStatusListener statusListener) {
final String key = "testKey";
final Properties props = new Properties();
props.put(key, "test");
final PropertiesUtil util = new PropertiesUtil(props);
final PropertySource source = new RecursivePropertySource(util);
util.addPropertySource(source);
try {
// We ignore the recursive source
statusListener.clear();
assertThat(util.getStringProperty(key)).isEqualTo("test");
assertThat(statusListener.findStatusData(Level.WARN))
.anySatisfy(data -> assertThat(data.getMessage().getFormattedMessage())
.contains("Recursive call", "getProperty"));

statusListener.clear();
// To check for existence, the sources are looked up in a random order.
assertThat(util.hasProperty(key)).isTrue();
// To find a missing key, all the sources must be used.
assertThat(util.hasProperty("noSuchKey")).isFalse();
assertThat(statusListener.findStatusData(Level.WARN))
.anySatisfy(data -> assertThat(data.getMessage().getFormattedMessage())
.contains("Recursive call", "containsProperty"));
// We check that the source is recursive
assertThat(source.getProperty(key)).isEqualTo("test");
assertThat(source.containsProperty(key)).isTrue();
} finally {
util.removePropertySource(source);
}
Expand Down Expand Up @@ -289,4 +332,28 @@ public boolean containsProperty(final String key) {
throw new IllegalStateException("Test");
}
}

private static class RecursivePropertySource implements PropertySource {

private final PropertiesUtil propertiesUtil;

private RecursivePropertySource(PropertiesUtil propertiesUtil) {
this.propertiesUtil = propertiesUtil;
}

@Override
public int getPriority() {
return Integer.MIN_VALUE;
}

@Override
public String getProperty(String key) {
return propertiesUtil.getStringProperty(key);
}

@Override
public boolean containsProperty(String key) {
return propertiesUtil.hasProperty(key);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ public void reload() {}
private static final class Environment {

private final Set<PropertySource> sources = ConcurrentHashMap.newKeySet();
private final ThreadLocal<PropertySource> CURRENT_PROPERTY_SOURCE = new ThreadLocal<>();

private Environment(final PropertySource propertySource) {
final PropertySource sysProps = new PropertyFilePropertySource(LOG4J_SYSTEM_PROPERTIES_FILE_NAME, false);
Expand Down Expand Up @@ -547,21 +548,35 @@ private String get(final String key) {
}

private boolean sourceContainsProperty(final PropertySource source, final String key) {
try {
return source.containsProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
return false;
PropertySource recursiveSource = CURRENT_PROPERTY_SOURCE.get();
if (recursiveSource == null) {
CURRENT_PROPERTY_SOURCE.set(source);
try {
return source.containsProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
} finally {
CURRENT_PROPERTY_SOURCE.remove();
}
}
LOGGER.warn("Recursive call to `containsProperty()` from property source {}.", recursiveSource);
return false;
}

private String sourceGetProperty(final PropertySource source, final String key) {
try {
return source.getProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
return null;
PropertySource recursiveSource = CURRENT_PROPERTY_SOURCE.get();
if (recursiveSource == null) {
CURRENT_PROPERTY_SOURCE.set(source);
try {
return source.getProperty(key);
} catch (final Exception e) {
LOGGER.warn("Failed to retrieve Log4j property {} from property source {}.", key, source, e);
} finally {
CURRENT_PROPERTY_SOURCE.remove();
}
}
LOGGER.warn("Recursive call to `getProperty()` from property source {}.", recursiveSource);
return null;
}

private boolean containsKey(final String key) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,61 @@

import org.apache.logging.log4j.message.AbstractMessageFactory;
import org.apache.logging.log4j.message.DefaultFlowMessageFactory;
import org.apache.logging.log4j.message.FlowMessageFactory;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.MessageFactory;
import org.apache.logging.log4j.message.ParameterizedMessageFactory;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.ClearSystemProperty;
import org.junit.jupiter.api.TestInfo;
import org.junitpioneer.jupiter.SetSystemProperty;

@SetSystemProperty(
key = "log4j2.messageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$AlternativeTestMessageFactory")
@SetSystemProperty(
key = "log4j2.flowMessageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$AlternativeTestFlowMessageFactory")
class LoggerMessageFactoryCustomizationTest {

@Test
@ClearSystemProperty(key = "log4j2.messageFactory")
@ClearSystemProperty(key = "log4j2.flowMessageFactory")
void arguments_should_be_honored() {
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName());
final Logger logger = new Logger(
loggerContext, "arguments_should_be_honored", new TestMessageFactory(), new TestFlowMessageFactory());
assertTestMessageFactories(logger);
void arguments_should_be_honored(TestInfo testInfo) {
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName())) {
Logger logger = new Logger(
loggerContext, testInfo.getDisplayName(), new TestMessageFactory(), new TestFlowMessageFactory());
assertTestMessageFactories(logger, TestMessageFactory.class, TestFlowMessageFactory.class);
}
}

@Test
@SetSystemProperty(
key = "log4j2.messageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$TestMessageFactory")
@SetSystemProperty(
key = "log4j2.flowMessageFactory",
value = "org.apache.logging.log4j.core.LoggerMessageFactoryCustomizationTest$TestFlowMessageFactory")
void properties_should_be_honored() {
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName());
final Logger logger = new Logger(loggerContext, "properties_should_be_honored", null, null);
assertTestMessageFactories(logger);
void properties_should_be_honored(TestInfo testInfo) {
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryCustomizationTest.class.getSimpleName())) {
Logger logger = loggerContext.getLogger(testInfo.getDisplayName());
assertTestMessageFactories(
logger, AlternativeTestMessageFactory.class, AlternativeTestFlowMessageFactory.class);
}
}

private static void assertTestMessageFactories(Logger logger) {
assertThat((MessageFactory) logger.getMessageFactory()).isInstanceOf(TestMessageFactory.class);
assertThat(logger.getFlowMessageFactory()).isInstanceOf(TestFlowMessageFactory.class);
private static void assertTestMessageFactories(
Logger logger,
Class<? extends MessageFactory> messageFactoryClass,
Class<? extends FlowMessageFactory> flowMessageFactoryClass) {
assertThat((MessageFactory) logger.getMessageFactory()).isInstanceOf(messageFactoryClass);
assertThat(logger.getFlowMessageFactory()).isInstanceOf(flowMessageFactoryClass);
}

public static final class TestMessageFactory extends AbstractMessageFactory {
public static class TestMessageFactory extends AbstractMessageFactory {

@Override
public Message newMessage(final String message, final Object... params) {
return ParameterizedMessageFactory.INSTANCE.newMessage(message, params);
}
}

public static final class TestFlowMessageFactory extends DefaultFlowMessageFactory {}
public static class AlternativeTestMessageFactory extends TestMessageFactory {}

public static class TestFlowMessageFactory extends DefaultFlowMessageFactory {}

public static class AlternativeTestFlowMessageFactory extends TestFlowMessageFactory {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@
import org.apache.logging.log4j.message.MessageFactory;
import org.apache.logging.log4j.message.ParameterizedMessageFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junitpioneer.jupiter.SetSystemProperty;

class LoggerMessageFactoryDefaultsTlaDisabledTest {

@Test
@SetSystemProperty(key = "log4j2.enableThreadLocals", value = "false")
void defaults_should_match_when_thread_locals_disabled() {
void defaults_should_match_when_thread_locals_disabled(TestInfo testInfo) {
assertThat(Constants.ENABLE_THREADLOCALS).isFalse();
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaDisabledTest.class.getSimpleName());
final Logger logger =
new Logger(loggerContext, "defaults_should_match_when_thread_locals_disabled", null, null);
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ParameterizedMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaDisabledTest.class.getSimpleName())) {
final Logger logger = loggerContext.getLogger(testInfo.getDisplayName());
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ParameterizedMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@
import org.apache.logging.log4j.message.MessageFactory;
import org.apache.logging.log4j.message.ReusableMessageFactory;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junitpioneer.jupiter.SetSystemProperty;

class LoggerMessageFactoryDefaultsTlaEnabledTest {

@Test
@SetSystemProperty(key = "log4j2.is.webapp", value = "false")
@SetSystemProperty(key = "log4j2.enableThreadLocals", value = "true")
void defaults_should_match_when_thread_locals_enabled() {
@SetSystemProperty(key = "log4j2.isWebapp", value = "false")
@SetSystemProperty(key = "log4j2.enableThreadlocals", value = "true")
void defaults_should_match_when_thread_locals_enabled(TestInfo testInfo) {
assertThat(Constants.ENABLE_THREADLOCALS).isTrue();
final LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaEnabledTest.class.getSimpleName());
final Logger logger = new Logger(loggerContext, "defaults_should_match_when_thread_locals_enabled", null, null);
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ReusableMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
try (LoggerContext loggerContext =
new LoggerContext(LoggerMessageFactoryDefaultsTlaEnabledTest.class.getSimpleName())) {
Logger logger = loggerContext.getLogger(testInfo.getDisplayName());
assertThat((MessageFactory) logger.getMessageFactory()).isSameAs(ReusableMessageFactory.INSTANCE);
assertThat(logger.getFlowMessageFactory()).isSameAs(DefaultFlowMessageFactory.INSTANCE);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.apache.logging.log4j.ThreadContext;
import org.apache.logging.log4j.core.ContextDataInjector;
import org.apache.logging.log4j.spi.ThreadContextMap;
import org.apache.logging.log4j.util.PropertiesUtil;
import org.apache.logging.log4j.util.ProviderUtil;
import org.apache.logging.log4j.util.SortedArrayStringMap;
import org.apache.logging.log4j.util.StringMap;
Expand All @@ -59,7 +58,6 @@ public static Collection<String[]> threadContextMapClassNames() {
public String threadContextMapClassName;

private static void resetThreadContextMap() {
PropertiesUtil.getProperties().reload();
final Log4jProvider provider = (Log4jProvider) ProviderUtil.getProvider();
provider.resetThreadContextMap();
ThreadContext.init();
Expand Down
Loading
Loading