Skip to content

Commit

Permalink
feat: Support server to client requests (#11)
Browse files Browse the repository at this point in the history
* This is a breaking change, but supports Server -> Client request and notifications.
  • Loading branch information
kyleconroy authored Dec 11, 2024
1 parent f79c24e commit d661166
Show file tree
Hide file tree
Showing 15 changed files with 521 additions and 292 deletions.
34 changes: 34 additions & 0 deletions base.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package mcp

import (
"context"
"strconv"
)

type base struct {
router *router
stream Stream
interceptors []Interceptor
}

func (b *base) listen(ctx context.Context, handler func(ctx context.Context, msg *Message) error) error {
for {
msg, err := b.stream.Recv()
if err != nil {
return err
}
if msg.Method != nil {
go func() {
handler(ctx, msg)
}()
} else {
id, err := strconv.ParseUint(msg.ID.String(), 10, 64)
if err != nil {
continue
}
if inbox, ok := b.router.Remove(id); ok {
inbox <- msg
}
}
}
}
77 changes: 77 additions & 0 deletions call.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package mcp

import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
)

func call[P any, R any](ctx context.Context, c *base, method string, req *Request[P]) (*Response[R], error) {
id, inbox := c.router.Add()

var interceptor Interceptor
if len(c.interceptors) > 0 {
interceptor = newStack(c.interceptors)
} else {
interceptor = UnaryInterceptorFunc(
func(next UnaryFunc) UnaryFunc {
return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
return next(ctx, request)
})
},
)
}

inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
rawmsg, err := json.Marshal(req.Params)
if err != nil {
return nil, err
}

msgID := json.Number(request.ID())
msgVersion := "2.0"
msgParams := json.RawMessage(rawmsg)

msg := &Message{
ID: &msgID,
JsonRPC: &msgVersion,
Method: &method,
Params: &msgParams,
}

if err := c.stream.Send(msg); err != nil {
return nil, err
}

var result R

select {
case resp := <-inbox:
if resp.Error != nil {
return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message))
}
if resp.Result == nil {
return nil, fmt.Errorf("no result")
}
if err := json.Unmarshal(*resp.Result, &result); err != nil {
return nil, err
}
case <-ctx.Done():
return nil, ctx.Err()
}

return NewResponse(&result), nil
})

req.id = strconv.FormatUint(id, 10)
req.method = method

resp, err := interceptor.WrapUnary(inner)(ctx, req)
if err != nil {
return nil, err
}

return resp.(*Response[R]), nil
}
201 changes: 65 additions & 136 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,178 +1,107 @@
package mcp

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"

"github.com/riza-io/mcp-go/internal/jsonrpc"
)

type Client interface {
Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error)
ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error)
ListTools(ctx context.Context, req *Request[ListToolsRequest]) (*Response[ListToolsResponse], error)
CallTool(ctx context.Context, req *Request[CallToolRequest]) (*Response[CallToolResponse], error)
ListPrompts(ctx context.Context, req *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error)
GetPrompt(ctx context.Context, req *Request[GetPromptRequest]) (*Response[GetPromptResponse], error)
ReadResource(ctx context.Context, req *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error)
ListResourceTemplates(ctx context.Context, req *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error)
Completion(ctx context.Context, req *Request[CompletionRequest]) (*Response[CompletionResponse], error)
Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error)
SetLogLevel(ctx context.Context, req *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error)
}

type StdioClient struct {
in io.Reader
out io.Writer
scanner *bufio.Scanner
next int
lock sync.Mutex
type ClientHandler interface {
Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error)
Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error)
LogMessage(ctx context.Context, request *Request[LogMessageRequest])
}

type UnimplementedClient struct{}

func (u *UnimplementedClient) Sampling(ctx context.Context, request *Request[SamplingRequest]) (*Response[SamplingResponse], error) {
return nil, fmt.Errorf("not implemented")
}

func (u *UnimplementedClient) LogMessage(ctx context.Context, request *Request[LogMessageRequest]) {
}

