GRPC 自定义流拦截器
参考链接
服务端实现流拦截器(拦截客户端的流的时候)
自定义 grpc.ServerStream -> 重写 sendMsg & Recv 方法! -> 在拦截器中传递给handler 自定义 grpc.ServerStream 供下层使用!
proto 定义 服务端接受 客户端的流!
// stream_example.proto
syntax = "proto3";
option go_package = "./customStream";
service StreamExample {
rpc ClientStreaming (stream StreamRequest) returns (StreamResponse) {}
}
message StreamRequest {
bytes message = 1;
}
message StreamResponse {
string message = 1;
}
自定义拦截器实现的
StreamServerInterceptor
方法
源码位置 go/pkg/mod/google.golang.org/grpc@v1.55.0/interceptor.go
// StreamServerInterceptor provides a hook to intercept the execution of a streaming RPC on the server.
// info contains all the information of this RPC the interceptor can operate on. And handler is the
// service method implementation. It is the responsibility of the interceptor to invoke handler to
// complete the RPC.
type StreamServerInterceptor func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error
因为流拦截器我们期望在每次接收到流数据的时候,都能进行一些拦截操作!
所以我们在这里自定义了 ServerStream 并且继承 grpc.ServerStream
重写其 RecvMsg
& SendMsg
方法!
然后一些拦截自定义的操作写入到重写的 RecvMsg
& SendMsg
方法!
// WrappedServerStream 自定义的 WrappedServerStream 继承 `grpc.ServerStream`
type WrappedServerStream struct {
grpc.ServerStream
}
// NewWrappedServerStream WrappedServerStream构造函数!
func NewWrappedServerStream(ss grpc.ServerStream) *WrappedServerStream {
return &WrappedServerStream{
ServerStream: ss,
}
}
func (w *WrappedServerStream) RecvMsg(m interface{}) error {
// 在消息接收之前执行
log.Println("Performing data validation before receiving message")
// 读取消息
err := w.ServerStream.RecvMsg(m)
if err != nil {
log.Printf("Error reading message: %v", err)
return err
}
// 在消息接收之后执行
log.Println("Performing data validation after receiving message")
return nil
}
func (w *WrappedServerStream) SendMsg(m interface{}) error {
// 在消息发送之前执行
log.Println("Performing data validation before sending message")
// 发送消息
err := w.ServerStream.SendMsg(m)
if err != nil {
log.Printf("Error sending message: %v", err)
return err
}
// 在消息发送之后执行
log.Println("Performing data validation after sending message")
return nil
}
然后自定义拦截器,并且在拦截器方法中!替换我们自定义的 ServerStream
type CustomInterceptor struct{}
func (i *CustomInterceptor) StreamServerInterceptor(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
log.Printf("CustomInterceptor: Before handling the streaming request: %s", info.FullMethod)
// 创建一个新的 ServerStream,用于拦截消息的读取和发送
wrappedStream := NewWrappedServerStream(ss)
// 调用下一个拦截器或流处理程序,传递包装后的 ServerStream
err := handler(srv, wrappedStream)
if err != nil {
return err
}
log.Println("CustomInterceptor: After handling the streaming request")
return nil
}
实现服务端服务获取客户端流方法!
//CustomStreamService 服务端服务
type CustomStreamService struct {
customStream.UnimplementedStreamExampleServer
}
// ClientStreaming 服务端方法 持续获取 客户端推送的流,直到结束!
func (c CustomStreamService) ClientStreaming(server customStream.StreamExample_ClientStreamingServer) error {
for {
rec, err := server.Recv()
if err != nil || err == io.EOF {
fmt.Println("done", err)
break
}
fmt.Println("rec", rec)
}
_ = server.SendAndClose(&customStream.StreamResponse{Message: "Server Done !"})
return nil
}
客户端调用测试!
func main() {
// 建立 grpc与服务端套接字通信
conn, err := grpc.Dial(":50051", grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return
}
//defer 关闭链接
defer func(conn *grpc.ClientConn) {
err := conn.Close()
if err != nil {
log.Fatalf("关闭 con 套接字失败 err:%s", err.Error())
}
}(conn)
// 创建一个客户端实例
client := customStream.NewStreamExampleClient(conn)
//调用客户端封装好的 访问服务方法!
c, err := client.ClientStreaming(context.Background())
file, _ := os.ReadFile("/Users/xiaohao/GolandProjects/microway/client/test.txt")
for i := 0; i < len(file); i++ {
time.Sleep(time.Millisecond * 100)
if err != nil {
log.Println(err)
break
}
// 持续发送读取的文件流给服务端!
err = c.Send(&customStream.StreamRequest{Message: []byte{file[i]}})
if err != nil {
continue
}
}
_ = c.CloseSend()
var res customStream.StreamResponse
for {
err := c.RecvMsg(&res)
if err != nil {
break
}
if len(res.Message) > 0 {
break
}
}
fmt.Println(res.Message)
}
客户端实现流拦截器(拦截服务端发送的流)
定义proto
// stream_example.proto
syntax = "proto3";
option go_package = "./customStream";
service StreamExample {
rpc ServerStreaming ( StreamRequest) returns (stream StreamResponse) {}
}
message StreamRequest {
string message = 1;
}
message StreamResponse {
bytes message = 1;
}
其实套路和服务端差不多!
在 gRPC 的 ClientStreamInterceptor
中重写 SendMsg
和 RecvMsg
方法,需要创建一个包装了 grpc.ClientStream
的结构体,并在其中实现这两个方法。
type ClientWrapperServerStream struct {
grpc.ClientStream
}
// NewWrappedClientStream ClientWrapperServerStream构造函数!
func NewWrappedClientStream(cs grpc.ClientStream) *ClientWrapperServerStream {
return &ClientWrapperServerStream{
ClientStream: cs,
}
}
func (cw *ClientWrapperServerStream) RecvMsg(m interface{}) error {
// 在消息接收之前执行
log.Println("Performing data validation before receiving message")
// 读取消息
err := cw.ClientStream.RecvMsg(m)
if err != nil {
log.Printf("Error reading message: %v", err)
return err
}
// 在消息接收之后执行
log.Println("Performing data validation after receiving message")
return nil
}
func (cw *ClientWrapperServerStream) SendMsg(m interface{}) error {
// 在消息发送之前执行
log.Println("Performing data validation before sending message")
// 发送消息
err := cw.ClientStream.SendMsg(m)
if err != nil {
log.Printf("Error sending message: %v", err)
return err
}
// 在消息发送之后执行
log.Println("Performing data validation after sending message")
return nil
}
然后自定义拦截器,实现StreamClientInterceptor
方法
// StreamClientInterceptor intercepts the creation of a ClientStream. Stream
// interceptors can be specified as a DialOption, using WithStreamInterceptor()
// or WithChainStreamInterceptor(), when creating a ClientConn. When a stream
// interceptor(s) is set on the ClientConn, gRPC delegates all stream creations
// to the interceptor, and it is the responsibility of the interceptor to call
// streamer.
//
// desc contains a description of the stream. cc is the ClientConn on which the
// RPC was invoked. streamer is the handler to create a ClientStream and it is
// the responsibility of the interceptor to call it. opts contain all applicable
// call options, including defaults from the ClientConn as well as per-call
// options.
//
// StreamClientInterceptor may return a custom ClientStream to intercept all I/O
// operations. The returned error must be compatible with the status package.
type StreamClientInterceptor func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error)
StreamClientInterceptor
方法 内部使用我们自定义的 clientStream (已经重新了Send&Recv方法!)
type CustomInterceptor struct{}
func (i *CustomInterceptor) ClientStreamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
log.Printf("CustomInterceptor: Before handling the streaming request: %s", method)
// 创建客户端流
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, err
}
// 包装客户端流
wrappedStream := NewWrappedClientStream(clientStream)
log.Println("CustomInterceptor: After handling the streaming request")
return wrappedStream, nil
}
服务端服务方法实现!
type CustomStreamService struct {
customStream.UnimplementedStreamExampleServer
}
// 服务端推流服务!
func (c CustomStreamService) ServerStreaming(request *customStream.StreamRequest, server customStream.StreamExample_ServerStreamingServer) error {
fmt.Println(request.Message)
file, _ := os.ReadFile("/Users/xiaohao/GolandProjects/microway/server/test.txt")
for i := 0; i < len(file); i++ {
time.Sleep(time.Millisecond * 100)
_ = server.Send(&customStream.StreamResponse{Message: []byte{file[i]}})
}
return nil
}
客户端添加拦截器并调用服务端服务
func main() {
csi := new(CustomInterceptor)
// 建立 grpc与服务端套接字通信
conn, err := grpc.Dial(":50051", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithStreamInterceptor(csi.ClientStreamInterceptor))
if err != nil {
return
}
//defer 关闭链接
defer func(conn *grpc.ClientConn) {
err := conn.Close()
if err != nil {
log.Fatalf("关闭 con 套接字失败 err:%s", err.Error())
}
}(conn)
// 创建一个客户端实例
client := customStream.NewStreamExampleClient(conn)
//调用客户端封装好的 访问服务方法!
c, err := client.ServerStreaming(context.Background(), &customStream.StreamRequest{Message: "1"})
for {
rec, err := c.Recv()
if err != nil || err == io.EOF {
break
}
fmt.Println(rec)
}
err = c.CloseSend()
if err != nil {
log.Println(err)
}
}
来运行测试一下!可以看到服务端和客户端两边的拦截器都生效了!
「这样我们服务端和客户端流拦截器的实现就完成了!」
评论区