Java 实现 WebSocket 集群转发:使用 Redis 发布订阅
场景
后端服务被部署到多个节点上,通过弹性负载均衡对外提供服务。
客户端(浏览器) 客户端 1 连接到了服务端 A 的 WebSocket 节点。
客户端通过弹性负载均衡,把请求分配到了服务端 B,比如计算服务会输出一些过程信息,服务端 B 上没有 客户端 1 的 WS 连接。
需求
服务端 B 把消息转发到服务端 A 上,找到 客户端 1 的连接,发送出去。
画示意图
代码
代码:https://github.com/ioufev/websocket-cluster-forward
备份:蓝奏云
Redis 发布类
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
@Component
public class RedisPublisher {
<span class="hljs-meta">@Resource</span>
<span class="hljs-keyword">private</span> RedisTemplate<String, <span class="hljs-type">byte</span>[]> redisTemplate;
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">publishMessage</span><span class="hljs-params">(String channel, <span class="hljs-type">byte</span>[] message)</span> {
redisTemplate.convertAndSend(channel, message);
}
}
Redis 订阅类
import com.ioufev.wsforward.consts.RedisConst;
import com.ioufev.wsforward.ws.WebSocketServer;
import org.springframework.context.annotation.Bean;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
@Component
public class RedisMessageListener implements MessageListener {
<span class="hljs-meta">@Resource</span>
<span class="hljs-keyword">private</span> WebSocketServer webSocket;
<span class="hljs-keyword">public</span> <span class="hljs-title function_">RedisMessageListener</span><span class="hljs-params">(WebSocketServer webSocket)</span> {
<span class="hljs-built_in">this</span>.webSocket = webSocket;
}
<span class="hljs-meta">@Override</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">onMessage</span><span class="hljs-params">(Message message, <span class="hljs-type">byte</span>[] pattern)</span> {
<span class="hljs-comment">// 获取频道名称</span>
<span class="hljs-type">String</span> <span class="hljs-variable">channel</span> <span class="hljs-operator">=</span> <span class="hljs-keyword">new</span> <span class="hljs-title class_">String</span>(message.getChannel());
<span class="hljs-comment">// 判断是否为需要转发的频道</span>
<span class="hljs-keyword">if</span>(channel.equals(RedisConst.PUB_SUB_TOPIC)){
<span class="hljs-comment">// 获取频道内容</span>
<span class="hljs-type">byte</span>[] body = message.getBody();
<span class="hljs-type">String</span> <span class="hljs-variable">contentBase64WithQuotes</span> <span class="hljs-operator">=</span> <span class="hljs-keyword">new</span> <span class="hljs-title class_">String</span>(body, StandardCharsets.UTF_8); <span class="hljs-comment">// 带引号的Base64</span>
<span class="hljs-type">String</span> <span class="hljs-variable">contentBase64</span> <span class="hljs-operator">=</span> removeQuotes(contentBase64WithQuotes); <span class="hljs-comment">// base64</span>
<span class="hljs-type">String</span> <span class="hljs-variable">content</span> <span class="hljs-operator">=</span> <span class="hljs-keyword">new</span> <span class="hljs-title class_">String</span>(Base64.getDecoder().decode(contentBase64), StandardCharsets.UTF_8); <span class="hljs-comment">// 原来的字符串</span>
<span class="hljs-type">String</span> <span class="hljs-variable">key</span> <span class="hljs-operator">=</span> content.split(<span class="hljs-string">"::"</span>)[<span class="hljs-number">0</span>];
<span class="hljs-type">String</span> <span class="hljs-variable">wsContent</span> <span class="hljs-operator">=</span> content.substring((key + <span class="hljs-string">"::"</span>).length());
webSocket.sendOneMessageForRedisMessage(key, wsContent);
}
}
<span class="hljs-meta">@Bean</span>
<span class="hljs-keyword">public</span> RedisMessageListenerContainer <span class="hljs-title function_">container</span><span class="hljs-params">(RedisConnectionFactory factory,
RedisMessageListener listener)</span> {
<span class="hljs-type">RedisMessageListenerContainer</span> <span class="hljs-variable">container</span> <span class="hljs-operator">=</span> <span class="hljs-keyword">new</span> <span class="hljs-title class_">RedisMessageListenerContainer</span>();
container.setConnectionFactory(factory);
container.addMessageListener(listener, <span class="hljs-keyword">new</span> <span class="hljs-title class_">ChannelTopic</span>(RedisConst.PUB_SUB_TOPIC));
<span class="hljs-keyword">return</span> container;
}
<span class="hljs-comment">/**
* 移除存在Redis中的值开头和结尾的引号
* <span class="hljs-doctag">@param</span> input 输入
* <span class="hljs-doctag">@return</span> 输出
*/</span>
<span class="hljs-keyword">private</span> String <span class="hljs-title function_">removeQuotes</span><span class="hljs-params">(String input)</span> {
<span class="hljs-keyword">if</span> (input != <span class="hljs-literal">null</span> && input.length() >= <span class="hljs-number">2</span> && input.startsWith(<span class="hljs-string">"\""</span>) && input.endsWith(<span class="hljs-string">"\""</span>)) {
<span class="hljs-keyword">return</span> input.substring(<span class="hljs-number">1</span>, input.length() - <span class="hljs-number">1</span>);
}
<span class="hljs-keyword">return</span> input;
}
}
WebSocket 服务端控制类
import com.ioufev.wsforward.consts.RedisConst;
import com.ioufev.wsforward.redis.RedisPublisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import javax.websocket.OnClose;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.PathParam;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import org.springframework.stereotype.Component;
import javax.websocket.server.ServerEndpoint;
@Component
@ServerEndpoint("/websocket/{key}")
public class WebSocketServer {
<span class="hljs-keyword">private</span> <span class="hljs-keyword">static</span> <span class="hljs-keyword">final</span> <span class="hljs-type">Logger</span> <span class="hljs-variable">log</span> <span class="hljs-operator">=</span> LoggerFactory.getLogger(WebSocketServer.class);
<span class="hljs-keyword">private</span> String sessionId;
<span class="hljs-keyword">private</span> Session session;
<span class="hljs-keyword">private</span> <span class="hljs-keyword">static</span> RedisPublisher redisPublisher;
<span class="hljs-meta">@Autowired</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">setApplicationContext</span><span class="hljs-params">(RedisPublisher redisPublisher)</span> {
WebSocketServer.redisPublisher= redisPublisher;
}
<span class="hljs-keyword">private</span> <span class="hljs-keyword">static</span> CopyOnWriteArraySet<WebSocketServer> webSockets = <span class="hljs-keyword">new</span> <span class="hljs-title class_">CopyOnWriteArraySet</span><>();
<span class="hljs-keyword">private</span> <span class="hljs-keyword">static</span> Map<String, Session> sessionPool = <span class="hljs-keyword">new</span> <span class="hljs-title class_">ConcurrentHashMap</span><>();
<span class="hljs-meta">@OnOpen</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">onOpen</span><span class="hljs-params">(Session session, <span class="hljs-meta">@PathParam(value = "key")</span> String key)</span> {
<span class="hljs-built_in">this</span>.sessionId = key;
<span class="hljs-built_in">this</span>.session = session;
webSockets.add(<span class="hljs-built_in">this</span>);
sessionPool.put(key, session);
log.info(key + <span class="hljs-string">"【websocket消息】有新的连接,总数为:"</span> + webSockets.size() + <span class="hljs-string">", session count is :"</span> + sessionPool.size());
<span class="hljs-keyword">for</span>(WebSocketServer webSocket : webSockets) {
log.info(<span class="hljs-string">"【webSocket】key is :"</span> + webSocket.sessionId);
}
}
<span class="hljs-meta">@OnClose</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">onClose</span><span class="hljs-params">()</span> {
sessionPool.remove(<span class="hljs-built_in">this</span>.sessionId);
webSockets.remove(<span class="hljs-built_in">this</span>);
log.info(<span class="hljs-string">"【websocket消息】连接断开,总数为:"</span> + webSockets.size());
}
<span class="hljs-meta">@OnMessage</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">onMessage</span><span class="hljs-params">(<span class="hljs-meta">@PathParam(value = "key")</span> String key, String message)</span> {
log.info(<span class="hljs-string">"【websocket消息】收到消息message:"</span> + message);
sendOneMessage(key, message);
}
<span class="hljs-comment">/**
* 广播消息
*/</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">sendAllMessage</span><span class="hljs-params">(String message)</span> {
<span class="hljs-keyword">for</span> (WebSocketServer webSocket : webSockets) {
log.info(<span class="hljs-string">"【websocket消息】广播消息:"</span> + message);
<span class="hljs-keyword">try</span> {
webSocket.session.getAsyncRemote().sendText(message);
} <span class="hljs-keyword">catch</span> (Exception e) {
e.printStackTrace();
}
}
}
<span class="hljs-comment">/**
* 单点消息
*/</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">sendOneMessage</span><span class="hljs-params">(String key, String message)</span> {
// Session session = sessionPool.get(key);
Session session = getSession(key);
if (session != null) {
try {
session.getBasicRemote().sendText(message);
} catch (Exception e) {
e.printStackTrace();
}
} else {
redisPublisher.publishMessage(RedisConst.PUB_SUB_TOPIC, (key + "::" + message).getBytes(StandardCharsets.UTF_8));
}
}
<span class="hljs-comment">/**
* 用来Redis订阅后使用
*/</span>
<span class="hljs-keyword">public</span> <span class="hljs-keyword">void</span> <span class="hljs-title function_">sendOneMessageForRedisMessage</span><span class="hljs-params">(String key, String message)</span> {
<span class="hljs-type">Session</span> <span class="hljs-variable">session</span> <span class="hljs-operator">=</span> getSession(key);
<span class="hljs-keyword">if</span> (session != <span class="hljs-literal">null</span>) {
<span class="hljs-keyword">try</span> {
session.getBasicRemote().sendText(message);
} <span class="hljs-keyword">catch</span> (Exception e) {
e.printStackTrace();
}
}
}
<span class="hljs-keyword">private</span> <span class="hljs-keyword">static</span> Session <span class="hljs-title function_">getSession</span><span class="hljs-params">(String key)</span>{
<span class="hljs-keyword">for</span> (WebSocketServer webSocket : webSockets) {
<span class="hljs-keyword">if</span>(webSocket.sessionId.equals(key)){
<span class="hljs-keyword">return</span> webSocket.session;
}
}
<span class="hljs-keyword">return</span> <span class="hljs-literal">null</span>;
}
}
参考文章
💧 WebSocket 集群解决方案
👉 图画的好,理解起来很清楚。
💧 WebSocket 集群解决方案,不用 MQ
👉 在上面的思路基础上,想给服务端添加一个标识,用来记录用户连接和服务端的关联关系,我也有类似的想法,不过关于用户 ID 和服务端 ID 关联关系的存储问题,还没处理好。
💧 Spring Cloud 一个配置注解实现 WebSocket 集群方案
👉 这个思路更大胆,既然是集群转发,没什么不能直接使用 WebSocket 本身
💧 分布式 WebSocket 集群解决方案
👉 用户连接和服务端的关联关系,用一致性哈希存储
💧 Spring Boot WebSocket 的 6 种集成方式
👉 喜欢文章的标题,内容看看目录就行了。
💧 构建通用 WebSocket 推送网关的设计与实践
👉 生产环境值得参考,但是用来入门参考显然没说清楚重点和难点
💧 石墨文档是如何通过 WebSocket 实现百万长连接的?
👉 生产环境值得参考,但是用来入门参考显然没说清楚重点和难点,这个比上面文章说更详细,显然具有可操作性。
总结
1、需要有一个统一的地方来保存用户连接和服务端的关联关系,可以是: Redis、MQ、Zookeeper、微服务的服务发现。
2、Redis 发布订阅用来集群转发非常简单,适用于实时发布消息那种,比如一个计算过程的实时步骤输出。
3、如果要确保消息不丢失,尽量送达之类的,那就用 MQ。
4、最佳方式:每个服务端有一个 ID,每个用户连接也有一个 ID,然后服务端转发的时候,找到需要的服务端,只转发一次就好了。