8000 Add more MCP stuff · coder/coder@363f86d · GitHub
[go: up one dir, main page]

Skip to content

Commit 363f86d

Browse files
kylecarbsjohnstcn
authored andcommitted
Add more MCP stuff
1 parent 0f31c3f commit 363f86d

File tree

14 files changed

+248
-70
lines changed

14 files changed

+248
-70
lines changed

coderd/ai/ai.go

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@ package ai
22

33
import (
44
"context"
5-
"fmt"
65

76
"github.com/anthropics/anthropic-sdk-go"
87
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
9-
"github.com/coder/coder/v2/codersdk"
108
"github.com/kylecarbs/aisdk-go"
119
"github.com/openai/openai-go"
1210
openaioption "github.com/openai/openai-go/option"
11+
"golang.org/x/xerrors"
1312
"google.golang.org/genai"
13+
14+
"github.com/coder/coder/v2/codersdk"
1415
)
1516

1617
type LanguageModel struct {
@@ -19,10 +20,11 @@ type LanguageModel struct {
1920
}
2021

2122
type StreamOptions struct {
22-
Model string
23-
Messages []aisdk.Message
24-
Thinking bool
25-
Tools []aisdk.Tool
23+
SystemPrompt string
24+
Model string
25+
Messages []aisdk.Message
26+
Thinking bool
27+
Tools []aisdk.Tool
2628
}
2729

2830
type StreamFunc func(ctx context.Context, options StreamOptions) (aisdk.DataStream, error)
@@ -45,6 +47,12 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
4547
return nil, err
4648
}
4749
tools := aisdk.ToolsToOpenAI(options.Tools)
50+
if options.SystemPrompt != "" {
51+
openaiMessages = append([]openai.ChatCompletionMessageParamUnion{
52+
openai.SystemMessage(options.SystemPrompt),
53+
}, openaiMessages...)
54+
}
55+
4856
return aisdk.OpenAIToDataStream(client.Chat.Completions.NewStreaming(ctx, openai.ChatCompletionNewParams{
4957
Messages: openaiMessages,
5058
Model: options.Model,
@@ -70,6 +78,11 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
7078
if err != nil {
7179
return nil, err
7280
}
81+
if options.SystemPrompt != "" {
82+
systemMessage = []anthropic.TextBlockParam{
83+
*anthropic.NewTextBlock(options.SystemPrompt).OfRequestTextBlock,
84+
}
85+
}
7386
return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
7487
Messages: anthropicMessages,
7588
Model: options.Model,
@@ -106,8 +119,18 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
106119
if err != nil {
107120
return nil, err
108121
}
122+
var systemInstruction *genai.Content
123+
if options.SystemPrompt != "" {
124+
systemInstruction = &genai.Content{
125+
Parts: []*genai.Part{
126+
genai.NewPartFromText(options.SystemPrompt),
127+
},
128+
Role: "model",
129+
}
130+
}
109131
return aisdk.GoogleToDataStream(client.Models.GenerateContentStream(ctx, options.Model, googleMessages, &genai.GenerateContentConfig{
110-
Tools: tools,
132+
SystemInstruction: systemInstruction,
133+
Tools: tools,
111134
})), nil
112135
}
113136
if config.Models == nil {
@@ -122,7 +145,7 @@ func ModelsFromConfig(ctx context.Context, configs []codersdk.AIProviderConfig)
122145
}
123146
break
124147
default:
125-
return nil, fmt.Errorf("unsupported model type: %s", config.Type)
148+
return nil, xerrors.Errorf("unsupported model type: %s", config.Type)
126149
}
127150

128151
for _, model := range config.Models {

coderd/chat.go

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@ package coderd
22

33
import (
44
"encoding/json"
5+
"io"
56
"net/http"
67
"time"
78

9+
"github.com/google/uuid"
10+
"github.com/kylecarbs/aisdk-go"
11+
812
"github.com/coder/coder/v2/coderd/ai"
913
"github.com/coder/coder/v2/coderd/database"
1014
"github.com/coder/coder/v2/coderd/database/db2sdk"
1115
"github.com/coder/coder/v2/coderd/database/dbtime"
1216
"github.com/coder/coder/v2/coderd/httpapi"
1317
"github.com/coder/coder/v2/coderd/httpmw"
1418
"github.com/coder/coder/v2/codersdk"
15-
codermcp "github.com/coder/coder/v2/mcp"
16-
"github.com/google/uuid"
17-
"github.com/kylecarbs/aisdk-go"
18-
"github.com/mark3labs/mcp-go/mcp"
19-
"github.com/mark3labs/mcp-go/server"
19+
"github.com/coder/coder/v2/codersdk/toolsdk"
2020
)
2121

2222
// postChats creates a new chat.
@@ -142,9 +142,10 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
142142
Message: "Failed to get chat messages",
143143
Detail: err.Error(),
144144
})
145+
return
145146
}
146147

