diff --git a/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java b/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java index 9720ca0a61..b2cea86fbb 100644 --- a/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java +++ b/spring-cloud-gateway-server-webflux/src/main/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiter.java @@ -24,6 +24,7 @@ import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; +import jakarta.validation.constraints.DecimalMin; import jakarta.validation.constraints.Min; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -139,7 +140,7 @@ public RedisRateLimiter(ReactiveStringRedisTemplate redisTemplate, RedisScript isAllowed(String routeId, String id) { Config routeConfig = loadConfiguration(routeId); // How many requests per second do you want a user to be allowed to do? - int replenishRate = routeConfig.getReplenishRate(); + double replenishRate = routeConfig.getReplenishRate(); // How much bursting do you want to allow? long burstCapacity = routeConfig.getBurstCapacity(); @@ -311,11 +312,18 @@ public Mono isAllowed(String routeId, String id) { return routeConfig; } + private static String formatReplenishRate(double rate) { + if (rate == Math.floor(rate)) { + return String.valueOf((long) rate); + } + return String.valueOf(rate); + } + public Map getHeaders(Config config, Long tokensLeft) { Map headers = new HashMap<>(); if (isIncludeHeaders()) { headers.put(this.remainingHeader, tokensLeft.toString()); - headers.put(this.replenishRateHeader, String.valueOf(config.getReplenishRate())); + headers.put(this.replenishRateHeader, formatReplenishRate(config.getReplenishRate())); headers.put(this.burstCapacityHeader, String.valueOf(config.getBurstCapacity())); headers.put(this.requestedTokensHeader, String.valueOf(config.getRequestedTokens())); } @@ -325,8 +333,8 @@ public Map getHeaders(Config config, Long tokensLeft) { @Validated public static class Config { - @Min(1) - private int replenishRate; + @DecimalMin(value = "0.0", inclusive = false) + private double replenishRate; @Min(0) private long burstCapacity = 1; @@ -334,11 +342,11 @@ public static class Config { @Min(1) private int requestedTokens = 1; - public int getReplenishRate() { + public double getReplenishRate() { return replenishRate; } - public Config setReplenishRate(int replenishRate) { + public Config setReplenishRate(double replenishRate) { this.replenishRate = replenishRate; return this; } diff --git a/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiterConfigTests.java b/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiterConfigTests.java index ff693df232..35b47d1af0 100644 --- a/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiterConfigTests.java +++ b/spring-cloud-gateway-server-webflux/src/test/java/org/springframework/cloud/gateway/filter/ratelimit/RedisRateLimiterConfigTests.java @@ -53,6 +53,14 @@ public void init() { routeLocator.getRoutes().collectList().block(); } + private void assertConfigOnly(String key, double replenishRate, int burstCapacity, int requestedTokens) { + RedisRateLimiter.Config config = rateLimiter.getConfig().get(key); + assertThat(config).isNotNull(); + assertThat(config.getReplenishRate()).isEqualTo(replenishRate); + assertThat(config.getBurstCapacity()).isEqualTo(burstCapacity); + assertThat(config.getRequestedTokens()).isEqualTo(requestedTokens); + } + @Test public void shouldThrowAnErrorWhenReplenishRateIsHigherThanBurstCapacity() { Assertions.assertThatThrownBy(() -> new RedisRateLimiter(10, 5)).isInstanceOf(IllegalArgumentException.class); @@ -86,7 +94,24 @@ public void redisRateConfiguredFromJavaAPIDirectBean() { assertFilter("alt_custom_redis_rate_limiter", 30, 60, 20, true); } - private void assertFilter(String key, int replenishRate, int burstCapacity, int requestedTokens, + @Test + public void fractionalReplenishRateTest() { + String key = "fractionalKey"; + double replenishRate = 0.5; + int burstCapacity = 1; + + RedisRateLimiter.Config fractionalConfig = new RedisRateLimiter.Config().setReplenishRate(replenishRate) + .setBurstCapacity(burstCapacity) + .setRequestedTokens(1); + + // Add this config manually for the test + rateLimiter.getConfig().put(key, fractionalConfig); + + // Only check config, skip route check + assertConfigOnly(key, replenishRate, burstCapacity, 1); + } + + private void assertFilter(String key, double replenishRate, int burstCapacity, int requestedTokens, boolean useDefaultConfig) { RedisRateLimiter.Config config;