10000 [GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient (#7… · dotnet/machinelearning@3659a48 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3659a48

Browse files
[GenAI] Introduce CausalLMPipelineChatClient for MEAI.IChatClient (#7270)
* leverage MEAI abstraction * Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs Co-authored-by: Stephen Toub <stoub@microsoft.com> * Update src/Microsoft.ML.GenAI.LLaMA/Llama3CausalLMChatClient.cs Co-authored-by: Stephen Toub <stoub@microsoft.com> * Update src/Microsoft.ML.GenAI.Phi/Phi3/Phi3CausalLMChatClient.cs Co-authored-by: Stephen Toub <stoub@microsoft.com> * fix comments * Update Microsoft.ML.GenAI.Core.csproj --------- Co-authored-by: Stephen Toub <stoub@microsoft.com>
1 parent 5b4981a commit 3659a48

16 files changed

+534
-82
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Text.Json;
6+
using System.Threading.Tasks;
7+
using AutoGen.Core;
8+
using Microsoft.Extensions.AI;
9+
using Microsoft.ML.GenAI.Core;
10+
using Microsoft.ML.GenAI.Core.Extension;
11+
using Microsoft.ML.GenAI.LLaMA;
12+
using Microsoft.ML.Tokenizers;
13+
using TorchSharp;
14+
using static TorchSharp.torch;
15+
16+
namespace Microsoft.ML.GenAI.Samples.MEAI;
17+
18+
internal class Llama3_1
19+
{
20+
public static async Task RunAsync(string weightFolder, string checkPointName = "model.safetensors.index.json")
21+
{
22+
var device = "cuda";
23+
if (device == "cuda")
24+
{
25+
torch.InitializeDeviceType(DeviceType.CUDA);
26+
}
27+
28+
var defaultType = ScalarType.BFloat16;
29+
torch.manual_seed(1);
30+
torch.set_default_dtype(defaultType);
31+
var configName = "config.json";
32+
var originalWeightFolder = Path.Combine(weightFolder, "original");
33+
34+
Console.WriteLine("Loading Llama from huggingface model weight folder");
35+
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
36+
stopWatch.Start();
37+
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
38+
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);
39+
40+
var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);
41+
42+
var client = new Llama3CausalLMChatClient(pipeline);
43+
44+
var task = """
45+
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
46+
""";
47+
var chatMessage = new ChatMessage(ChatRole.User, task);
48+
49+
await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
50+
{
51+
Console.Write(response.Text);
52+
}
53+
}
54+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using static TorchSharp.torch;
7+
using TorchSharp;
8+
using Microsoft.ML.GenAI.Phi;
9+
using Microsoft.ML.GenAI.Core;
10+
using Microsoft.ML.Tokenizers;
11+
using Microsoft.Extensions.AI;
12+
13+
namespace Microsoft.ML.GenAI.Samples.MEAI;
14+
15+
internal class Phi3
16+
{
17+
public static async Task RunAsync(string weightFolder)
18+
{
19+
var device = "cuda";
20+
if (device == "cuda")
21+
{
22+
torch.InitializeDeviceType(DeviceType.CUDA);
23+
}
24+
25+
var defaultType = ScalarType.Float16;
26+
torch.manual_seed(1);
27+
torch.set_default_dtype(defaultType);
28+
var tokenizerPath = Path.Combine(weightFolder, "tokenizer.model");
29+
var tokenizer = Phi3TokenizerHelper.FromPretrained(tokenizerPath);
30+
var model = Phi3ForCasualLM.FromPretrained(weightFolder, "config.json", layersOnTargetDevice: -1, quantizeToInt8: true);
31+
var pipeline = new CausalLMPipeline<LlamaTokenizer, Phi3ForCasualLM>(tokenizer, model, device);
32+
var client = new Phi3CausalLMChatClient(pipeline);
33+
34+
var task = """
35+
Write a C# program to print the sum of two numbers. Use top-level statement, put code between ```csharp and ```.
36+
""";
37+
var chatMessage = new ChatMessage(ChatRole.User, task);
38+
39+
await foreach (var response in client.CompleteStreamingAsync([chatMessage]))
40+
{
41+
Console.Write(response.Text);
42+
}
43+
}
44+
}
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
// See https://aka.ms/new-console-template for more information
22
using Microsoft.ML.GenAI.Samples.Llama;
3+
using Microsoft. E377 ML.GenAI.Samples.MEAI;
34

4-
await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct");
5+
//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
6+
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");

eng/Versions.props

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
<SystemRuntimeCompilerServicesUnsafeVersion>6.0.0</SystemRuntimeCompilerServicesUnsafeVersion>
3535
<SystemSecurityPrincipalWindows>5.0.0</SystemSecurityPrincipalWindows>
3636
<SystemTextEncodingsWebVersion>8.0.0</SystemTextEncodingsWebVersion>
37-
<SystemTextJsonVersion>8.0.4</SystemTextJsonVersion>
37+
<SystemTextJsonVersion>8.0.5</SystemTextJsonVersion>
3838
<SystemThreadingChannelsVersion>8.0.0</SystemThreadingChannelsVersion>
3939
<!-- Other product dependencies -->
4040
<ApacheArrowVersion>14.0.2</ApacheArrowVersion>
@@ -47,6 +47,7 @@
4747
<MicrosoftDotNetInteractiveVersion>1.0.0-beta.24375.2</MicrosoftDotNetInteractiveVersion>
4848
<MicrosoftMLOnnxRuntimeVersion>1.18.1</MicrosoftMLOnnxRuntimeVersion>
4949
<MlNetMklDepsVersion>0.0.0.12</MlNetMklDepsVersion>
50+
<MicrosoftExtensionsAIVersion>9.0.0-preview.9.24507.7</MicrosoftExtensionsAIVersion>
5051
<!--
5152
@("inteltbb.devel", "win", "2021.7.1.15305")
5253
-->
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using System.Runtime.CompilerServices;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Microsoft.Extensions.AI;
12+
using Microsoft.ML.Tokenizers;
13+
using static TorchSharp.torch;
14+
15+
namespace Microsoft.ML.GenAI.Core;
16+
17+
public abstract class CausalLMPipelineChatClient<TTokenizer, TCausalLMModel> : IChatClient
18+
where TTokenizer : Tokenizer
19+
where TCausalLMModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
20+
{
21+
private readonly ICausalLMPipeline<TTokenizer, TCausalLMModel> _pipeline;
22+
private readonly IMEAIChatTemplateBuilder _chatTemplateBuilder;
23+
24+
public CausalLMPipelineChatClient(
25+
ICausalLMPipeline<TTokenizer, TCausalLMModel> pipeline,
26+
IMEAIChatTemplateBuilder chatTemplateBuilder,
27+
ChatClientMetadata? metadata = null)
28+
{
29+
var classNameWithType = $"{nameof(CausalLMPipelineChatClient<TTokenizer, TCausalLMModel>)}<{typeof(TTokenizer).Name}, {typeof(TCausalLMModel).Name}>";
30+
Metadata ??= new ChatClientMetadata(providerName: classNameWithType, modelId: typeof(TCausalLMModel).Name);
31+
_chatTemplateBuilder = chatTemplateBuilder;
32+
_pipeline = pipeline;
33+
}
34+
35+
public ChatClientMetadata Metadata { get; }
36+
37+
public virtual Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
38+
{
39+
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
40+
var stopSequences = options?.StopSequences ?? Array.Empty<string>();
41+
42+
var output = _pipeline.Generate(
43+
prompt,
44+
maxLen: options?.MaxOutputTokens ?? 1024,
45+
temperature: options?.Temperature ?? 0.7f,
46+
stopSequences: stopSequences.ToArray()) ?? throw new InvalidOperationException("Failed to generate a reply.");
47+
48+
var chatMessage = new ChatMessage(ChatRole.Assistant, output);
49+
return Task.FromResult(new ChatCompletion([chatMessage])
50+
{
51+
CreatedAt = DateTime.UtcNow,
52+
FinishReason = ChatFinishReason.Stop,
53+
});
54+
}
55+
56+
#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
57+
public virtual async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
58+
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
59+
IList<ChatMessage> chatMessages,
60+
ChatOptions? options = null,
61+
[EnumeratorCancellation] CancellationToken cancellationToken = default)
62+
{
63+
var prompt = _chatTemplateBuilder.BuildPrompt(chatMessages, options);
64+
var stopSequences = options?.StopSequences ?? Array.Empty<string>();
65+
66+
foreach (var output in _pipeline.GenerateStreaming(
67+
prompt,
68+
maxLen: options?.MaxOutputTokens ?? 1024,
69+
temperature: options?.Temperature ?? 0.7f,
70+
stopSequences: stopSequences.ToArray()))
71+
{
72+
yield return new StreamingChatCompletionUpdate
73+
{
74+
Role = ChatRole.Assistant,
75+
Text = output,
76+
CreatedAt = DateTime.UtcNow,
77+
};
78+
}
79+
}
80+
81+
public virtual void Dispose()
82+
{
83+
}
84+
85+
public virtual TService? GetService<TService>(object? key = null) where TService : class
86+
{
87+
return null;
88+
}
89+
}

src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
<ItemGroup>
1515
<PackageReference Include="AutoGen.Core" Version="$(AutoGenVersion)" />
16+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="$(MicrosoftExtensionsAIVersion)" />
1617
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="$(SemanticKernelVersion)" />
1718
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1819
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />

src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using System.Text;
99
using System.Threading.Tasks;
1010
using AutoGen.Core;
11+
using Microsoft.Extensions.AI;
1112
using Microsoft.SemanticKernel.ChatCompletion;
1213

1314
namespace Microsoft.ML.GenAI.Core;
@@ -22,6 +23,11 @@ public interface IAutoGenChatTemplateBuilder
2223
string BuildPrompt(IEnumerable<IMessage> messages, IEnumerable<FunctionContract>? tools = null);
2324
}
2425