func (c *UnimplementedClient) Ping(ctx context.Context, req *Request[PingRequest]) (*Response[PingResponse], error) {
return NewResponse(&PingResponse{}), nil
}

type Client struct {
handler ClientHandler
interceptors []Interceptor
base *base
}

func NewStdioClient(stdin io.Reader, stdout io.Writer, opts ...Option) Client {
c := &StdioClient{
in: stdin,
out: stdout,
scanner: bufio.NewScanner(stdin),
func NewClient(stream Stream, handler ClientHandler, opts ...Option) *Client {
c := &Client{
handler: handler,
}

for _, opt := range opts {
opt.applyToClient(c)
}

c.base = &base{
router: newRouter(),
interceptors: c.interceptors,
stream: stream,
}
return c
}

func clientCallUnary[P any, R any](ctx context.Context, c *StdioClient, method string, req *Request[P]) (*Response[R], error) {
// Ensure that we are not sending multiple requests at the same time
c.lock.Lock()
defer c.lock.Unlock()

defer func() {
// Increment the ID counter
c.next++
}()

var interceptor Interceptor
if len(c.interceptors) > 0 {
interceptor = newStack(c.interceptors)
} else {
interceptor = UnaryInterceptorFunc(
func(next UnaryFunc) UnaryFunc {
return UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
return next(ctx, request)
})
},
)
}

inner := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
rawmsg, err := json.Marshal(req.Params)
if err != nil {
return nil, err
}

msg := jsonrpc.Request{
ID: json.Number(request.ID()),
JsonRPC: "2.0",
Method: request.Method(),
Params: json.RawMessage(rawmsg),
}

bs, err := json.Marshal(msg)
if err != nil {
return nil, err
}

fmt.Fprintln(c.out, string(bs))

var result R

for c.scanner.Scan() {
line := c.scanner.Bytes()

var resp jsonrpc.Response

if err := json.Unmarshal(line, &resp); err != nil {
return nil, err
}

if resp.Error != nil {
return nil, NewError(resp.Error.Code, errors.New(resp.Error.Message))
}

if err := json.Unmarshal(resp.Result, &result); err != nil {
return nil, err
}

break
}

if err := c.scanner.Err(); err != nil {
return nil, err
}

return NewResponse(&result), nil
})

req.id = strconv.Itoa(c.next)
req.method = method
// sync.Once?
func (c *Client) Listen(ctx context.Context) error {
return c.base.listen(ctx, c.processMessage)
}