147-
messages := make([]aisdk.Message, len(dbMessages))
148+
messages := make([]aisdk.Message, 0, len(dbMessages))
148149
for i, message := range dbMessages {
149150
err = json.Unmarshal(message.Content, &messages[i])
150151
if err != nil {
@@ -157,31 +158,17 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
157158
}
158159
messages = append(messages, req.Message)
159160

160-
toolMap := codermcp.AllTools()
161-
toolsByName := make(map[string]server.ToolHandlerFunc)
162161
client := codersdk.New(api.AccessURL)
163162
client.SetSessionToken(httpmw.APITokenFromRequest(r))
164-
toolDeps := codermcp.ToolDeps{
165-
Client: client,
166-
Logger: &api.Logger,
167-
}
168-
for _, tool := range toolMap {
169-
toolsByName[tool.Tool.Name] = tool.MakeHandler(toolDeps)
170-
}
171-
convertedTools := make([]aisdk.Tool, len(toolMap))
172-
for i, tool := range toolMap {
173-
schema := aisdk.Schema{
174-
Required: tool.Tool.InputSchema.Required,
175-
Properties: tool.Tool.InputSchema.Properties,
176-
}
177-
if tool.Tool.InputSchema.Required == nil {
178-
schema.Required = []string{}
179-
}
180-
convertedTools[i] = aisdk.Tool{
181-
Name: tool.Tool.Name,
182-
Description: tool.Tool.Description,
183-
Schema: schema,
163+
164+
tools := make([]aisdk.Tool, len(toolsdk.All))
165+
handlers := map[string]toolsdk.GenericHandlerFunc{}
166+
for i, tool := range toolsdk.All {
167+
if tool.Tool.Schema.Required == nil {
168+
tool.Tool.Schema.Required = []string{}
184169
}
170+
tools[i] = tool.Tool
171+
handlers[tool.Tool.Name] = tool.Handler
185172
}
186173

187174
provider, ok := api.LanguageModels[req.Model]
@@ -192,6 +179,44 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
192179
return
193180
}
194181

182+
// If it's the user's first message, generate a title for the chat.
183+
if len(messages) == 1 {
184+
var acc aisdk.DataStreamAccumulator
185+
stream, err := provider.StreamFunc(ctx, ai.StreamOptions{
186+
Model: req.Model,
187+
SystemPrompt: `- You will generate a short title based on the user's message.
188+
- It should be maximum of 40 characters.
189+
- Do not use quotes, colons, special characters, or emojis.`,
190+
Messages: messages,
191+
Tools: tools,
192+
})
193+
if err != nil {
194+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
195+
Message: "Failed to create stream",
196+
Detail: err.Error(),
197+
})
198+
}
199+
stream = stream.WithAccumulator(&acc)
200+
err = stream.Pipe(io.Discard)
201+
if err != nil {
202+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
203+
Message: "Failed to pipe stream",
204+
Detail: err.Error(),
205+
})
206+
}
207+
err = api.Database.UpdateChatByID(ctx, database.UpdateChatByIDParams{
208+
ID: chat.ID,
209+
Title: acc.Messages()[0].Content,
210+
})
211+
if err != nil {
212+
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
213+
Message: "Failed to update chat title",
214+
Detail: err.Error(),
215+
})
216+
return
217+
}
218+
}
219+
195220
// Write headers for the data stream!
196221
aisdk.WriteDataStreamHeaders(w)
197222

