SpringBoot+Redis实现接口限流

准备工作

创建一个SpringBoot工程,引入相关依赖。

<!-- redis -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-redis-reactive</artifactId>
</dependency>

<!-- aop -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-aop</artifactId>
</dependency>

<!-- hutool -->
<dependency>
    <groupId>cn.hutool</groupId>
    <artifactId>hutool-all</artifactId>
    <version>5.8.1</version>
</dependency>

<!-- web -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>

<!-- lombok-->
<dependency>
    <groupId>org.projectlombok</groupId>
    <artifactId>lombok</artifactId>
    <version>1.18.18</version>
</dependency>

<!-- fastJson -->
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>fastjson</artifactId>
    <version>1.2.76</version>
</dependency>

接口限流一般是通过注解来标记,而注解是通过 AOP 来解析的,所以需要加上 AOP 的依赖

提前准备好一个 Redis 实例,启动一个Redis服务即可,然后在配置文件中配置:

server:
  port: 5599

spring:
  application:
    name: redis-study
  redis:
    host: 127.0.0.1
    port: 6379

这样准备工作就差不多了。

限流注解

接下来创建一个限流注解,将限流分为两种情况:

  1. 针对当前接口的全局性限流,例如该接口可以在 1 分钟内访问 100 次。
  2. 针对某一个 IP 地址的限流,例如某个 IP 地址可以在 1 分钟内访问 100 次。

针对这两种情况,创建一个枚举类:

package com.itjing.redis.enu;

/**
 * @author lijing
 * @date 2022年05月26日 9:30
 * @description 限流使用枚举
 */

public enum LimitType {

    /**
     * 默认策略全局限流
     */

    DEFAULT,

    /**
     * 根据请求者IP进行限流
     */

    IP

}

接下来创建限流注解:

package com.itjing.redis.annotation;

import com.itjing.redis.enu.LimitType;

import java.lang.annotation.*;

/**
 * @author lijing
 * @date 2022年05月26日 9:31
 * @description 限流使用注解
 */

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
    /**
     * 限流key
     */

    String key() default "rate_limit:";

    /**
     * 限流时间,单位秒
     */

    int time() default 60;

    /**
     * 限流次数
     */

    int count() default 100;

    /**
     * 限流类型
     */

    LimitType limitType() default LimitType.DEFAULT;
}

其中,第一个参数key,仅仅是一个前缀,将来完整的 key 是这个前缀再加上接口方法的完整路径,共同组成限流 key,这个 key 将被存入到 Redis 中。

另外三个参数好理解,就不多说了。

好了,将来哪个接口需要限流,就在哪个接口上添加 @RateLimiter 注解,然后配置相关参数即可。

定制 RedisTemplate

package com.itjing.redis.config;

import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.jsontype.impl.LaissezFaireSubTypeValidator;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;

@Configuration
public class RedisConfig {

    // 自己定义了一个 RedisTemplate     
    @Bean
    @SuppressWarnings("all")
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory factory) {
        // 我们为了自己开发方便,一般直接使用 <String, Object>        
        RedisTemplate<Object, Object> template = new RedisTemplate<Object, Object>();
        template.setConnectionFactory(factory);
        // Json序列化配置
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        // om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); // 已过期
        // om.activateDefaultTyping(om.getPolymorphicTypeValidator());
        //om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.PROPERTY);
        om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL, JsonTypeInfo.As.WRAPPER_ARRAY);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        // String 的序列化        
        StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
        // key采用String的序列化方式        
        template.setKeySerializer(stringRedisSerializer);
        // hash的key也采用String的序列化方式 
        template.setHashKeySerializer(stringRedisSerializer);
        // value序列化方式采用jackson        
        template.setValueSerializer(jackson2JsonRedisSerializer);
        // hash的value序列化方式采用jackson        
        template.setHashValueSerializer(jackson2JsonRedisSerializer);
        template.afterPropertiesSet();

        return template;
    }
}

开发 Lua 脚本

Redis 中的一些原子操作可以借助 Lua 脚本来实现,想要调用 Lua 脚本,有两种不同的思路:

  1. 在 Redis 服务端定义好 Lua 脚本,然后计算出来一个散列值,在 Java 代码中,通过这个散列值锁定要执行哪个 Lua 脚本。
  2. 直接在 Java 代码中将 Lua 脚本定义好,然后发送到 Redis 服务端去执行。

Spring Data Redis 中也提供了操作 Lua 脚本的接口,还是比较方便的,所以这里就采用第二种方案。

在 resources 目录下新建 lua 文件夹专门用来存放 lua 脚本,脚本内容如下:

local key = KEYS[1]
local count = tonumber(ARGV[1])
local time = tonumber(ARGV[2])
local current = redis.call('get', key)
if current and tonumber(current) > count then
  return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
  redis.call('expire', key, time)
end
return tonumber(current)

这个脚本其实不难,大概瞅一眼就知道干啥用的。

KEYS 和 ARGV 都是一会调用时候传进来的参数,tonumber 就是把字符串转为数字,redis.call 就是执行具体的 redis 指令,具体流程是这样:

  1. 首先获取到传进来的 key 以及 限流的 count 和时间 time。
  2. 通过 get 获取到这个 key 对应的值,这个值就是当前时间窗内这个接口可以访问多少次。
  3. 如果是第一次访问,此时拿到的结果为 nil,否则拿到的结果应该是一个数字,所以接下来就判断,如果拿到的结果是一个数字,并且这个数字还大于 count,那就说明已经超过流量限制了,那么直接返回查询的结果即可。
  4. 如果拿到的结果为 nil,说明是第一次访问,此时就给当前 key 自增 1,然后设置一个过期时间。
  5. 最后把自增 1 后的值返回就可以了。