resp, err := interceptor.WrapUnary(inner)(ctx, req)
if err != nil {
return nil, err
func (c *Client) processMessage(ctx context.Context, msg *Message) error {
srv := c.handler
switch m := *msg.Method; m {
case "ping":
return process(ctx, c.base, msg, srv.Ping)
case "notifications/message":
return process(ctx, c.base, msg, noop(srv.LogMessage))
default:
return fmt.Errorf("unknown method: %s", m)
}

return resp.(*Response[R]), nil
}

func (c *StdioClient) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) {
return clientCallUnary[InitializeRequest, InitializeResponse](ctx, c, "initialize", request)
func (c *Client) Initialize(ctx context.Context, request *Request[InitializeRequest]) (*Response[InitializeResponse], error) {
return call[InitializeRequest, InitializeResponse](ctx, c.base, "initialize", request)
}

func (c *StdioClient) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) {
return clientCallUnary[ListResourcesRequest, ListResourcesResponse](ctx, c, "resources/list", request)
func (c *Client) ListResources(ctx context.Context, request *Request[ListResourcesRequest]) (*Response[ListResourcesResponse], error) {
return call[ListResourcesRequest, ListResourcesResponse](ctx, c.base, "resources/list", request)
}

func (c *StdioClient) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) {
return clientCallUnary[ListToolsRequest, ListToolsResponse](ctx, c, "tools/list", request)
func (c *Client) ListTools(ctx context.Context, request *Request[ListToolsRequest]) (*Response[ListToolsResponse], error) {
return call[ListToolsRequest, ListToolsResponse](ctx, c.base, "tools/list", request)
}

func (c *StdioClient) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) {
return clientCallUnary[CallToolRequest, CallToolResponse](ctx, c, "tools/call", request)
func (c *Client) CallTool(ctx context.Context, request *Request[CallToolRequest]) (*Response[CallToolResponse], error) {
return call[CallToolRequest, CallToolResponse](ctx, c.base, "tools/call", request)
}

func (c *StdioClient) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) {
return clientCallUnary[ListPromptsRequest, ListPromptsResponse](ctx, c, "prompts/list", request)
func (c *Client) ListPrompts(ctx context.Context, request *Request[ListPromptsRequest]) (*Response[ListPromptsResponse], error) {
return call[ListPromptsRequest, ListPromptsResponse](ctx, c.base, "prompts/list", request)
}

func (c *StdioClient) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) {
return clientCallUnary[GetPromptRequest, GetPromptResponse](ctx, c, "prompts/get", request)
func (c *Client) GetPrompt(ctx context.Context, request *Request[GetPromptRequest]) (*Response[GetPromptResponse], error) {
return call[GetPromptRequest, GetPromptResponse](ctx, c.base, "prompts/get", request)
}

func (c *StdioClient) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) {
return clientCallUnary[ReadResourceRequest, ReadResourceResponse](ctx, c, "resources/read", request)
func (c *Client) ReadResource(ctx context.Context, request *Request[ReadResourceRequest]) (*Response[ReadResourceResponse], error) {
return call[ReadResourceRequest, ReadResourceResponse](ctx, c.base, "resources/read", request)
}

func (c *StdioClient) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) {
return clientCallUnary[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c, "resources/templates/list", request)
func (c *Client) ListResourceTemplates(ctx context.Context, request *Request[ListResourceTemplatesRequest]) (*Response[ListResourceTemplatesResponse], error) {
return call[ListResourceTemplatesRequest, ListResourceTemplatesResponse](ctx, c.base, "resources/templates/list", request)
}

func (c *StdioClient) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) {
return clientCallUnary[CompletionRequest, CompletionResponse](ctx, c, "completion", request)
func (c *Client) Completion(ctx context.Context, request *Request[CompletionRequest]) (*Response[CompletionResponse], error) {
return call[CompletionRequest, CompletionResponse](ctx, c.base, "completion", request)
}

func (c *StdioClient) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) {
return clientCallUnary[PingRequest, PingResponse](ctx, c, "ping", request)
func (c *Client) Ping(ctx context.Context, request *Request[PingRequest]) (*Response[PingResponse], error) {
return call[PingRequest, PingResponse](ctx, c.base, "ping", request)
}

func (c *StdioClient) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) {
return clientCallUnary[SetLogLevelRequest, SetLogLevelResponse](ctx, c, "logging/setLevel", request)
func (c *Client) SetLogLevel(ctx context.Context, request *Request[SetLogLevelRequest]) (*Response[SetLogLevelResponse], error) {
return call[SetLogLevelRequest, SetLogLevelResponse](ctx, c.base, "logging/setLevel", request)
}
5 changes: 3 additions & 2 deletions examples/fs/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strings"

"github.com/riza-io/mcp-go"
"github.com/riza-io/mcp-go/stdio"
)

type FSServer struct {
Expand Down Expand Up @@ -73,11 +74,11 @@ func main() {
root = "/"
}

server := mcp.NewStdioServer(&FSServer{
server := mcp.NewServer(stdio.NewStream(os.Stdin, os.Stdout), &FSServer{
fs: os.DirFS(root),
})

if err := server.Listen(context.Background(), os.Stdin, os.Stdout); err != nil {
if err := server.Listen(context.Background()); err != nil {
log.Fatal(err)
}
}
Loading

0 comments on commit d661166

Please sign in to comment.