@@ -219,12 +244,20 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
219244
return
220245
}
221246

247+
deps := toolsdk.Deps{
248+
CoderClient: client,
249+
}
250+
222251
for {
223252
var acc aisdk.DataStreamAccumulator
224253
stream, err := provider.StreamFunc(ctx, ai.StreamOptions{
225254
Model: req.Model,
226255
Messages: messages,
227-
Tools: convertedTools,
256+
Tools: tools,
257+
SystemPrompt: `You are a chat assistant for Coder. You will attempt to resolve the user's
258+
request to the maximum utilization of your tools.
259+
260+
Try your best to not ask the user for help - solve the task with your tools!`,
228261
})
229262
if err != nil {
230263
httpapi.Write(ctx, w, http.StatusInternalServerError, codersdk.Response{
@@ -234,28 +267,21 @@ func (api *API) postChatMessages(w http.ResponseWriter, r *http.Request) {
234267
return
235268
}
236269
stream = stream.WithToolCalling(func(toolCall aisdk.ToolCall) any {
237-
tool, ok := toolsByName[toolCall.Name]
270+
tool, ok := handlers[toolCall.Name]
238271
if !ok {
239272
return nil
240273
}
241-
result, err := tool(ctx, mcp.CallToolRequest{
242-
Params: struct {
243-
Name string "json:\"name\""
244-
Arguments map[string]interface{} "json:\"arguments,omitempty\""
245-
Meta *struct {
246-
ProgressToken mcp.ProgressToken "json:\"progressToken,omitempty\""
247-
} "json:\"_meta,omitempty\""
248-
}{
249-
Name: toolCall.Name,
250-
Arguments: toolCall.Args,
251-
},
252-
})
274+
toolArgs, err := json.Marshal(toolCall.Args)
275+
if err != nil {
276+
return nil
277+
}
278+
result, err := tool(ctx, deps, toolArgs)
253279
if err != nil {
254280
return map[string]any{
255281
"error": err.Error(),
256282
}
257283
}
258-
return result.Content
284+
return result
259285
}).WithAccumulator(&acc)
260286

261287
err = stream.Pipe(w)

coderd/database/dbauthz/dbauthz.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4000,7 +4000,10 @@ func (q *querier) UpdateAPIKeyByID(ctx context.Context, arg database.UpdateAPIKe
40004000
}
40014001

40024002
func (q *querier) UpdateChatByID(ctx context.Context, arg database.UpdateChatByIDParams) error {
4003-
panic("not implemented")
4003+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat.WithID(arg.ID)); err != nil {
4004+
return err
4005+
}
4006+
return q.db.UpdateChatByID(ctx, arg)
40044007
}
40054008

40064009
func (q *querier) UpdateCryptoKeyDeletesAt(ctx context.Context, arg database.UpdateCryptoKeyDeletesAtParams) (database.CryptoKey, error) {

coderd/database/dbmem/dbmem.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8447,13 +8447,7 @@ func (q *FakeQuerier) InsertChat(ctx context.Context, arg database.InsertChatPar
84478447
q.mutex.Lock()
84488448
defer q.mutex.Unlock()
84498449

8450-
chat := database.Chat{
8451-
ID: arg.ID,
8452-
CreatedAt: arg.CreatedAt,
8453-
UpdatedAt: arg.UpdatedAt,
8454-
OwnerID: arg.OwnerID,
8455-
Title: arg.Title,
8456-
}
8450+
chat := database.Chat(arg)
84578451
q.chats = append(q.chats, chat)
84588452

84598453
return chat, nil

coderd/deployment.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"context"
55
"net/http"
66

7+
"github.com/kylecarbs/aisdk-go"
8+
79
"github.com/coder/coder/v2/coderd/httpapi"
810
"github.com/coder/coder/v2/coderd/rbac"
911
"github.com/coder/coder/v2/coderd/rbac/policy"
1012
"github.com/coder/coder/v2/codersdk"
11-
"github.com/kylecarbs/aisdk-go"
1213
)
1314

1415
// @Summary Get deployment config

coderd/httpmw/chat.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@ import (
44
"context"
55
"net/http"
66

7+
"github.com/go-chi/chi/v5"
8+
"github.com/google/uuid"
9+
710
"github.com/coder/coder/v2/coderd/database"
811
"github.com/coder/coder/v2/coderd/httpapi"
912
"github.com/coder/coder/v2/codersdk"
10-
"github.com/go-chi/chi/v5"
11-
"github.com/google/uuid"
1213
)
1314

1415
type chatContextKey struct{}

coderd/httpmw/chat_test.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@ import (
77
"testing"
88
"time"
99

10+
"github.com/go-chi/chi/v5"
11+
"github.com/google/uuid"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
1015
"github.com/coder/coder/v2/coderd/database"
1116
"github.com/coder/coder/v2/coderd/database/dbgen"
1217
"github.com/coder/coder/v2/coderd/database/dbmem"
1318
"github.com/coder/coder/v2/coderd/database/dbtime"
1419
"github.com/coder/coder/v2/coderd/httpmw"
1520
"github.com/coder/coder/v2/codersdk"
16-
"github.com/go-chi/chi/v5"
17-
"github.com/google/uuid"
18-
"github.com/stretchr/testify/assert"
19-
"github.com/stretchr/testify/require"
2021
)
2122

2223
func TestExtractChat(t *testing.T) {

go.mod

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,13 @@ require (
487487
)
488488

489489
require (
490+
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3
490491
github.com/coder/preview v0.0.1
491492
github.com/fsnotify/fsnotify v1.9.0
492493
github.com/kylecarbs/aisdk-go v0.0.5
493494
github.com/mark3labs/mcp-go v0.23.1
495+
github.com/openai/openai-go v0.1.0-beta.6
496+
google.golang.org/genai v0.7.0
494497
)
495498

496499
require (
@@ -502,7 +505,6 @@ require (
502505
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.26.0 // indirect
503506
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.50.0 // indirect
504507
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.50.0 // indirect
505-
github.com/anthropics/anthropic-sdk-go v0.2.0-beta.3 // indirect
506508
github.com/aquasecurity/go-version v0.0.1 // indirect
507509
github.com/aquasecurity/trivy v0.58.2 // indirect
508510
github.com/aws/aws-sdk-go v1.55.6 // indirect
@@ -516,7 +518,6 @@ require (
516518
github.com/hashicorp/go-safetemp v1.0.0 // indirect
517519
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
518520
github.com/moby/sys/user v0.3.0 // indirect
519-
github.com/openai/openai-go v0.1.0-beta.6 // indirect
520521
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
521522
github.com/samber/lo v1.49.1 // indirect
522523
github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect
@@ -527,6 +528,5 @@ require (
527528
go.opentelemetry.io/contrib/detectors/gcp v1.34.0 // indirect
528529
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.60.0 // indirect
529530
go.opentelemetry.io/otel/sdk/metric v1.35.0 // indirect
530-
google.golang.org/genai v0.7.0 // indirect
531531
k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 // indirect
532532
)

site/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
},
3737
"dependencies": {
3838
"@ai-sdk/react": "1.2.6",
39+
"@ai-sdk/ui-utils": "1.2.7",
3940
"@emoji-mart/data": "1.2.1",
4041
"@emoji-mart/react": "1.1.1",
4142
"@emotion/cache": "11.14.0",

0 commit comments

Comments
 (0)
0