其实这段 Lua 脚本很好理解。

接下来在一个 Bean 中来加载这段 Lua 脚本,在RedisConfig中定义如下:

@Bean
public DefaultRedisScript<Long> limitScript() {
    DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
    redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
    redisScript.setResultType(Long.class);
    return redisScript;
}

Lua 脚本现在就准备好了。

注解解析

接下来就需要自定义切面,来解析这个注解了,来看看切面的定义:

package com.itjing.redis.aspect;

import com.itjing.redis.annotation.RateLimiter;
import com.itjing.redis.enu.LimitType;
import com.itjing.redis.exception.ServiceException;
import com.itjing.redis.utils.IpUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;

/**
 * @author lijing
 * @date 2022年05月26日 9:37
 * @description
 */

@Aspect
@Component
public class RateLimiterAspect {

    private static final Logger log = LoggerFactory.getLogger(RateLimiterAspect.class);

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List<Object> keys = Collections.singletonList(combineKey);
        try {
            Long number = redisTemplate.execute(limitScript, keys, count, time);
            if (number==null || number.intValue() > count) {
                throw new ServiceException("访问过于频繁,请稍候再试");
            }
            log.info("限制请求'{}',当前请求'{}',缓存key'{}'", count, number.intValue(), key);
        } catch (ServiceException e) {
            throw e;
        } catch (Exception e) {
            throw new RuntimeException("服务器限流异常,请稍候再试");
        }
    }

    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            stringBuffer.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes()).getRequest())).append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }
}
package com.itjing.redis.utils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;

/**
 * @author lijing
 * @date 2022年05月26日 9:41
 * @description 获取IP工具类
 */

public class IpUtils {

    private static Logger logger = LoggerFactory.getLogger(IpUtils.class);
    public static String getIpAddr(HttpServletRequest request) {
//        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String ip = null;
        try {
            ip = request.getHeader("x-forwarded-for");
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("WL-Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_CLIENT_IP");
            }
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("HTTP_X_FORWARDED_FOR");
            }
            if (StringUtils.isEmpty(ip) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
            }
        } catch (Exception e) {
            logger.error("IPUtils ERROR ", e);
        }
        return ip;
    }
}

这个切面就是拦截所有加了 @RateLimiter 注解的方法,在前置通知中对注解进行处理。

  1. 首先获取到注解中的 key、time 以及 count 三个参数。
  2. 获取一个组合的 key,所谓的组合的 key,就是在注解的 key 属性基础上,再加上方法的完整路径,如果是 IP 模式的话,就再加上 IP 地址。以 IP 模式为例,最终生成的 key 类似这样:rate_limit:127.0.0.1-com.itjing.redis.controller.TestLimitController-hello(如果不是 IP 模式,那么生成的 key 中就不包含 IP 地址)。
  3. 将生成的 key 放到集合中。
  4. 通过 redisTemplate.execute 方法取执行一个 Lua 脚本,第一个参数是脚本所封装的对象,第二个参数是 key,对应了脚本中的 KEYS,后面是可变长度的参数,对应了脚本中的 ARGV。
  5. 将 Lua 脚本执行的结果与 count 进行比较,如果大于 count,就说明过载了,抛异常就行了。

接口测试

接下来就进行接口的一个简单测试,如下:

package com.itjing.redis.controller;

import com.itjing.redis.annotation.RateLimiter;
import com.itjing.redis.enu.LimitType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.Date;

/**
 * @author lijing
 * @date 2022年05月26日 9:48
 * @description 测试限流
 */

@RestController
@RequestMapping("/limit")
public class TestLimitController {

    /**
     * 每一个 IP 地址,在 5 秒内只能访问 3 次
     * @return
     */

    @GetMapping("/hello")
    @RateLimiter(time = 5, count = 3, limitType = LimitType.IP)
    public String hello() {
        return "hello>>>" + new Date();
    }

}

全局异常处理

由于过载的时候是抛异常出来,所以还需要一个全局异常处理器,如下:

package com.itjing.redis.exception;

import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import java.util.HashMap;
import java.util.Map;

/**
 * @author lijing
 * @date 2022年05月26日 9:47
 * @description 全局异常捕获
 */

@RestControllerAdvice
public class GlobalException {

    /**
     * 业务异常
     * @param e
     * @return
     */

    @ExceptionHandler(ServiceException.class)
    public Map<StringObjectserviceException(ServiceException e
{
        HashMap<String, Object> map = new HashMap<>();
        map.put("status", e.getCode());
        map.put("message", e.getMessage());
        return map;
    }
}
package com.itjing.redis.exception;

/**
 * @author lijing
 * @date 2022年05月26日 9:44
 * @description 自定义业务异常
 */

public class ServiceException extends Exception {

    private Integer code = 500;

    public ServiceException() {
    }

    public ServiceException(String message) {
        super(message);
    }

    public ServiceException(String message, Integer code) {
        super(message);
        this.code = code;
    }


    public Integer getCode() {
        return code;
    }

    public void setCode(Integer code) {
        this.code = code;
    }
}

这是一个小 demo,就不去定义实体类了,直接用 Map 来返回 JSON 了。

这就是使用 Redis 做限流的方式。

不过这里只是简单使用,具体的实现还是要根据应用场景来设计。


原文始发于微信公众号(程序员阿晶):SpringBoot+Redis实现接口限流

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/19629.html

(0)
小半的头像小半

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!