准备工作
创建一个SpringBoot工程,引入相关依赖。
<!-- redis -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis-reactive</artifactId>
</dependency>
<!-- aop -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- hutool -->
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.8.1</version>
</dependency>
<!-- web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- lombok-->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.18</version>
</dependency>
<!-- fastJson -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.76</version>
</dependency>
接口限流一般是通过注解来标记,而注解是通过 AOP 来解析的,所以需要加上 AOP 的依赖
提前准备好一个 Redis 实例,启动一个Redis服务即可,然后在配置文件中配置:
server:
port: 5599
spring:
application:
name: redis-study
redis:
host: 127.0.0.1
port: 6379
这样准备工作就差不多了。
限流注解
接下来创建一个限流注解,将限流分为两种情况:
-
针对当前接口的全局性限流,例如该接口可以在 1 分钟内访问 100 次。 -
针对某一个 IP 地址的限流,例如某个 IP 地址可以在 1 分钟内访问 100 次。
针对这两种情况,创建一个枚举类:
package com.itjing.redis.enu;
/**
* @author lijing
* @date 2022年05月26日 9:30
* @description 限流使用枚举
*/
public enum LimitType {
/**
* 默认策略全局限流
*/
DEFAULT,
/**
* 根据请求者IP进行限流
*/
IP
}
接下来创建限流注解:
package com.itjing.redis.annotation;
import com.itjing.redis.enu.LimitType;
import java.lang.annotation.*;
/**
* @author lijing
* @date 2022年05月26日 9:31
* @description 限流使用注解
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流key
*/
String key() default "rate_limit:";
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 100;
/**
* 限流类型
*/
LimitType limitType() default LimitType.DEFAULT;
}
其中,第一个参数key,仅仅是一个前缀,将来完整的 key 是这个前缀再加上接口方法的完整路径,共同组成限流 key,这个 key 将被存入到 Redis 中。
另外三个参数好理解,就不多说了。
好了,将来哪个接口需要限流,就在哪个接口上添加 @RateLimiter
注解,然后配置相关参数即可。
定制 RedisTemplate
package com.itjing.redis.config;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;
@Configuration
public class RedisConfig {
// 自己定义了一个 RedisTemplate
@Bean
@SuppressWarnings("all")
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory factory) {
// 我们为了自己开发方便,一般直接使用 <String, Object>
RedisTemplate<Object, Object> template = new RedisTemplate<Object, Object>();
template.setConnectionFactory(factory);
// Json序列化配置
Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
ObjectMapper om = new ObjectMapper();
om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
// om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); // 已过期
// om.activateDefaultTyping(om.getPolymorphicTypeValidator());
//om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.WRAPPER_ARRAY);
jackson2JsonRedisSerializer.setObjectMapper(om);
// String 的序列化
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
// key采用String的序列化方式
template.setKeySerializer(stringRedisSerializer);
// hash的key也采用String的序列化方式
template.setHashKeySerializer(stringRedisSerializer);
// value序列化方式采用jackson
template.setValueSerializer(jackson2JsonRedisSerializer);
// hash的value序列化方式采用jackson
template.setHashValueSerializer(jackson2JsonRedisSerializer);
template.afterPropertiesSet();
return template;
}
}
开发 Lua 脚本
Redis 中的一些原子操作可以借助 Lua 脚本来实现,想要调用 Lua 脚本,有两种不同的思路:
-
在 Redis 服务端定义好 Lua 脚本,然后计算出来一个散列值,在 Java 代码中,通过这个散列值锁定要执行哪个 Lua 脚本。 -
直接在 Java 代码中将 Lua 脚本定义好,然后发送到 Redis 服务端去执行。
Spring Data Redis 中也提供了操作 Lua 脚本的接口,还是比较方便的,所以这里就采用第二种方案。
在 resources 目录下新建 lua 文件夹专门用来存放 lua 脚本,脚本内容如下:
local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
这个脚本其实不难,大概瞅一眼就知道干啥用的。
KEYS 和 ARGV 都是一会调用时候传进来的参数,tonumber 就是把字符串转为数字,redis.call 就是执行具体的 redis 指令,具体流程是这样:
-
首先获取到传进来的 key 以及 限流的 count 和时间 time。 -
通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次。 -
如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那就说明已经超过流量限制了,那么直接返回查询的结果即可。 -
如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间。 -
最后把自增 1 后的值返回就可以了。
其实这段 Lua 脚本很好理解。
接下来在一个 Bean 中来加载这段 Lua 脚本,在RedisConfig中定义如下:
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
redisScript.setResultType(Long.class);
return redisScript;
}
Lua 脚本现在就准备好了。
注解解析
接下来就需要自定义切面,来解析这个注解了,来看看切面的定义:
package com.itjing.redis.aspect;
import com.itjing.redis.annotation.RateLimiter;
import com.itjing.redis.enu.LimitType;
import com.itjing.redis.exception.ServiceException;
import com.itjing.redis.utils.IpUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
/**
* @author lijing
* @date 2022年05月26日 9:37
* @description
*/
@Aspect
@Component
public class RateLimiterAspect {
private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);
@Autowired
private RedisTemplate<Object, Object> redisTemplate;
@Autowired
private RedisScript<Long> limitScript;
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
String key = rateLimiter.key();
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, point);
List<Object> keys = Collections.singletonList(combineKey);
try {
Long number = redisTemplate.execute(limitScript, keys, count, time);
if (number==null || number.intValue() > count) {
throw new ServiceException("访问过于频繁,请稍候再试");
}
log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
} catch (ServiceException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
stringBuffer.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-");
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
return stringBuffer.toString();
}
}
package com.itjing.redis.utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
/**
* @author lijing
* @date 2022年05月26日 9:41
* @description 获取IP工具类
*/
public class IpUtils {
private static Logger logger = LoggerFactory.getLogger(IpUtils.class);
public static String getIpAddr(HttpServletRequest request) {
// HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
String ip = null;
try {
ip = request.getHeader("x-forwarded-for");
if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (StringUtils.isEmpty(ip) || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_CLIENT_IP");
}
if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("HTTP_X_FORWARDED_FOR");
}
if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
} catch (Exception e) {
logger.error("IPUtils ERROR ", e);
}
return ip;
}
}
这个切面就是拦截所有加了 @RateLimiter
注解的方法,在前置通知中对注解进行处理。
-
首先获取到注解中的 key、time 以及 count 三个参数。 -
获取一个组合的 key,所谓的组合的 key,就是在注解的 key 属性基础上,再加上方法的完整路径,如果是 IP 模式的话,就再加上 IP 地址。以 IP 模式为例,最终生成的 key 类似这样:rate_limit:127.0.0.1-com.itjing.redis.controller.TestLimitController-hello(如果不是 IP 模式,那么生成的 key 中就不包含 IP 地址)。 -
将生成的 key 放到集合中。 -
通过 redisTemplate.execute 方法取执行一个 Lua 脚本,第一个参数是脚本所封装的对象,第二个参数是 key,对应了脚本中的 KEYS,后面是可变长度的参数,对应了脚本中的 ARGV。 -
将 Lua 脚本执行的结果与 count 进行比较,如果大于 count,就说明过载了,抛异常就行了。
接口测试
接下来就进行接口的一个简单测试,如下:
package com.itjing.redis.controller;
import com.itjing.redis.annotation.RateLimiter;
import com.itjing.redis.enu.LimitType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.Date;
/**
* @author lijing
* @date 2022年05月26日 9:48
* @description 测试限流
*/
@RestController
@RequestMapping("/limit")
public class TestLimitController {
/**
* 每一个 IP 地址,在 5 秒内只能访问 3 次
* @return
*/
@GetMapping("/hello")
@RateLimiter(time = 5, count = 3, limitType = LimitType.IP)
public String hello() {
return "hello>>>" + new Date();
}
}
全局异常处理
由于过载的时候是抛异常出来,所以还需要一个全局异常处理器,如下:
package com.itjing.redis.exception;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import java.util.HashMap;
import java.util.Map;
/**
* @author lijing
* @date 2022年05月26日 9:47
* @description 全局异常捕获
*/
@RestControllerAdvice
public class GlobalException {
/**
* 业务异常
* @param e
* @return
*/
@ExceptionHandler(ServiceException.class)
public Map<String, Object> serviceException(ServiceException e) {
HashMap<String, Object> map = new HashMap<>();
map.put("status", e.getCode());
map.put("message", e.getMessage());
return map;
}
}
package com.itjing.redis.exception;
/**
* @author lijing
* @date 2022年05月26日 9:44
* @description 自定义业务异常
*/
public class ServiceException extends Exception {
private Integer code = 500;
public ServiceException() {
}
public ServiceException(String message) {
super(message);
}
public ServiceException(String message, Integer code) {
super(message);
this.code = code;
}
public Integer getCode() {
return code;
}
public void setCode(Integer code) {
this.code = code;
}
}
这是一个小 demo,就不去定义实体类了,直接用 Map 来返回 JSON 了。
这就是使用 Redis 做限流的方式。
不过这里只是简单使用,具体的实现还是要根据应用场景来设计。
原文始发于微信公众号(程序员阿晶):SpringBoot+Redis实现接口限流
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/19629.html