26+
public interface IMEAIChatTemplateBuilder
27+
{
28+
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null);
29+
}
30+
2531
public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
2632
{
2733
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Runtime.CompilerServices;
6+
using Microsoft.Extensions.AI;
7+
using Microsoft.ML.GenAI.Core;
8+
using Microsoft.ML.Tokenizers;
9+
10+
namespace Microsoft.ML.GenAI.LLaMA;
11+
12+
public class Llama3CausalLMChatClient : CausalLMPipelineChatClient<Tokenizer, LlamaForCausalLM>
13+
{
14+
private readonly string _eotToken = "<|eot_id|>";
15+
16+
public Llama3CausalLMChatClient(
17+
ICausalLMPipeline<Tokenizer, LlamaForCausalLM> pipeline,
18+
IMEAIChatTemplateBuilder? chatTemplateBuilder = null,
19+
ChatClientMetadata? metadata = null)
20+
: base(
21+
pipeline,
22+
chatTemplateBuilder ?? Llama3_1ChatTemplateBuilder.Instance,
23+
metadata ?? new ChatClientMetadata(modelId: nameof(Llama3CausalLMChatClient)))
24+
{
25+
}
26+
27+
public override Task<ChatCompletion> CompleteAsync(
28+
IList<ChatMessage> chatMessages,
29+
ChatOptions? options = null,
30+
CancellationToken cancellationToken = default)
31+
{
32+
options ??= new ChatOptions();
33+
34+
if (options.StopSequences != null)
35+
{
36+
options.StopSequences.Add(_eotToken);
37+
}
38+
else
39+
{
40+
options.StopSequences = new List<string> { _eotToken };
41+
}
42+
43+
return base.CompleteAsync(chatMessages, options, cancellationToken);
44+
}
45+
46+
public override IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
47+
IList<ChatMessage> chatMessages,
48+
ChatOptions? options = null,
49+
CancellationToken cancellationToken = default)
50+
{
51+
options ??= new ChatOptions();
52+
options.StopSequences ??= [];
53+
options.StopSequences.Add(_eotToken);
54+
55+
return base.CompleteStreamingAsync(chatMessages, options, cancellationToken);
56+
}
57+
}

