重构WebSocket Interceptor

This commit is contained in:
easonzhu 2025-01-28 23:26:45 +08:00
parent fcd7d3a511
commit ca9064a58e
4 changed files with 228 additions and 108 deletions

View File

@ -1,16 +1,5 @@
package com.upchina.common.interceptor;
import cn.hutool.core.util.StrUtil;
import com.hazelcast.map.IMap;
import com.upchina.common.constant.IsOrNot;
import com.upchina.common.filter.AuthFilter;
import com.upchina.common.handler.BizException;
import com.upchina.common.result.ResponseStatus;
import com.upchina.common.vo.BackendUserVO;
import com.upchina.common.vo.FrontUserVO;
import com.upchina.video.entity.OnlineUser;
import com.upchina.video.service.common.VideoCacheService;
import com.upchina.video.service.common.VideoMessageService;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order;
import org.springframework.lang.NonNull;
@ -22,8 +11,6 @@ import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.time.LocalDateTime;
import java.util.Map;
/**
* <p>
@ -38,123 +25,46 @@ import java.util.Map;
public class WebSocketAuthChannelInterceptor implements ChannelInterceptor {
@Resource
private AuthFilter authFilter;
private WebSocketAuthHandler authHandler;
@Resource
private VideoCacheService videoCacheService;
private WebSocketSessionHandler sessionHandler;
@Resource
private VideoMessageService videoMessageService;
/**
* 发送前监听是否为登录用户
*/
@Override
public Message<?> preSend(@NonNull Message<?> message, @NonNull MessageChannel channel) {
StompHeaderAccessor header = StompHeaderAccessor.wrap(message);
if (header == null || header.getCommand() == null) {
if (!isValidHeader(header)) {
return message;
}
// 只处理连接消息
if (!header.getCommand().equals(StompCommand.CONNECT)) {
return message;
if (StompCommand.CONNECT.equals(header.getCommand())) {
return authHandler.handleConnect(message, header);
}
String sessionId = header.getFirstNativeHeader("sessionId");
if (StrUtil.isEmpty(sessionId)) {
throw new BizException(ResponseStatus.PARM_ERROR, "header里没有包含sessionId");
}
String groupId = header.getFirstNativeHeader("GroupId");
Integer videoId = StrUtil.isNotEmpty(groupId) ? Integer.parseInt(groupId) : null;
if (videoId == null) {
throw new BizException(ResponseStatus.PARM_ERROR, "header里没有包含videoId");
}
Map<String, Object> attributes = header.getSessionAttributes();
if (attributes == null) {
throw new BizException(ResponseStatus.PARM_ERROR, "header里没有包含attributes");
}
String userId = null;
String authorization = header.getFirstNativeHeader("Authorization");
if (StrUtil.isNotEmpty(authorization)) {
BackendUserVO backendUser = authFilter.parseBackendUser(authorization);
attributes.put("backendUser", backendUser);
userId = backendUser.getUserId().toString();
} else {
String token = header.getFirstNativeHeader("token");
if (StrUtil.isNotEmpty(token)) {
FrontUserVO frontUser = authFilter.parseFrontUser(token);
attributes.put("frontUser", frontUser);
userId = frontUser.getUserId();
}
}
if (userId == null) {
throw new BizException(ResponseStatus.SESSION_EXPIRY);
}
attributes.put("userId", userId);
attributes.put("sessionId", sessionId);
attributes.put("videoId", videoId);
attributes.put("sessionKey", userId + "-" + sessionId);
return message;
}
@Override
public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
StompHeaderAccessor header = StompHeaderAccessor.wrap(message);
if (header == null || header.getCommand() == null) {
if (!isValidHeader(header)) {
return;
}
Map<String, Object> attributes = header.getSessionAttributes();
FrontUserVO frontUser = (FrontUserVO) attributes.get("frontUser");
if (frontUser == null) {
return;
}
String userId = (String) attributes.get("userId");
Integer videoId = (Integer) attributes.get("videoId");
String sessionId = (String) attributes.get("sessionId");
String sessionKey = (String) attributes.get("sessionKey");
if (userId == null || videoId == null || sessionId == null || sessionKey == null) {
return;
}
LocalDateTime now = LocalDateTime.now();
switch (header.getCommand()) {
case CONNECT: {
IMap<String, OnlineUser> totalOnlineMap = videoCacheService.getTotalOnlineMap(videoId);
OnlineUser onlineUser = new OnlineUser(
videoId,
userId,
frontUser.getUserName(),
frontUser.getImgUrl(),
sessionId,
IsOrNot.IS.value,
IsOrNot.NOT.value,
now
);
totalOnlineMap.put(sessionKey, onlineUser);
//上线通知
// LoggerUtil.websocket.info("上线通知:" + JSONObject.toJSONString(onlineUser));
videoMessageService.memberNotify(videoId, onlineUser);
videoMessageService.publishEnterMessage(videoId, frontUser);
case CONNECT:
sessionHandler.handleConnect(header);
break;
}
case DISCONNECT: {
IMap<String, OnlineUser> totalOnlineMap = videoCacheService.getTotalOnlineMap(videoId);
OnlineUser onlineUser = totalOnlineMap.get(sessionKey);
if (onlineUser != null) {
onlineUser.setIsOnline(IsOrNot.NOT.value);
onlineUser.setIsPlay(IsOrNot.NOT.value);
onlineUser.setExitTime(now);
totalOnlineMap.put(sessionKey, onlineUser);
//下线通知
// LoggerUtil.websocket.info("下线通知:" + JSONObject.toJSONString(onlineUser));
videoMessageService.memberNotify(videoId, onlineUser);
}
case DISCONNECT:
sessionHandler.handleDisconnect(header);
break;
default:
break;
}
}
}
@Override
public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
private boolean isValidHeader(StompHeaderAccessor header) {
return header != null && header.getCommand() != null;
}
}

View File

@ -0,0 +1,86 @@
package com.upchina.common.interceptor;
import cn.hutool.core.util.StrUtil;
import com.upchina.common.filter.AuthFilter;
import com.upchina.common.handler.BizException;
import com.upchina.common.result.ResponseStatus;
import com.upchina.common.vo.BackendUserVO;
import com.upchina.common.vo.FrontUserVO;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.Message;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.util.Map;
@Slf4j
@Component
public class WebSocketAuthHandler {
@Resource
private AuthFilter authFilter;
public Message<?> handleConnect(Message<?> message, StompHeaderAccessor header) {
validateHeaders(header);
Map<String, Object> attributes = header.getSessionAttributes();
String userId = authenticateUser(header);
if (userId == null) {
throw new BizException(ResponseStatus.SESSION_EXPIRY);
}
String sessionId = header.getFirstNativeHeader("sessionId");
Integer videoId = getVideoId(header);
populateAttributes(attributes, userId, sessionId, videoId);
return message;
}
private void validateHeaders(StompHeaderAccessor header) {
String sessionId = header.getFirstNativeHeader("sessionId");
if (StrUtil.isEmpty(sessionId)) {
throw new BizException(ResponseStatus.PARM_ERROR, "header里没有包含sessionId");
}
String groupId = header.getFirstNativeHeader("GroupId");
if (StrUtil.isEmpty(groupId)) {
throw new BizException(ResponseStatus.PARM_ERROR, "header里没有包含videoId");
}
if (header.getSessionAttributes() == null) {
throw new BizException(ResponseStatus.PARM_ERROR, "header里没有包含attributes");
}
}
private String authenticateUser(StompHeaderAccessor header) {
String authorization = header.getFirstNativeHeader("Authorization");
if (StrUtil.isNotEmpty(authorization)) {
BackendUserVO backendUser = authFilter.parseBackendUser(authorization);
header.getSessionAttributes().put("backendUser", backendUser);
return backendUser.getUserId().toString();
}
String token = header.getFirstNativeHeader("token");
if (StrUtil.isNotEmpty(token)) {
FrontUserVO frontUser = authFilter.parseFrontUser(token);
header.getSessionAttributes().put("frontUser", frontUser);
return frontUser.getUserId();
}
return null;
}
private Integer getVideoId(StompHeaderAccessor header) {
String groupId = header.getFirstNativeHeader("GroupId");
return StrUtil.isNotEmpty(groupId) ? Integer.parseInt(groupId) : null;
}
private void populateAttributes(Map<String, Object> attributes, String userId,
String sessionId, Integer videoId) {
attributes.put("userId", userId);
attributes.put("sessionId", sessionId);
attributes.put("videoId", videoId);
attributes.put("sessionKey", userId + "-" + sessionId);
}
}

View File

@ -0,0 +1,106 @@
package com.upchina.common.interceptor;
import com.hazelcast.map.IMap;
import com.upchina.common.constant.IsOrNot;
import com.upchina.common.model.SessionInfo;
import com.upchina.common.vo.FrontUserVO;
import com.upchina.video.entity.OnlineUser;
import com.upchina.video.service.common.VideoCacheService;
import com.upchina.video.service.common.VideoMessageService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.time.LocalDateTime;
import java.util.Map;
@Component
public class WebSocketSessionHandler {
@Resource
private VideoCacheService videoCacheService;
@Resource
private VideoMessageService videoMessageService;
public void handleConnect(StompHeaderAccessor header) {
Map<String, Object> attributes = header.getSessionAttributes();
FrontUserVO frontUser = (FrontUserVO) attributes.get("frontUser");
if (frontUser == null) {
return;
}
SessionInfo sessionInfo = extractSessionInfo(attributes);
if (!sessionInfo.isValid()) {
return;
}
OnlineUser onlineUser = createOnlineUser(sessionInfo, frontUser);
updateOnlineStatus(sessionInfo.getVideoId(), sessionInfo.getSessionKey(), onlineUser);
notifyUserConnect(sessionInfo.getVideoId(), frontUser, onlineUser);
}
public void handleDisconnect(StompHeaderAccessor header) {
Map<String, Object> attributes = header.getSessionAttributes();
FrontUserVO frontUser = (FrontUserVO) attributes.get("frontUser");
if (frontUser == null) {
return;
}
SessionInfo sessionInfo = extractSessionInfo(attributes);
if (!sessionInfo.isValid()) {
return;
}
handleUserDisconnect(sessionInfo);
}
private SessionInfo extractSessionInfo(Map<String, Object> attributes) {
return SessionInfo.builder()
.userId((String) attributes.get("userId"))
.videoId((Integer) attributes.get("videoId"))
.sessionId((String) attributes.get("sessionId"))
.sessionKey((String) attributes.get("sessionKey"))
.build();
}
private OnlineUser createOnlineUser(SessionInfo sessionInfo, FrontUserVO frontUser) {
return new OnlineUser(
sessionInfo.getVideoId(),
sessionInfo.getUserId(),
frontUser.getUserName(),
frontUser.getImgUrl(),
sessionInfo.getSessionId(),
IsOrNot.IS.value,
IsOrNot.NOT.value,
LocalDateTime.now()
);
}
private void updateOnlineStatus(Integer videoId, String sessionKey, OnlineUser onlineUser) {
IMap<String, OnlineUser> totalOnlineMap = videoCacheService.getTotalOnlineMap(videoId);
totalOnlineMap.put(sessionKey, onlineUser);
}
private void notifyUserConnect(Integer videoId, FrontUserVO frontUser, OnlineUser onlineUser) {
videoMessageService.memberNotify(videoId, onlineUser);
videoMessageService.publishEnterMessage(videoId, frontUser);
}
private void handleUserDisconnect(SessionInfo sessionInfo) {
IMap<String, OnlineUser> totalOnlineMap = videoCacheService.getTotalOnlineMap(sessionInfo.getVideoId());
OnlineUser onlineUser = totalOnlineMap.get(sessionInfo.getSessionKey());
if (onlineUser != null) {
updateOfflineStatus(onlineUser);
totalOnlineMap.put(sessionInfo.getSessionKey(), onlineUser);
videoMessageService.memberNotify(sessionInfo.getVideoId(), onlineUser);
}
}
private void updateOfflineStatus(OnlineUser onlineUser) {
onlineUser.setIsOnline(IsOrNot.NOT.value);
onlineUser.setIsPlay(IsOrNot.NOT.value);
onlineUser.setExitTime(LocalDateTime.now());
}
}

View File

@ -0,0 +1,18 @@
package com.upchina.common.model;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class SessionInfo {
private String userId;
private Integer videoId;
private String sessionId;
private String sessionKey;
public boolean isValid() {
return userId != null && videoId != null &&
sessionId != null && sessionKey != null;
}
}