src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44

55
using System.Text;
66
using AutoGen.Core;
7+
using Microsoft.Extensions.AI;
78
using Microsoft.ML.GenAI.Core;
89
using Microsoft.SemanticKernel;
910
using Microsoft.SemanticKernel.ChatCompletion;
11+
using TextContent = Microsoft.SemanticKernel.TextContent;
1012

1113
namespace Microsoft.ML.GenAI.LLaMA;
1214
#pragma warning disable MSML_GeneralName // This name should be PascalCased
13-
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder
15+
public class Llama3_1ChatTemplateBuilder : IChatTemplateBuilder, IMEAIChatTemplateBuilder
1416
#pragma warning restore MSML_GeneralName // This name should be PascalCased
1517
{
1618
private const char Newline = '\n';
@@ -86,5 +88,39 @@ public string BuildPrompt(ChatHistory chatHistory)
8688
return sb.ToString();
8789
}
8890

91+
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
92+
{
93+
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
94+
if (messages.Any(m => m.Text is null))
95+
{
96+
throw new InvalidOperationException("Please provide a message with content.");
97+
}
98+
99+
if (messages.Any(m => availableRoles.Any(availableRole => availableRole == m.Role) == false))
100+
{
101+
throw new InvalidOperationException("Please provide a message with a valid role. The valid roles are System, User, and Assistant.");
102+
}
103+
104+
var sb = new StringBuilder();
105+
sb.Append("<|begin_of_text|>");
106+
foreach (var message in messages)
107+
{
108+
var role = message.Role.Value;
109+
var content = message.Text!;
110+
sb.Append(message switch
111+
{
112+
_ when message.Role == ChatRole.System => $"<|start_header_id|>system<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
113+
_ when message.Role == ChatRole.User => $"<|start_header_id|>user<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
114+
_ when message.Role == ChatRole.Assistant => $"<|start_header_id|>assistant<|end_header_id|>{Newline}{content.Trim()}<|eot_id|>{Newline}",
115+
_ => throw new InvalidOperationException("Invalid role.")
116+
});
117+
}
118+
119+
sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
120+
var input = sb.ToString();
121+
122+
return input;
123+
}
124+
89125
public static Llama3_1ChatTemplateBuilder Instance { get; } = new Llama3_1ChatTemplateBuilder();
90126
}

0 commit comments

Comments
 (0)
0