diff --git a/Cnblogs.DashScope.Sdk.sln.DotSettings b/Cnblogs.DashScope.Sdk.sln.DotSettings index 3ad7f9e..9866f59 100644 --- a/Cnblogs.DashScope.Sdk.sln.DotSettings +++ b/Cnblogs.DashScope.Sdk.sln.DotSettings @@ -1,4 +1,8 @@  + True True True - True \ No newline at end of file + True + True + True + True \ No newline at end of file diff --git a/sample/Cnblogs.DashScope.Sample/Cnblogs.DashScope.Sample.csproj b/sample/Cnblogs.DashScope.Sample/Cnblogs.DashScope.Sample.csproj index ba80da7..76001b4 100644 --- a/sample/Cnblogs.DashScope.Sample/Cnblogs.DashScope.Sample.csproj +++ b/sample/Cnblogs.DashScope.Sample/Cnblogs.DashScope.Sample.csproj @@ -1,26 +1,26 @@  - - Exe - net8.0 - enable - enable - false - + + Exe + net8.0 + enable + enable + false + - - - - + + + + - - - Always - - + + + Always + + - - - + + + diff --git a/sample/Cnblogs.DashScope.Sample/Program.cs b/sample/Cnblogs.DashScope.Sample/Program.cs index 87bfe69..f9f74d7 100644 --- a/sample/Cnblogs.DashScope.Sample/Program.cs +++ b/sample/Cnblogs.DashScope.Sample/Program.cs @@ -9,7 +9,8 @@ using Microsoft.Extensions.AI; Console.WriteLine("Reading key from environment variable DASHSCOPE_KEY"); -var apiKey = Environment.GetEnvironmentVariable("DASHSCOPE_API_KEY"); +var apiKey = Environment.GetEnvironmentVariable("DASHSCOPE_KEY", EnvironmentVariableTarget.Process) + ?? Environment.GetEnvironmentVariable("DASHSCOPE_KEY", EnvironmentVariableTarget.User); if (string.IsNullOrEmpty(apiKey)) { Console.Write("ApiKey > "); @@ -63,6 +64,25 @@ userInput = Console.ReadLine()!; await ApplicationCallAsync(applicationId, userInput); break; + case SampleType.TextToSpeech: + { + using var tts = await dashScopeClient.CreateSpeechSynthesizerSocketSessionAsync("cosyvoice-v2"); + var taskId = await tts.RunTaskAsync( + new SpeechSynthesizerParameters() { Voice = "longxiaochun_v2", Format = "mp3" }); + await tts.ContinueTaskAsync(taskId, "博客园"); + await tts.ContinueTaskAsync(taskId, "代码改变世界"); + await tts.FinishTaskAsync(taskId); + var file = new FileInfo("tts.mp3"); + var writer = file.OpenWrite(); + await foreach (var b in tts.GetAudioAsync()) + { + writer.WriteByte(b); + } + + writer.Close(); + Console.WriteLine($"audio saved to {file.FullName}"); + break; + } } return; diff --git a/sample/Cnblogs.DashScope.Sample/SampleType.cs b/sample/Cnblogs.DashScope.Sample/SampleType.cs index feddf79..138ed9a 100644 --- a/sample/Cnblogs.DashScope.Sample/SampleType.cs +++ b/sample/Cnblogs.DashScope.Sample/SampleType.cs @@ -16,5 +16,7 @@ public enum SampleType MicrosoftExtensionsAiToolCall, - ApplicationCall + ApplicationCall, + + TextToSpeech, } diff --git a/sample/Cnblogs.DashScope.Sample/SampleTypeDescriptor.cs b/sample/Cnblogs.DashScope.Sample/SampleTypeDescriptor.cs index 26988a5..a6d1b94 100644 --- a/sample/Cnblogs.DashScope.Sample/SampleTypeDescriptor.cs +++ b/sample/Cnblogs.DashScope.Sample/SampleTypeDescriptor.cs @@ -14,6 +14,7 @@ public static string GetDescription(this SampleType sampleType) SampleType.MicrosoftExtensionsAi => "Use with Microsoft.Extensions.AI", SampleType.MicrosoftExtensionsAiToolCall => "Use tool call with Microsoft.Extensions.AI interfaces", SampleType.ApplicationCall => "Call pre-defined application", + SampleType.TextToSpeech => "TTS task", _ => throw new ArgumentOutOfRangeException(nameof(sampleType), sampleType, "Unsupported sample option") }; } diff --git a/src/Cnblogs.DashScope.AI/Cnblogs.DashScope.AI.csproj b/src/Cnblogs.DashScope.AI/Cnblogs.DashScope.AI.csproj index 8de1ebd..c34f277 100644 --- a/src/Cnblogs.DashScope.AI/Cnblogs.DashScope.AI.csproj +++ b/src/Cnblogs.DashScope.AI/Cnblogs.DashScope.AI.csproj @@ -11,7 +11,7 @@ - + diff --git a/src/Cnblogs.DashScope.AspNetCore/Assembly.cs b/src/Cnblogs.DashScope.AspNetCore/Assembly.cs new file mode 100644 index 0000000..b79c4ed --- /dev/null +++ b/src/Cnblogs.DashScope.AspNetCore/Assembly.cs @@ -0,0 +1,4 @@ +using System.Runtime.CompilerServices; + +[assembly:InternalsVisibleTo("Cnblogs.DashScope.Sdk.UnitTests")] +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2")] diff --git a/src/Cnblogs.DashScope.AspNetCore/Cnblogs.DashScope.AspNetCore.csproj b/src/Cnblogs.DashScope.AspNetCore/Cnblogs.DashScope.AspNetCore.csproj index 1ab06a5..ba30d5f 100644 --- a/src/Cnblogs.DashScope.AspNetCore/Cnblogs.DashScope.AspNetCore.csproj +++ b/src/Cnblogs.DashScope.AspNetCore/Cnblogs.DashScope.AspNetCore.csproj @@ -3,7 +3,8 @@ Cnblogs.DashScopeSDK true - Cnblogs;Dashscope;AI;Sdk;Embedding;AspNetCore + Cnblogs;Dashscope;AI;Sdk;Embedding;AspNetCore;Bailian + Cnblogs.DashScope.AspNetCore diff --git a/src/Cnblogs.DashScope.AspNetCore/DashScopeAspNetCoreDefaults.cs b/src/Cnblogs.DashScope.AspNetCore/DashScopeAspNetCoreDefaults.cs new file mode 100644 index 0000000..1cdbb63 --- /dev/null +++ b/src/Cnblogs.DashScope.AspNetCore/DashScopeAspNetCoreDefaults.cs @@ -0,0 +1,6 @@ +namespace Cnblogs.DashScope.AspNetCore; + +internal static class DashScopeAspNetCoreDefaults +{ + public const string DefaultHttpClientName = "Cnblogs.DashScope.Http"; +} diff --git a/src/Cnblogs.DashScope.AspNetCore/DashScopeClientAspNetCore.cs b/src/Cnblogs.DashScope.AspNetCore/DashScopeClientAspNetCore.cs new file mode 100644 index 0000000..be4a685 --- /dev/null +++ b/src/Cnblogs.DashScope.AspNetCore/DashScopeClientAspNetCore.cs @@ -0,0 +1,20 @@ +using Cnblogs.DashScope.Core; + +namespace Cnblogs.DashScope.AspNetCore; + +/// +/// The with DI and options pattern support. +/// +public class DashScopeClientAspNetCore + : DashScopeClientCore +{ + /// + /// The with DI and options pattern support. + /// + /// The factory to create . + /// The socket pool for WebSocket API calls. + public DashScopeClientAspNetCore(IHttpClientFactory factory, DashScopeClientWebSocketPool pool) + : base(factory.CreateClient(DashScopeAspNetCoreDefaults.DefaultHttpClientName), pool) + { + } +} diff --git a/src/Cnblogs.DashScope.AspNetCore/ServiceCollectionInjector.cs b/src/Cnblogs.DashScope.AspNetCore/ServiceCollectionInjector.cs index ee6b00a..39a388a 100644 --- a/src/Cnblogs.DashScope.AspNetCore/ServiceCollectionInjector.cs +++ b/src/Cnblogs.DashScope.AspNetCore/ServiceCollectionInjector.cs @@ -1,6 +1,9 @@ using System.Net.Http.Headers; +using Cnblogs.DashScope.AspNetCore; using Cnblogs.DashScope.Core; +using Cnblogs.DashScope.Core.Internals; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Options; // ReSharper disable once CheckNamespace namespace Microsoft.Extensions.DependencyInjection; @@ -37,9 +40,10 @@ public static IHttpClientBuilder AddDashScopeClient(this IServiceCollection serv { var apiKey = section["apiKey"] ?? throw new InvalidOperationException("There is no apiKey provided in given section"); - var baseAddress = section["baseAddress"]; + var baseAddress = section["baseAddress"] ?? DashScopeDefaults.HttpApiBaseAddress; var workspaceId = section["workspaceId"]; - return services.AddDashScopeClient(apiKey, baseAddress, workspaceId); + services.Configure(section); + return services.AddDashScopeHttpClient(apiKey, baseAddress, workspaceId); } /// @@ -48,16 +52,46 @@ public static IHttpClientBuilder AddDashScopeClient(this IServiceCollection serv /// The service collection to add service to. /// The DashScope api key. /// The DashScope api base address, you may change this value if you are using proxy. + /// The DashScope websocket base address, you may want to change this value if use are using proxy. /// Default workspace id to use. /// public static IHttpClientBuilder AddDashScopeClient( this IServiceCollection services, string apiKey, string? baseAddress = null, + string? baseWebsocketAddress = null, string? workspaceId = null) { - baseAddress ??= "https://dashscope.aliyuncs.com/api/v1/"; - return services.AddHttpClient( + services.Configure(o => + { + o.ApiKey = apiKey; + if (baseAddress != null) + { + o.BaseAddress = baseAddress; + } + + if (baseWebsocketAddress != null) + { + o.BaseWebsocketAddress = baseWebsocketAddress; + } + + o.WorkspaceId = workspaceId; + }); + + return services.AddDashScopeHttpClient(apiKey, baseAddress, workspaceId); + } + + private static IHttpClientBuilder AddDashScopeHttpClient( + this IServiceCollection services, + string apiKey, + string? baseAddress, + string? workspaceId) + { + services.AddSingleton(sp + => new DashScopeClientWebSocketPool(sp.GetRequiredService>().Value)); + services.AddScoped(); + return services.AddHttpClient( + DashScopeAspNetCoreDefaults.DefaultHttpClientName, h => { h.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", apiKey); @@ -66,7 +100,7 @@ public static IHttpClientBuilder AddDashScopeClient( h.DefaultRequestHeaders.Add("X-DashScope-WorkSpace", workspaceId); } - h.BaseAddress = new Uri(baseAddress); + h.BaseAddress = new Uri(baseAddress ?? DashScopeDefaults.HttpApiBaseAddress); }); } } diff --git a/src/Cnblogs.DashScope.Core/Cnblogs.DashScope.Core.csproj b/src/Cnblogs.DashScope.Core/Cnblogs.DashScope.Core.csproj index b164ed0..f337ae6 100644 --- a/src/Cnblogs.DashScope.Core/Cnblogs.DashScope.Core.csproj +++ b/src/Cnblogs.DashScope.Core/Cnblogs.DashScope.Core.csproj @@ -3,7 +3,7 @@ Cnblogs.DashScopeSDK true - Cnblogs;Dashscope;AI;Sdk;Embedding; + Cnblogs;Dashscope;AI;Sdk;Embedding;Bailian; Provide pure api access to DashScope without extra references. Cnblogs.DashScope.Sdk should be used for general purpose. @@ -15,5 +15,5 @@ - + diff --git a/src/Cnblogs.DashScope.Core/DashScopeClient.cs b/src/Cnblogs.DashScope.Core/DashScopeClient.cs index dea20f9..11b20e2 100644 --- a/src/Cnblogs.DashScope.Core/DashScopeClient.cs +++ b/src/Cnblogs.DashScope.Core/DashScopeClient.cs @@ -9,14 +9,17 @@ namespace Cnblogs.DashScope.Core; public class DashScopeClient : DashScopeClientCore { private static readonly Dictionary ClientPools = new(); + private static readonly Dictionary SocketPools = new(); /// /// Creates a DashScopeClient for further api call. /// /// The DashScope api key. /// The timeout for internal http client, defaults to 2 minute. - /// The base address for dashscope api call. + /// The base address for DashScope api call. + /// The base address for DashScope websocket api call. /// The workspace id. + /// Maximum size of socket pool. /// /// The underlying httpclient is cached by constructor parameter list. /// Client created with same parameter value will share same underlying instance. @@ -24,10 +27,41 @@ public class DashScopeClient : DashScopeClientCore public DashScopeClient( string apiKey, TimeSpan? timeout = null, - string? baseAddress = null, + string baseAddress = DashScopeDefaults.HttpApiBaseAddress, + string baseWebsocketAddress = DashScopeDefaults.WebsocketApiBaseAddress, + string? workspaceId = null, + int socketPoolSize = 32) + : base( + GetConfiguredClient(apiKey, timeout, baseAddress, workspaceId), + GetConfiguredSocketPool(apiKey, baseWebsocketAddress, socketPoolSize, workspaceId)) + { + } + + private static DashScopeClientWebSocketPool GetConfiguredSocketPool( + string apiKey, + string baseAddress, + int socketPoolSize, string? workspaceId = null) - : base(GetConfiguredClient(apiKey, timeout, baseAddress, workspaceId)) { + var key = GetCacheKey(); + + var pool = SocketPools.GetValueOrDefault(key); + if (pool is null) + { + pool = new DashScopeClientWebSocketPool( + new DashScopeOptions() + { + ApiKey = apiKey, + BaseWebsocketAddress = baseAddress, + SocketPoolSize = socketPoolSize, + WorkspaceId = workspaceId + }); + SocketPools.Add(key, pool); + } + + return pool; + + string GetCacheKey() => $"{apiKey}-{socketPoolSize}-{baseAddress}-{workspaceId}"; } private static HttpClient GetConfiguredClient( @@ -41,7 +75,7 @@ private static HttpClient GetConfiguredClient( { client = new HttpClient { - BaseAddress = new Uri(baseAddress ?? DashScopeDefaults.DashScopeApiBaseAddress), + BaseAddress = new Uri(baseAddress ?? DashScopeDefaults.HttpApiBaseAddress), Timeout = timeout ?? TimeSpan.FromMinutes(2) }; diff --git a/src/Cnblogs.DashScope.Core/DashScopeClientCore.cs b/src/Cnblogs.DashScope.Core/DashScopeClientCore.cs index e0f3159..a2ce231 100644 --- a/src/Cnblogs.DashScope.Core/DashScopeClientCore.cs +++ b/src/Cnblogs.DashScope.Core/DashScopeClientCore.cs @@ -4,7 +4,6 @@ using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; -using System.Text.Json.Serialization; using Cnblogs.DashScope.Core.Internals; namespace Cnblogs.DashScope.Core; @@ -14,22 +13,18 @@ namespace Cnblogs.DashScope.Core; /// public class DashScopeClientCore : IDashScopeClient { - private static readonly JsonSerializerOptions SerializationOptions = - new() - { - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, - PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower - }; - private readonly HttpClient _httpClient; + private readonly DashScopeClientWebSocketPool _socketPool; /// /// For DI container to inject pre-configured httpclient. /// /// Pre-configured httpclient. - public DashScopeClientCore(HttpClient httpClient) + /// Websocket pool. + public DashScopeClientCore(HttpClient httpClient, DashScopeClientWebSocketPool pool) { _httpClient = httpClient; + _socketPool = pool; } /// @@ -283,6 +278,15 @@ public async Task DeleteFileAsync( return (await SendCompatibleAsync(request, cancellationToken))!; } + /// + public async Task CreateSpeechSynthesizerSocketSessionAsync( + string modelId, + CancellationToken cancellationToken = default) + { + var socket = await _socketPool.RentSocketAsync(cancellationToken); + return new SpeechSynthesizerSocketSession(socket, modelId); + } + private static HttpRequestMessage BuildSseRequest(HttpMethod method, string url, TPayload payload) where TPayload : class { @@ -304,7 +308,9 @@ private static HttpRequestMessage BuildRequest( { var message = new HttpRequestMessage(method, url) { - Content = payload != null ? JsonContent.Create(payload, options: SerializationOptions) : null + Content = payload != null + ? JsonContent.Create(payload, options: DashScopeDefaults.SerializationOptions) + : null }; if (sse) @@ -340,7 +346,9 @@ private static HttpRequestMessage BuildRequest( }, HttpCompletionOption.ResponseContentRead, cancellationToken); - return await response.Content.ReadFromJsonAsync(SerializationOptions, cancellationToken); + return await response.Content.ReadFromJsonAsync( + DashScopeDefaults.SerializationOptions, + cancellationToken); } private async Task SendAsync(HttpRequestMessage message, CancellationToken cancellationToken) @@ -350,7 +358,9 @@ private static HttpRequestMessage BuildRequest( message, HttpCompletionOption.ResponseContentRead, cancellationToken); - return await response.Content.ReadFromJsonAsync(SerializationOptions, cancellationToken); + return await response.Content.ReadFromJsonAsync( + DashScopeDefaults.SerializationOptions, + cancellationToken); } private async IAsyncEnumerable StreamAsync( @@ -373,7 +383,8 @@ private async IAsyncEnumerable StreamAsync( var data = line["data:".Length..]; if (data.StartsWith("{\"code\":")) { - var error = JsonSerializer.Deserialize(data, SerializationOptions)!; + var error = + JsonSerializer.Deserialize(data, DashScopeDefaults.SerializationOptions)!; throw new DashScopeException( message.RequestUri?.ToString(), (int)response.StatusCode, @@ -381,7 +392,7 @@ private async IAsyncEnumerable StreamAsync( error.Message); } - yield return JsonSerializer.Deserialize(data, SerializationOptions)!; + yield return JsonSerializer.Deserialize(data, DashScopeDefaults.SerializationOptions)!; } } } @@ -418,7 +429,9 @@ private async Task GetSuccessResponseAsync( DashScopeError? error = null; try { - var r = await response.Content.ReadFromJsonAsync(SerializationOptions, cancellationToken); + var r = await response.Content.ReadFromJsonAsync( + DashScopeDefaults.SerializationOptions, + cancellationToken); error = r == null ? null : errorMapper.Invoke(r); } catch (Exception) diff --git a/src/Cnblogs.DashScope.Core/DashScopeClientWebSocket.cs b/src/Cnblogs.DashScope.Core/DashScopeClientWebSocket.cs new file mode 100644 index 0000000..e24a375 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeClientWebSocket.cs @@ -0,0 +1,246 @@ +using System.Net.WebSockets; +using System.Text; +using System.Text.Json; +using System.Threading.Channels; +using Cnblogs.DashScope.Core.Internals; + +namespace Cnblogs.DashScope.Core; + +/// +/// A websocket client for DashScope websocket API. +/// +public sealed class DashScopeClientWebSocket : IDisposable +{ + private static readonly UnboundedChannelOptions UnboundedChannelOptions = + new() + { + SingleWriter = true, + SingleReader = true, + AllowSynchronousContinuations = true + }; + + private readonly IClientWebSocket _socket; + private Task? _receiveTask; + private TaskCompletionSource _taskStartedSignal = new(); + private Channel? _binaryOutput; + + /// + /// The binary output. + /// + public ChannelReader BinaryOutput + => _binaryOutput?.Reader + ?? throw new InvalidOperationException("Please call ResetOutput() before accessing output"); + + /// + /// A task that completed when received task-started event. + /// + public Task TaskStarted => _taskStartedSignal.Task; + + /// + /// Current state for this websocket. + /// + public DashScopeWebSocketState State { get; private set; } + + /// + /// Initialize a configured web socket client. + /// + /// The api key to use. + /// Optional workspace id. + public DashScopeClientWebSocket(string apiKey, string? workspaceId = null) + { + _socket = new ClientWebSocketWrapper(new ClientWebSocket()); + _socket.Options.SetRequestHeader("X-DashScope-DataInspection", "enable"); + _socket.Options.SetRequestHeader("Authorization", $"bearer {apiKey}"); + if (string.IsNullOrEmpty(workspaceId) == false) + { + _socket.Options.SetRequestHeader("X-DashScope-WorkspaceId", workspaceId); + } + } + + /// + /// Initiate a with a pre-configured . + /// + /// Pre-configured . + internal DashScopeClientWebSocket(IClientWebSocket socket) + { + _socket = socket; + } + + /// + /// Start a websocket connection. + /// + /// Websocket API uri. + /// The cancellation token to use. + /// + /// The type of the response content. + /// When was request. + public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken = default) + where TOutput : class + { + await _socket.ConnectAsync(uri, cancellationToken); + _receiveTask = ReceiveMessagesAsync(cancellationToken); + State = DashScopeWebSocketState.Ready; + } + + /// + /// Reset binary output. + /// + public void ResetOutput() + { + _binaryOutput?.Writer.TryComplete(); + _binaryOutput = Channel.CreateUnbounded(UnboundedChannelOptions); + _taskStartedSignal.TrySetResult(false); + _taskStartedSignal = new TaskCompletionSource(); + } + + /// + /// Send message to server. + /// + /// Request to send. + /// A cancellation token used to propagate notification that this operation should be canceled. + /// + /// Type of the input. + /// Type of the parameter. + /// The is requested. + /// Websocket is not connected or already closed. + /// The underlying websocket has already been closed. + public Task SendMessageAsync( + DashScopeWebSocketRequest request, + CancellationToken cancellationToken = default) + where TInput : class, new() + where TParameter : class + { + if (State == DashScopeWebSocketState.Closed) + { + throw new InvalidOperationException("Socket is already closed."); + } + + var json = JsonSerializer.Serialize(request, DashScopeDefaults.SerializationOptions); + return _socket.SendAsync( + new ArraySegment(Encoding.UTF8.GetBytes(json)), + WebSocketMessageType.Text, + true, + cancellationToken); + } + + private async Task?> ReceiveMessageAsync( + CancellationToken cancellationToken = default) + where TOutput : class + { + var buffer = new byte[1024 * 4]; + var segment = new ArraySegment(buffer); + + try + { + var result = await _socket.ReceiveAsync(segment, cancellationToken); + if (result.MessageType == WebSocketMessageType.Close) + { + await CloseAsync(cancellationToken); + return null; + } + + if (result.MessageType == WebSocketMessageType.Binary) + { + for (var i = 0; i < result.Count; i++) + { + await _binaryOutput!.Writer.WriteAsync(buffer[i], cancellationToken); + } + + return null; + } + + var message = Encoding.UTF8.GetString(buffer, 0, result.Count); + var jsonResponse = + JsonSerializer.Deserialize>( + message, + DashScopeDefaults.SerializationOptions); + return jsonResponse; + } + catch + { + // close socket when exception happens. + await CloseAsync(cancellationToken); + } + + return null; + } + + /// + /// Wait for server response. + /// + /// A cancellation token used to propagate notification that this operation should be canceled. + /// Type of the response content. + /// The task was failed. + public async Task ReceiveMessagesAsync(CancellationToken cancellationToken = default) + where TOutput : class + { + while (State != DashScopeWebSocketState.Closed && _socket.CloseStatus == null) + { + var json = await ReceiveMessageAsync(cancellationToken); + if (json == null) + { + continue; + } + + var eventStr = json.Header.Event; + switch (eventStr) + { + case "task-started": + State = DashScopeWebSocketState.RunningTask; + _taskStartedSignal.TrySetResult(true); + break; + case "task-finished": + State = DashScopeWebSocketState.Ready; + _binaryOutput?.Writer.Complete(); + break; + case "task-failed": + await CloseAsync(cancellationToken); + throw new DashScopeException( + null, + 400, + new DashScopeError() + { + Code = json.Header.ErrorCode ?? string.Empty, + Message = json.Header.ErrorMessage ?? string.Empty, + RequestId = json.Header.Attributes.RequestUuid ?? string.Empty + }, + json.Header.ErrorMessage ?? "The task was failed"); + default: + break; + } + } + + await CloseAsync(cancellationToken); + } + + /// + /// Close the underlying websocket. + /// + /// A cancellation token used to propagate notification that this operation should be canceled. + /// + public async Task CloseAsync(CancellationToken cancellationToken = default) + { + await _socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken); + State = DashScopeWebSocketState.Closed; + if (_receiveTask != null) + { + await _receiveTask; + } + } + + private void Dispose(bool disposing) + { + if (disposing) + { + // Dispose managed resources. + _socket.Dispose(); + _binaryOutput?.Writer.TryComplete(); + } + } + + /// + public void Dispose() + { + Dispose(true); + } +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeClientWebSocketPool.cs b/src/Cnblogs.DashScope.Core/DashScopeClientWebSocketPool.cs new file mode 100644 index 0000000..cfa4ed2 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeClientWebSocketPool.cs @@ -0,0 +1,123 @@ +using System.Collections.Concurrent; + +namespace Cnblogs.DashScope.Core; + +/// +/// Socket pool for DashScope API. +/// +public sealed class DashScopeClientWebSocketPool : IDisposable +{ + private readonly ConcurrentBag _available = new(); + private readonly ConcurrentBag _active = new(); + private readonly DashScopeOptions _options; + + /// + /// Socket pool for DashScope API. + /// + /// Options for DashScope sdk. + public DashScopeClientWebSocketPool(DashScopeOptions options) + { + _options = options; + } + + internal DashScopeClientWebSocketPool(IEnumerable sockets) + { + _options = new DashScopeOptions(); + foreach (var socket in sockets) + { + _available.Add(socket); + } + } + + internal void ReturnSocketAsync(DashScopeClientWebSocket socket) + { + if (socket.State != DashScopeWebSocketState.Ready) + { + // not returnable, disposing. + socket.Dispose(); + return; + } + + _available.Add(socket); + } + + /// + /// Rent or create a socket connection from pool. + /// + /// + /// The output type. + /// + public async Task RentSocketAsync( + CancellationToken cancellationToken = default) + where TOutput : class + { + var found = false; + DashScopeClientWebSocket? socket = null; + while (found == false) + { + if (_available.IsEmpty == false) + { + found = _available.TryTake(out socket); + if (socket?.State != DashScopeWebSocketState.Ready) + { + // expired + found = false; + socket?.Dispose(); + } + } + else + { + socket = await InitializeNewSocketAsync(_options.BaseWebsocketAddress, cancellationToken); + found = true; + } + } + + return ActivateSocket(socket!); + } + + private DashScopeClientWebSocketWrapper ActivateSocket(DashScopeClientWebSocket socket) + { + _active.Add(socket); + return new DashScopeClientWebSocketWrapper(socket, this); + } + + private async Task InitializeNewSocketAsync( + string url, + CancellationToken cancellationToken = default) + where TOutput : class + { + if (_available.Count + _active.Count >= _options.SocketPoolSize) + { + throw new InvalidOperationException("[DashScopeSDK] Socket pool is full"); + } + + var socket = new DashScopeClientWebSocket(_options.ApiKey, _options.WorkspaceId); + await socket.ConnectAsync(new Uri(url), cancellationToken); + return socket; + } + + private void Dispose(bool disposing) + { + if (disposing) + { + // Dispose managed resources. + while (_available.IsEmpty == false) + { + _available.TryTake(out var socket); + socket?.Dispose(); + } + + while (_active.IsEmpty == false) + { + _active.TryTake(out var socket); + socket?.Dispose(); + } + } + } + + /// + public void Dispose() + { + Dispose(true); + } +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeClientWebSocketWrapper.cs b/src/Cnblogs.DashScope.Core/DashScopeClientWebSocketWrapper.cs new file mode 100644 index 0000000..23b78fc --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeClientWebSocketWrapper.cs @@ -0,0 +1,58 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Represents a transient wrapper for rented websocket, should be transient. +/// +/// The rented websocket +/// The pool to return the socket to. +public sealed record DashScopeClientWebSocketWrapper(DashScopeClientWebSocket Socket, DashScopeClientWebSocketPool Pool) + : IDisposable +{ + /// + /// The binary output. + /// + public IAsyncEnumerable BinaryOutput => Socket.BinaryOutput.ReadAllAsync(); + + /// + /// The task that completes when received task-started event from server. + /// + public Task TaskStarted => Socket.TaskStarted; + + /// + /// Reset task signal and output cannel. + /// + public void ResetTask() => Socket.ResetOutput(); + + /// + /// Send message to server. + /// + /// Request to send. + /// A cancellation token used to propagate notification that this operation should be canceled. + /// + /// Type of the input. + /// Type of the parameter. + /// The is requested. + /// Websocket is not connected. + /// The underlying websocket has already been closed. + public Task SendMessageAsync( + DashScopeWebSocketRequest request, + CancellationToken cancellationToken = default) + where TInput : class, new() + where TParameter : class + => Socket.SendMessageAsync(request, cancellationToken); + + /// + public void Dispose() + { + Pool.ReturnSocketAsync(Socket); + GC.SuppressFinalize(this); + } + + /// + /// Finalizer. + /// + ~DashScopeClientWebSocketWrapper() + { + Dispose(); + } +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeOptions.cs b/src/Cnblogs.DashScope.Core/DashScopeOptions.cs new file mode 100644 index 0000000..cf2f3c3 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeOptions.cs @@ -0,0 +1,34 @@ +using Cnblogs.DashScope.Core.Internals; + +namespace Cnblogs.DashScope.Core; + +/// +/// Options for DashScope client. +/// +public class DashScopeOptions +{ + /// + /// The api key used to access DashScope api + /// + public string ApiKey { get; set; } = string.Empty; + + /// + /// Base address for DashScope HTTP API. + /// + public string BaseAddress { get; set; } = DashScopeDefaults.HttpApiBaseAddress; + + /// + /// Base address for DashScope websocket API. + /// + public string BaseWebsocketAddress { get; set; } = DashScopeDefaults.WebsocketApiBaseAddress; + + /// + /// Default workspace Id. + /// + public string? WorkspaceId { get; set; } + + /// + /// Default socket pool size. + /// + public int SocketPoolSize { get; set; } = 32; +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequest.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequest.cs new file mode 100644 index 0000000..b1845e0 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequest.cs @@ -0,0 +1,21 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Represents a websocket request to DashScope. +/// +/// Type of the input. +/// Type of the parameter. +public class DashScopeWebSocketRequest + where TInput : class, new() + where TParameter : class +{ + /// + /// Metadata of the request. + /// + public DashScopeWebSocketRequestHeader Header { get; set; } = new(); + + /// + /// Payload of the request. + /// + public DashScopeWebSocketRequestPayload Payload { get; set; } = new(); +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequestHeader.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequestHeader.cs new file mode 100644 index 0000000..5b6e5cc --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequestHeader.cs @@ -0,0 +1,22 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Metadata for websocket request. +/// +public class DashScopeWebSocketRequestHeader +{ + /// + /// Action name. + /// + public string Action { get; set; } = string.Empty; + + /// + /// UUID for task. + /// + public string TaskId { get; set; } = string.Empty; + + /// + /// Streaming type. + /// + public string Streaming { get; set; } = "duplex"; +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequestPayload.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequestPayload.cs new file mode 100644 index 0000000..e994239 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketRequestPayload.cs @@ -0,0 +1,41 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Payload for websocket request. +/// +/// Type of the input. +/// Type of the parameter. +public class DashScopeWebSocketRequestPayload + where TInput : class, new() // Input's default value must be empty object(not null or omitted). + where TParameter : class +{ + /// + /// Group name of task. + /// + public string? TaskGroup { get; set; } + + /// + /// Requesting task name. + /// + public string? Task { get; set; } + + /// + /// Requesting function name. + /// + public string? Function { get; set; } + + /// + /// Model id. + /// + public string? Model { get; set; } + + /// + /// Optional parameters. + /// + public TParameter? Parameters { get; set; } + + /// + /// The input of the request. + /// + public TInput Input { get; set; } = new(); +} diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponse.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponse.cs new file mode 100644 index 0000000..9409f92 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponse.cs @@ -0,0 +1,12 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Response from websocket API. +/// +/// Response metadatas. +/// Response body. +/// Output type of the response. +public record DashScopeWebSocketResponse( + DashScopeWebSocketResponseHeader Header, + DashScopeWebSocketResponsePayload Payload) + where TOutput : class; diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseHeader.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseHeader.cs new file mode 100644 index 0000000..26bbb9f --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseHeader.cs @@ -0,0 +1,16 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Metadata of the websocket response. +/// +/// TaskId of the task. +/// Event name. +/// Error code when is task-failed. +/// Error message when is task-failed. +/// Optional attributes +public record DashScopeWebSocketResponseHeader( + string TaskId, + string Event, + string? ErrorCode, + string? ErrorMessage, + DashScopeWebSocketResponseHeaderAttributes Attributes); diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseHeaderAttributes.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseHeaderAttributes.cs new file mode 100644 index 0000000..1ee663e --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseHeaderAttributes.cs @@ -0,0 +1,7 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Attributes field in websocket response header. +/// +/// UUID for current request. +public record DashScopeWebSocketResponseHeaderAttributes(string? RequestUuid); diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponsePayload.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponsePayload.cs new file mode 100644 index 0000000..9f39256 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponsePayload.cs @@ -0,0 +1,10 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Payload field of websocket API response. +/// +/// Content of the response. +/// Task usage info. +/// Type of the response content. +public record DashScopeWebSocketResponsePayload(TOutput? Output, DashScopeWebSocketResponseUsage? Usage) + where TOutput : class; diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseUsage.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseUsage.cs new file mode 100644 index 0000000..7e7f734 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketResponseUsage.cs @@ -0,0 +1,7 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Usage info of websocket task. +/// +/// Current character usage count. +public record DashScopeWebSocketResponseUsage(int Characters); diff --git a/src/Cnblogs.DashScope.Core/DashScopeWebSocketState.cs b/src/Cnblogs.DashScope.Core/DashScopeWebSocketState.cs new file mode 100644 index 0000000..2cb73b1 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/DashScopeWebSocketState.cs @@ -0,0 +1,27 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// The state of . +/// +public enum DashScopeWebSocketState +{ + /// + /// The socket has been created but not connected yet. + /// + Created, + + /// + /// The socket has been connected and ready. + /// + Ready, + + /// + /// The socket has a running task waiting to be finished. + /// + RunningTask, + + /// + /// The socket has been closed. + /// + Closed +} diff --git a/src/Cnblogs.DashScope.Core/IDashScopeClient.cs b/src/Cnblogs.DashScope.Core/IDashScopeClient.cs index a123050..cb61eb5 100644 --- a/src/Cnblogs.DashScope.Core/IDashScopeClient.cs +++ b/src/Cnblogs.DashScope.Core/IDashScopeClient.cs @@ -247,4 +247,14 @@ public Task UploadFileAsync( public Task DeleteFileAsync( DashScopeFileId id, CancellationToken cancellationToken = default); + + /// + /// Start a speech synthesizer session. Related model: cosyvoice + /// + /// The model to use. + /// Cancellation token. + /// + public Task CreateSpeechSynthesizerSocketSessionAsync( + string modelId, + CancellationToken cancellationToken = default); } diff --git a/src/Cnblogs.DashScope.Core/Internals/Assembly.cs b/src/Cnblogs.DashScope.Core/Internals/Assembly.cs new file mode 100644 index 0000000..e628d36 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/Internals/Assembly.cs @@ -0,0 +1,5 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Cnblogs.DashScope.Sdk.UnitTests")] +[assembly: InternalsVisibleTo("Cnblogs.DashScope.Tests.Shared")] +[assembly: InternalsVisibleTo("DynamicProxyGenAssembly2")] diff --git a/src/Cnblogs.DashScope.Core/Internals/ClientWebSocketWrapper.cs b/src/Cnblogs.DashScope.Core/Internals/ClientWebSocketWrapper.cs new file mode 100644 index 0000000..71ebddb --- /dev/null +++ b/src/Cnblogs.DashScope.Core/Internals/ClientWebSocketWrapper.cs @@ -0,0 +1,47 @@ +using System.Net.WebSockets; + +namespace Cnblogs.DashScope.Core.Internals; + +internal sealed class ClientWebSocketWrapper : IClientWebSocket +{ + private readonly ClientWebSocket _socket; + + public ClientWebSocketWrapper(ClientWebSocket socket) + { + _socket = socket; + } + + /// + public void Dispose() + { + _socket.Dispose(); + } + + /// + public ClientWebSocketOptions Options => _socket.Options; + + /// + public WebSocketCloseStatus? CloseStatus => _socket.CloseStatus; + + /// + public Task ConnectAsync(Uri uri, CancellationToken cancellation) => _socket.ConnectAsync(uri, cancellation); + + /// + public Task SendAsync( + ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken) + => _socket.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + + /// + public Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + => _socket.ReceiveAsync(buffer, cancellationToken); + + /// + public Task CloseAsync( + WebSocketCloseStatus closeStatus, + string? statusDescription, + CancellationToken cancellationToken) + => _socket.CloseAsync(closeStatus, statusDescription, cancellationToken); +} diff --git a/src/Cnblogs.DashScope.Core/Internals/DashScopeDefaults.cs b/src/Cnblogs.DashScope.Core/Internals/DashScopeDefaults.cs index 5c1b537..7623c1b 100644 --- a/src/Cnblogs.DashScope.Core/Internals/DashScopeDefaults.cs +++ b/src/Cnblogs.DashScope.Core/Internals/DashScopeDefaults.cs @@ -1,6 +1,30 @@ -namespace Cnblogs.DashScope.Core.Internals; +using System.Text.Json; +using System.Text.Json.Serialization; -internal static class DashScopeDefaults +namespace Cnblogs.DashScope.Core.Internals; + +/// +/// Default values for DashScope client. +/// +public static class DashScopeDefaults { - public const string DashScopeApiBaseAddress = "https://dashscope.aliyuncs.com/api/v1/"; + /// + /// Base address for HTTP API. + /// + public const string HttpApiBaseAddress = "https://dashscope.aliyuncs.com/api/v1/"; + + /// + /// Base address for websocket API. + /// + public const string WebsocketApiBaseAddress = "wss://dashscope.aliyuncs.com/api-ws/v1/inference/"; + + /// + /// Default json serializer options. + /// + public static readonly JsonSerializerOptions SerializationOptions = + new() + { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower, + }; } diff --git a/src/Cnblogs.DashScope.Core/Internals/IClientWebSocket.cs b/src/Cnblogs.DashScope.Core/Internals/IClientWebSocket.cs new file mode 100644 index 0000000..b4a3cd7 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/Internals/IClientWebSocket.cs @@ -0,0 +1,25 @@ +using System.Net.WebSockets; + +namespace Cnblogs.DashScope.Core.Internals; + +/// +/// Extract for testing purpose. +/// +internal interface IClientWebSocket : IDisposable +{ + public ClientWebSocketOptions Options { get; } + + public WebSocketCloseStatus? CloseStatus { get; } + + public Task ConnectAsync(Uri uri, CancellationToken cancellation); + + public Task SendAsync( + ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken); + + Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken); + + Task CloseAsync(WebSocketCloseStatus closeStatus, string? statusDescription, CancellationToken cancellationToken); +} diff --git a/src/Cnblogs.DashScope.Core/SpeechSynthesizerInput.cs b/src/Cnblogs.DashScope.Core/SpeechSynthesizerInput.cs new file mode 100644 index 0000000..8421fe2 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/SpeechSynthesizerInput.cs @@ -0,0 +1,12 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Input for TTS task. +/// +public class SpeechSynthesizerInput +{ + /// + /// Input text, can be null for run-task command. + /// + public string? Text { get; set; } +} diff --git a/src/Cnblogs.DashScope.Core/SpeechSynthesizerOutput.cs b/src/Cnblogs.DashScope.Core/SpeechSynthesizerOutput.cs new file mode 100644 index 0000000..3e15197 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/SpeechSynthesizerOutput.cs @@ -0,0 +1,7 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Output for TTS task. +/// +/// The output sentences. +public record SpeechSynthesizerOutput(SpeechSynthesizerOutputSentences? Sentences); diff --git a/src/Cnblogs.DashScope.Core/SpeechSynthesizerOutputSentences.cs b/src/Cnblogs.DashScope.Core/SpeechSynthesizerOutputSentences.cs new file mode 100644 index 0000000..0698aba --- /dev/null +++ b/src/Cnblogs.DashScope.Core/SpeechSynthesizerOutputSentences.cs @@ -0,0 +1,7 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Sentences for TTS output. +/// +/// Output words. +public record SpeechSynthesizerOutputSentences(string[]? Words); diff --git a/src/Cnblogs.DashScope.Core/SpeechSynthesizerParameters.cs b/src/Cnblogs.DashScope.Core/SpeechSynthesizerParameters.cs new file mode 100644 index 0000000..8e59c19 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/SpeechSynthesizerParameters.cs @@ -0,0 +1,47 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Parameters for TTS task. +/// +public class SpeechSynthesizerParameters +{ + /// + /// Fixed to "PlainText" + /// + public string TextType { get; set; } = "PlainText"; + + /// + /// The voice to use. + /// + public string Voice { get; set; } = string.Empty; + + /// + /// Output file format, can be pcm, wav or mp3. + /// + public string? Format { get; set; } + + /// + /// Output audio sample rate. + /// + public int? SampleRate { get; set; } + + /// + /// Output audio volume, range between 0-100, defaults to 50. + /// + public int? Volume { get; set; } + + /// + /// Speech speed, range between 0.5~2.0, defaults to 1.0. + /// + public float? Rate { get; set; } + + /// + /// Pitch of the voice, range between 0.5~2, defaults to 1.0. + /// + public float? Pitch { get; set; } + + /// + /// Enable SSML, you can only send text once if enabled. + /// + public bool? EnableSsml { get; set; } +} diff --git a/src/Cnblogs.DashScope.Core/SpeechSynthesizerSocketSession.cs b/src/Cnblogs.DashScope.Core/SpeechSynthesizerSocketSession.cs new file mode 100644 index 0000000..0a96853 --- /dev/null +++ b/src/Cnblogs.DashScope.Core/SpeechSynthesizerSocketSession.cs @@ -0,0 +1,136 @@ +namespace Cnblogs.DashScope.Core; + +/// +/// Represents a socket-based TTS session. +/// +public sealed class SpeechSynthesizerSocketSession + : IDisposable +{ + private readonly DashScopeClientWebSocketWrapper _socket; + private readonly string _modelId; + + /// + /// Represents a socket-based TTS session. + /// + /// Underlying websocket. + /// Model name to use. + public SpeechSynthesizerSocketSession(DashScopeClientWebSocketWrapper socket, string modelId) + { + _socket = socket; + _modelId = modelId; + } + + /// + /// Send a run-task command, use random GUID as taskId. + /// + /// Input parameters. + /// Input text. + /// The cancellation token to use. + /// The generated taskId. + public Task RunTaskAsync( + SpeechSynthesizerParameters parameters, + string? text = null, + CancellationToken cancellationToken = default) + { + return RunTaskAsync(Guid.NewGuid().ToString(), parameters, text, cancellationToken); + } + + /// + /// Send a run-task command. + /// + /// Unique taskId. + /// Input parameters. + /// Input text. + /// The cancellation token to use. + /// . + public async Task RunTaskAsync( + string taskId, + SpeechSynthesizerParameters parameters, + string? text = null, + CancellationToken cancellationToken = default) + { + var command = new DashScopeWebSocketRequest() + { + Header = new DashScopeWebSocketRequestHeader() + { + Action = "run-task", TaskId = taskId, + }, + Payload = new DashScopeWebSocketRequestPayload() + { + Input = new SpeechSynthesizerInput() { Text = text, }, + TaskGroup = "audio", + Task = "tts", + Function = "SpeechSynthesizer", + Model = _modelId, + Parameters = parameters + } + }; + _socket.ResetTask(); + await _socket.SendMessageAsync(command, cancellationToken); + await _socket.TaskStarted; + return taskId; + } + + /// + /// Append input text to task. + /// + /// TaskId to append. + /// Text to append. + /// Cancellation token to use. + public async Task ContinueTaskAsync(string taskId, string input, CancellationToken cancellationToken = default) + { + var command = new DashScopeWebSocketRequest() + { + Header = new DashScopeWebSocketRequestHeader() + { + Action = "continue-task", TaskId = taskId, + }, + Payload = new DashScopeWebSocketRequestPayload() + { + Input = new SpeechSynthesizerInput() { Text = input } + } + }; + await _socket.SendMessageAsync(command, cancellationToken); + } + + /// + /// Send finish-task command. + /// + /// Unique id of the task. + /// The cancellation token to use. + public async Task FinishTaskAsync(string taskId, CancellationToken cancellationToken = default) + { + var command = new DashScopeWebSocketRequest() + { + Header = new DashScopeWebSocketRequestHeader() { TaskId = taskId, Action = "finish-task" }, + Payload = new DashScopeWebSocketRequestPayload() + { + Input = new SpeechSynthesizerInput() + } + }; + await _socket.SendMessageAsync(command, cancellationToken); + } + + /// + /// Get the audio stream. + /// + /// + public IAsyncEnumerable GetAudioAsync() + { + return _socket.BinaryOutput; + } + + private void Dispose(bool disposing) + { + if (disposing) + { + _socket.Dispose(); + } + } + + /// + public void Dispose() + { + Dispose(true); + } +} diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/BaiChuanApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/BaiChuanApiTests.cs index 0d99cac..579789e 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/BaiChuanApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/BaiChuanApiTests.cs @@ -22,6 +22,19 @@ public async Task BaiChuanTextGeneration_UseEnum_SuccessAsync() s => s.Model == "baichuan-7b-v1" && s.Input.Prompt == Cases.Prompt && s.Parameters == null)); } + [Fact] + public async Task BaiChuanTextGeneration_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + + // Act + var act = async () => await client.GetBaiChuanTextCompletionAsync((BaiChuanLlm)(-1), Cases.Prompt); + + // Assert + await Assert.ThrowsAsync(act); + } + [Fact] public async Task BaiChuanTextGeneration_CustomModel_SuccessAsync() { @@ -29,12 +42,12 @@ public async Task BaiChuanTextGeneration_CustomModel_SuccessAsync() var client = Substitute.For(); // Act - _ = await client.GetBaiChuanTextCompletionAsync(BaiChuanLlm.BaiChuan7B, Cases.Prompt); + _ = await client.GetBaiChuanTextCompletionAsync(Cases.CustomModelName, Cases.Prompt); // Assert _ = await client.Received().GetTextCompletionAsync( Arg.Is>( - s => s.Model == "baichuan-7b-v1" && s.Input.Prompt == Cases.Prompt && s.Parameters == null)); + s => s.Model == Cases.CustomModelName && s.Input.Prompt == Cases.Prompt && s.Parameters == null)); } [Fact] @@ -43,6 +56,22 @@ public async Task BaiChuan2TextGeneration_UseEnum_SuccessAsync() // Arrange var client = Substitute.For(); + // Act + var act = async () => await client.GetBaiChuanTextCompletionAsync( + (BaiChuan2Llm)(-1), + Cases.TextMessages, + ResultFormats.Message); + + // Assert + await Assert.ThrowsAsync(act); + } + + [Fact] + public async Task BaiChuan2TextGeneration_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + // Act _ = await client.GetBaiChuanTextCompletionAsync( BaiChuan2Llm.BaiChuan2_13BChatV1, diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/Cnblogs.DashScope.Sdk.UnitTests.csproj b/test/Cnblogs.DashScope.Sdk.UnitTests/Cnblogs.DashScope.Sdk.UnitTests.csproj index 94a3a5e..3fbb8f0 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/Cnblogs.DashScope.Sdk.UnitTests.csproj +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/Cnblogs.DashScope.Sdk.UnitTests.csproj @@ -11,7 +11,6 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientTests.cs index 93f7d4d..7b3aab5 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientTests.cs @@ -81,7 +81,7 @@ public void DashScopeClient_Constructor_WithWorkspaceId() // Arrange const string apiKey = "key"; const string workspaceId = "workspaceId"; - var client = new DashScopeClient(apiKey, null, null, workspaceId); + var client = new DashScopeClient(apiKey, workspaceId: workspaceId); // Act var value = HttpClientAccessor.GetValue(client) as HttpClient; diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientWebSocketTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientWebSocketTests.cs new file mode 100644 index 0000000..6398af6 --- /dev/null +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/DashScopeClientWebSocketTests.cs @@ -0,0 +1,173 @@ +using System.Net; +using System.Net.WebSockets; +using System.Reflection; +using Cnblogs.DashScope.Core; +using Cnblogs.DashScope.Core.Internals; +using Cnblogs.DashScope.Tests.Shared.Utils; +using NSubstitute; + +namespace Cnblogs.DashScope.Sdk.UnitTests; + +public class DashScopeClientWebSocketTests +{ + private static readonly FieldInfo InnerSocketInfo = + typeof(DashScopeClientWebSocket).GetField("_socket", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException( + $"Can not found {nameof(DashScopeClientWebSocket)}._client, please update this test after refactoring"); + + private static readonly PropertyInfo InnerRequestHeaderInfo = + typeof(ClientWebSocketOptions).GetProperty("RequestHeaders", BindingFlags.NonPublic | BindingFlags.Instance) + ?? throw new InvalidOperationException( + $"Can not found {nameof(ClientWebSocketOptions)}.RequestHeaders property, please update this test after framework change"); + + [Fact] + public void Constructor_UseApiKeyAndWorkspaceId_EnsureConfigured() + { + // Arrange + const string apiKey = "apiKey"; + const string workspaceId = "workspaceId"; + + // Act + var client = new DashScopeClientWebSocket(apiKey, workspaceId); + var headers = ExtractHeaders(client); + + // Assert + Assert.Equal($"bearer {apiKey}", headers.GetValues("Authorization")?.First()); + Assert.Equal("enable", headers.GetValues("X-DashScope-DataInspection")?.First()); + Assert.Equal(workspaceId, headers.GetValues("X-DashScope-WorkspaceId")?.First()); + } + + [Fact] + public void Constructor_UseApiKeyWithoutWorkspaceId_EnsureConfigured() + { + // Arrange + const string apiKey = "apiKey"; + + // Act + var client = new DashScopeClientWebSocket(apiKey); + var headers = ExtractHeaders(client); + + // Assert + Assert.Equal($"bearer {apiKey}", headers.GetValues("Authorization")?.First()); + Assert.Equal("enable", headers.GetValues("X-DashScope-DataInspection")?.First()); + Assert.Null(headers.GetValues("X-DashScope-WorkspaceId")); + } + + [Fact] + public void Constructor_UsePreconfiguredSocket_EnsureConfigured() + { + // Arrange + using var socket = new ClientWebSocketWrapper(new ClientWebSocket()); + + // Act + var client = new DashScopeClientWebSocket(socket); + + // Assert + Assert.StrictEqual(socket, InnerSocketInfo.GetValue(client)); + } + + [Fact] + public async Task ConnectAsync_InitialConnect_ChangeStateAsync() + { + // Arrange + var socket = Substitute.For(); + var client = new DashScopeClientWebSocket(socket); + var apiUri = new Uri("ws://test.com"); + + // Act + await client.ConnectAsync(apiUri); + + // Assert + Assert.Equal(DashScopeWebSocketState.Ready, client.State); + await socket.Received(1).ConnectAsync(Arg.Is(apiUri), Arg.Any()); + await socket.Received().ReceiveAsync(Arg.Any>(), Arg.Any()); + } + + [Fact] + public async Task ResetOutput_WithInitialOutput_CompleteThenCreateNewOutputAsync() + { + // Arrange + var socket = Substitute.For(); + var client = new DashScopeClientWebSocket(socket); + client.ResetOutput(); + var oldOutput = client.BinaryOutput; + var oldSignal = client.TaskStarted; + + // Act + client.ResetOutput(); + + // Assert + Assert.False(await oldSignal); + Assert.True(oldOutput.Completion.IsCompletedSuccessfully); + Assert.NotSame(oldOutput, client.BinaryOutput); + Assert.NotSame(oldSignal, client.TaskStarted); + } + + [Fact] + public async Task SendMessageAsync_SocketClosed_ThrowAsync() + { + // Arrange + var socket = Substitute.For(); + var client = new DashScopeClientWebSocket(socket); + var snapshot = Snapshots.SpeechSynthesizer.RunTask; + await client.CloseAsync(); + + // Act + var act = () => client.SendMessageAsync(snapshot.Message); + + // Assert + await Assert.ThrowsAsync(act); + } + + [Fact] + public async Task SendMessageAsync_Connected_SendAsync() + { + // Arrange + var socket = Substitute.For(); + var client = new DashScopeClientWebSocket(socket); + var snapshot = Snapshots.SpeechSynthesizer.RunTask; + + // Act + await client.ConnectAsync(new Uri(DashScopeDefaults.WebsocketApiBaseAddress)); + await client.SendMessageAsync(snapshot.Message); + + // Assert + await socket.Received().SendAsync( + Arg.Is>(s => Checkers.IsJsonEquivalent(s, snapshot.GetRequestJson())), + WebSocketMessageType.Text, + true, + Arg.Any()); + } + + [Fact] + public async Task ReceiveMessageAsync_ServerClosed_CloseAsync() + { + // Arrange + var (_, dashScopeClientWebSocket, server) = await Sut.GetSocketTestClientAsync(); + + // Act + await server.WriteServerCloseAsync(); + + // Assert + Assert.Equal(DashScopeWebSocketState.Closed, dashScopeClientWebSocket.State); + Assert.Equal(WebSocketCloseStatus.NormalClosure, server.CloseStatus); + } + + private static WebHeaderCollection ExtractHeaders(DashScopeClientWebSocket socket) + { + var obj = InnerSocketInfo.GetValue(socket); + if (obj is not IClientWebSocket clientWebSocket) + { + throw new InvalidOperationException($"Get null when trying to fetch {InnerSocketInfo.Name}"); + } + + obj = InnerRequestHeaderInfo.GetValue(clientWebSocket.Options); + if (obj is not WebHeaderCollection headers) + { + throw new InvalidOperationException( + $"Wrong type or null when trying to fetch {InnerRequestHeaderInfo.Name}"); + } + + return headers; + } +} diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/DeepSeekTextGenerationApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/DeepSeekTextGenerationApiTests.cs index 461b037..87643dd 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/DeepSeekTextGenerationApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/DeepSeekTextGenerationApiTests.cs @@ -19,8 +19,23 @@ await client.GetDeepSeekChatCompletionAsync( // Assert await client.Received().GetTextCompletionAsync( - Arg.Is>( - x => x.Model == "deepseek-r1" && x.Input.Messages!.First().Content == "你好" && x.Parameters == null)); + Arg.Is>(x + => x.Model == "deepseek-r1" && x.Input.Messages!.First().Content == "你好" && x.Parameters == null)); + } + + [Fact] + public async Task TextCompletion_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + + // Act + var act = async () => await client.GetDeepSeekChatCompletionAsync( + (DeepSeekLlm)(-1), + new List { TextChatMessage.User("你好") }.AsReadOnly()); + + // Assert + await Assert.ThrowsAsync(act); } [Fact] @@ -37,8 +52,8 @@ await client.GetDeepSeekChatCompletionAsync( // Assert await client.Received().GetTextCompletionAsync( - Arg.Is>( - x => x.Model == customModel && x.Input.Messages!.First().Content == "你好" && x.Parameters == null)); + Arg.Is>(x + => x.Model == customModel && x.Input.Messages!.First().Content == "你好" && x.Parameters == null)); } [Fact] @@ -54,10 +69,9 @@ public void StreamCompletion_UseEnum_SuccessAsync() // Assert _ = client.Received().GetTextCompletionStreamAsync( - Arg.Is>( - x => x.Model == "deepseek-v3" - && x.Input.Messages!.First().Content == "你好" - && x.Parameters!.IncrementalOutput == true)); + Arg.Is>(x => x.Model == "deepseek-v3" + && x.Input.Messages!.First().Content == "你好" + && x.Parameters!.IncrementalOutput == true)); } [Fact] @@ -74,9 +88,8 @@ public void StreamCompletion_CustomModel_SuccessAsync() // Assert _ = client.Received().GetTextCompletionStreamAsync( - Arg.Is>( - x => x.Model == customModel - && x.Input.Messages!.First().Content == "你好" - && x.Parameters!.IncrementalOutput == true)); + Arg.Is>(x => x.Model == customModel + && x.Input.Messages!.First().Content == "你好" + && x.Parameters!.IncrementalOutput == true)); } } diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/Llama2TextGenerationApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/Llama2TextGenerationApiTests.cs index c4d5679..7daf36c 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/Llama2TextGenerationApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/Llama2TextGenerationApiTests.cs @@ -18,11 +18,27 @@ public async Task Llama2_UseEnum_SuccessAsync() // Assert _ = await client.Received().GetTextCompletionAsync( - Arg.Is>( - s => s.Input.Messages == Cases.TextMessages - && s.Model == "llama2-13b-chat-v2" - && s.Parameters != null - && s.Parameters.ResultFormat == ResultFormats.Message)); + Arg.Is>(s + => s.Input.Messages == Cases.TextMessages + && s.Model == "llama2-13b-chat-v2" + && s.Parameters != null + && s.Parameters.ResultFormat == ResultFormats.Message)); + } + + [Fact] + public async Task Llama2_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + + // Act + var act = async () => await client.GetLlama2TextCompletionAsync( + (Llama2Model)(-1), + Cases.TextMessages, + ResultFormats.Message); + + // Assert + await Assert.ThrowsAsync(act); } [Fact] @@ -36,10 +52,10 @@ public async Task Llama2_CustomModel_SuccessAsync() // Assert _ = await client.Received().GetTextCompletionAsync( - Arg.Is>( - s => s.Input.Messages == Cases.TextMessages - && s.Model == Cases.CustomModelName - && s.Parameters != null - && s.Parameters.ResultFormat == ResultFormats.Message)); + Arg.Is>(s + => s.Input.Messages == Cases.TextMessages + && s.Model == Cases.CustomModelName + && s.Parameters != null + && s.Parameters.ResultFormat == ResultFormats.Message)); } } diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/QWenMultimodalApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/QWenMultimodalApiTests.cs index c2c15ce..8e7a553 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/QWenMultimodalApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/QWenMultimodalApiTests.cs @@ -30,8 +30,23 @@ public async Task Multimodal_UseEnum_SuccessAsync() // Assert _ = client.Received().GetMultimodalGenerationAsync( - Arg.Is>( - s => s.Model == "qwen-vl-max" && s.Input.Messages == Messages && s.Parameters == parameters)); + Arg.Is>(s + => s.Model == "qwen-vl-max" && s.Input.Messages == Messages && s.Parameters == parameters)); + } + + [Fact] + public async Task Multimodal_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + var parameters = new MultimodalParameters { Seed = 6666 }; + + // Act + var act = async () + => await client.GetQWenMultimodalCompletionAsync((QWenMultimodalModel)(-1), Messages, parameters); + + // Assert + await Assert.ThrowsAsync(act); } [Fact] @@ -46,8 +61,8 @@ public async Task Multimodal_CustomModel_SuccessAsync() // Assert _ = client.Received().GetMultimodalGenerationAsync( - Arg.Is>( - s => s.Model == Cases.CustomModelName && s.Input.Messages == Messages && s.Parameters == parameters)); + Arg.Is>(s + => s.Model == Cases.CustomModelName && s.Input.Messages == Messages && s.Parameters == parameters)); } [Fact] @@ -62,8 +77,8 @@ public void MultimodalStream_UseEnum_Success() // Assert _ = client.Received().GetMultimodalGenerationStreamAsync( - Arg.Is>( - s => s.Model == "qwen-vl-plus" && s.Input.Messages == Messages && s.Parameters == parameters)); + Arg.Is>(s + => s.Model == "qwen-vl-plus" && s.Input.Messages == Messages && s.Parameters == parameters)); } [Fact] @@ -78,7 +93,7 @@ public void Multimodal_CustomModel_Success() // Assert _ = client.Received().GetMultimodalGenerationStreamAsync( - Arg.Is>( - s => s.Model == Cases.CustomModelName && s.Input.Messages == Messages && s.Parameters == parameters)); + Arg.Is>(s + => s.Model == Cases.CustomModelName && s.Input.Messages == Messages && s.Parameters == parameters)); } } diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/QWenTextGenerationApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/QWenTextGenerationApiTests.cs index ddd6640..5fb1f6f 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/QWenTextGenerationApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/QWenTextGenerationApiTests.cs @@ -102,6 +102,20 @@ await client.Received().GetTextCompletionAsync( s => s.Input.Messages == Cases.TextMessages && s.Parameters == parameters && s.Model == "qwen-max-1201")); } + [Fact] + public async Task QWenChatCompletion_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + var parameters = new TextGenerationParameters { EnableSearch = true, ResultFormat = ResultFormats.Message }; + + // Act + var act = async () => await client.GetQWenChatCompletionAsync((QWenLlm)(-1), Cases.TextMessages, parameters); + + // Assert + await Assert.ThrowsAsync(act); + } + [Fact] public async Task QWenChatCompletion_CustomModel_SuccessAsync() { diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/ServiceCollectionInjectorTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/ServiceCollectionInjectorTests.cs index 02230cc..2d3ff04 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/ServiceCollectionInjectorTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/ServiceCollectionInjectorTests.cs @@ -1,4 +1,5 @@ using System.Net.Http.Headers; +using Cnblogs.DashScope.AspNetCore; using Cnblogs.DashScope.Core; using FluentAssertions; using Microsoft.Extensions.Configuration; @@ -20,11 +21,12 @@ public void Parameter_Normal_Inject() // Act services.AddDashScopeClient(ApiKey); var provider = services.BuildServiceProvider(); - var httpClient = provider.GetRequiredService().CreateClient(nameof(IDashScopeClient)); + var httpClient = provider.GetRequiredService() + .CreateClient(DashScopeAspNetCoreDefaults.DefaultHttpClientName); // Assert provider.GetRequiredService().Should().NotBeNull().And - .BeOfType(); + .BeOfType(); httpClient.Should().NotBeNull(); httpClient.DefaultRequestHeaders.Authorization.Should() .BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey)); @@ -37,13 +39,14 @@ public void Parameter_HasProxy_Inject() var services = new ServiceCollection(); // Act - services.AddDashScopeClient(ApiKey, ProxyApi); + services.AddDashScopeClient(ApiKey, baseAddress: ProxyApi); var provider = services.BuildServiceProvider(); - var httpClient = provider.GetRequiredService().CreateClient(nameof(IDashScopeClient)); + var httpClient = provider.GetRequiredService() + .CreateClient(DashScopeAspNetCoreDefaults.DefaultHttpClientName); // Assert provider.GetRequiredService().Should().NotBeNull().And - .BeOfType(); + .BeOfType(); httpClient.Should().NotBeNull(); httpClient.DefaultRequestHeaders.Authorization.Should() .BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey)); @@ -66,11 +69,12 @@ public void Configuration_Normal_Inject() // Act services.AddDashScopeClient(configuration); var provider = services.BuildServiceProvider(); - var httpClient = provider.GetRequiredService().CreateClient(nameof(IDashScopeClient)); + var httpClient = provider.GetRequiredService() + .CreateClient(DashScopeAspNetCoreDefaults.DefaultHttpClientName); // Assert provider.GetRequiredService().Should().NotBeNull().And - .BeOfType(); + .BeOfType(); httpClient.Should().NotBeNull(); httpClient.DefaultRequestHeaders.Authorization.Should() .BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey)); @@ -93,11 +97,12 @@ public void Configuration_CustomSectionName_Inject() // Act services.AddDashScopeClient(configuration, "dashScopeCustom"); var provider = services.BuildServiceProvider(); - var httpClient = provider.GetRequiredService().CreateClient(nameof(IDashScopeClient)); + var httpClient = provider.GetRequiredService() + .CreateClient(DashScopeAspNetCoreDefaults.DefaultHttpClientName); // Assert provider.GetRequiredService().Should().NotBeNull().And - .BeOfType(); + .BeOfType(); httpClient.Should().NotBeNull(); httpClient.DefaultRequestHeaders.Authorization.Should() .BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey)); @@ -111,14 +116,15 @@ public void Configuration_AddMultipleTime_Replace() var services = new ServiceCollection(); // Act - services.AddDashScopeClient(ApiKey, ProxyApi); - services.AddDashScopeClient(ApiKey, ProxyApi); + services.AddDashScopeClient(ApiKey, baseAddress: ProxyApi); + services.AddDashScopeClient(ApiKey, baseAddress: ProxyApi); var provider = services.BuildServiceProvider(); - var httpClient = provider.GetRequiredService().CreateClient(nameof(IDashScopeClient)); + var httpClient = provider.GetRequiredService() + .CreateClient(DashScopeAspNetCoreDefaults.DefaultHttpClientName); // Assert provider.GetRequiredService().Should().NotBeNull().And - .BeOfType(); + .BeOfType(); httpClient.Should().NotBeNull(); httpClient.DefaultRequestHeaders.Authorization.Should() .BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey)); @@ -130,11 +136,7 @@ public void Configuration_NoApiKey_Throw() { // Arrange var services = new ServiceCollection(); - var config = new Dictionary - { - { "irrelevant", "irr" }, - { "dashScope:baseAddress", ProxyApi } - }; + var config = new Dictionary { { "irrelevant", "irr" }, { "dashScope:baseAddress", ProxyApi } }; var configuration = new ConfigurationBuilder().AddInMemoryCollection(config).Build(); // Act diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/TextEmbeddingApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/TextEmbeddingApiTests.cs index 5ad66d3..d82d43a 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/TextEmbeddingApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/TextEmbeddingApiTests.cs @@ -20,8 +20,23 @@ public async Task GetEmbeddings_UseEnum_SuccessAsync() // Assert await client.Received().GetEmbeddingsAsync( - Arg.Is>( - s => s.Input.Texts == texts && s.Model == "text-embedding-v2" && s.Parameters == parameters)); + Arg.Is>(s + => s.Input.Texts == texts && s.Model == "text-embedding-v2" && s.Parameters == parameters)); + } + + [Fact] + public async Task GetEmbeddings_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + var texts = new[] { "hello" }; + var parameters = new TextEmbeddingParameters { TextType = TextTypes.Query }; + + // Act + var act = async () => await client.GetTextEmbeddingsAsync((TextEmbeddingModel)(-1), texts, parameters); + + // Assert + await Assert.ThrowsAsync(act); } [Fact] @@ -37,7 +52,7 @@ public async Task GetEmbeddings_CustomModel_SuccessAsync() // Assert await client.Received().GetEmbeddingsAsync( - Arg.Is>( - s => s.Input.Texts == texts && s.Model == Cases.CustomModelName && s.Parameters == parameters)); + Arg.Is>(s + => s.Input.Texts == texts && s.Model == Cases.CustomModelName && s.Parameters == parameters)); } } diff --git a/test/Cnblogs.DashScope.Sdk.UnitTests/WanxApiTests.cs b/test/Cnblogs.DashScope.Sdk.UnitTests/WanxApiTests.cs index ff05d7c..ad0dad0 100644 --- a/test/Cnblogs.DashScope.Sdk.UnitTests/WanxApiTests.cs +++ b/test/Cnblogs.DashScope.Sdk.UnitTests/WanxApiTests.cs @@ -42,6 +42,27 @@ public async Task WanxImageSynthesis_UseEnum_SuccessAsync() && s.Parameters == Parameters)); } + [Fact] + public async Task WanxImageSynthesis_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + client.Configure().CreateImageSynthesisTaskAsync( + Arg.Any>(), + Arg.Any()) + .Returns(Snapshots.ImageSynthesis.CreateTask.ResponseModel); + + // Act + var act = async () => await client.CreateWanxImageSynthesisTaskAsync( + (WanxModel)(-1), + Cases.Prompt, + Cases.PromptAlter, + Parameters); + + // Assert + await Assert.ThrowsAsync(act); + } + [Fact] public async Task WanxImageSynthesis_CustomModel_SuccessAsync() { @@ -104,6 +125,25 @@ public async Task WanxImageGeneration_UseEnum_SuccessAsync() && s.Input.StyleIndex == 3)); } + [Fact] + public async Task WanxImageGeneration_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + client.Configure().CreateImageGenerationTaskAsync( + Arg.Any>(), + Arg.Any()) + .Returns(Snapshots.ImageGeneration.CreateTaskNoSse.ResponseModel); + + // Act + var act = async () => await client.CreateWanxImageGenerationTaskAsync( + (WanxStyleRepaintModel)(-1), + new ImageGenerationInput { ImageUrl = Cases.ImageUrl, StyleIndex = 3 }); + + // Assert + await Assert.ThrowsAsync(act); + } + [Fact] public async Task WanxImageGeneration_CustomModel_SuccessAsync() { @@ -162,6 +202,25 @@ public async Task WanxBackgroundImageGeneration_UseEnum_SuccessAsync() && s.Input.BaseImageUrl == Cases.ImageUrl)); } + [Fact] + public async Task WanxBackgroundImageGeneration_UseInvalidEnum_SuccessAsync() + { + // Arrange + var client = Substitute.For(); + client.Configure().CreateBackgroundGenerationTaskAsync( + Arg.Any>(), + Arg.Any()) + .Returns(Snapshots.BackgroundGeneration.CreateTaskNoSse.ResponseModel); + + // Act + var act = async () => await client.CreateWanxBackgroundGenerationTaskAsync( + (WanxBackgroundGenerationModel)(-1), + new BackgroundGenerationInput { BaseImageUrl = Cases.ImageUrl }); + + // Assert + await Assert.ThrowsAsync(act); + } + [Fact] public async Task WanxBackgroundImageGeneration_CustomModel_SuccessAsync() { diff --git a/test/Cnblogs.DashScope.Tests.Shared/Assembly.cs b/test/Cnblogs.DashScope.Tests.Shared/Assembly.cs new file mode 100644 index 0000000..7c8a352 --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/Assembly.cs @@ -0,0 +1,4 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Cnblogs.DashScope.Sdk.UnitTests")] +[assembly: InternalsVisibleTo("Cnblogs.DashScope.AI.UnitTests")] diff --git a/test/Cnblogs.DashScope.Tests.Shared/Cnblogs.DashScope.Tests.Shared.csproj b/test/Cnblogs.DashScope.Tests.Shared/Cnblogs.DashScope.Tests.Shared.csproj index 6c9b008..f2aedfd 100644 --- a/test/Cnblogs.DashScope.Tests.Shared/Cnblogs.DashScope.Tests.Shared.csproj +++ b/test/Cnblogs.DashScope.Tests.Shared/Cnblogs.DashScope.Tests.Shared.csproj @@ -8,7 +8,7 @@ - + diff --git a/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.continue-task.json b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.continue-task.json new file mode 100644 index 0000000..20ed15f --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.continue-task.json @@ -0,0 +1,11 @@ +{ + "header": { + "action": "continue-task", + "task_id": "439e0616-2f5b-44e0-8872-0002a066a49c" + }, + "payload": { + "input": { + "text": "代码改变世界" + } + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.finish-task.json b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.finish-task.json new file mode 100644 index 0000000..2efb69a --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.finish-task.json @@ -0,0 +1,9 @@ +{ + "header": { + "action": "finish-task", + "task_id": "439e0616-2f5b-44e0-8872-0002a066a49c" + }, + "payload": { + "input": {} + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.result-generated.json b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.result-generated.json new file mode 100644 index 0000000..c3111db --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.result-generated.json @@ -0,0 +1,17 @@ +{ + "header": { + "task_id": "439e0616-2f5b-44e0-8872-0002a066a49c", + "event": "result-generated", + "attributes": { + "request_uuid": "c88301b4-3caa-4f15-94e2-246e84d2e648", + "x-ds-batch-queue-length": "0" + } + }, + "payload": { + "output": { + "sentence": { + "words": [] + } + } + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.run-task.json b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.run-task.json new file mode 100644 index 0000000..482a599 --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.run-task.json @@ -0,0 +1,24 @@ +{ + "header": { + "action": "run-task", + "task_id": "439e0616-2f5b-44e0-8872-0002a066a49c", + "streaming": "duplex" + }, + "payload": { + "model": "cosyvoice-v1", + "task_group": "audio", + "task": "tts", + "function": "SpeechSynthesizer", + "input": {}, + "parameters": { + "voice": "longxiaochun", + "volume": 50, + "text_type": "PlainText", + "sample_rate": 0, + "rate": 1.1, + "format": "mp3", + "pitch": 1.2, + "enable_ssml": true + } + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.task-finished.json b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.task-finished.json new file mode 100644 index 0000000..e48f75a --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.task-finished.json @@ -0,0 +1,20 @@ +{ + "header": { + "task_id": "439e0616-2f5b-44e0-8872-0002a066a49c", + "event": "task-finished", + "attributes": { + "request_uuid": "c88301b4-3caa-4f15-94e2-246e84d2e648", + "x-ds-batch-queue-length": "0" + } + }, + "payload": { + "output": { + "sentence": { + "words": [] + } + }, + "usage": { + "characters": 12 + } + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.task-started.json b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.task-started.json new file mode 100644 index 0000000..ebcddee --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/RawHttpData/socket-speech-synthesizer.task-started.json @@ -0,0 +1,8 @@ +{ + "header": { + "task_id": "439e0616-2f5b-44e0-8872-0002a066a49c", + "event": "task-started", + "attributes": {} + }, + "payload": {} +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/Utils/Checkers.cs b/test/Cnblogs.DashScope.Tests.Shared/Utils/Checkers.cs index 0e66fbd..df574e7 100644 --- a/test/Cnblogs.DashScope.Tests.Shared/Utils/Checkers.cs +++ b/test/Cnblogs.DashScope.Tests.Shared/Utils/Checkers.cs @@ -4,6 +4,13 @@ namespace Cnblogs.DashScope.Tests.Shared.Utils; public static class Checkers { + public static bool IsJsonEquivalent(ArraySegment socketBuffer, string requestSnapshot) + { + var actual = JsonNode.Parse(socketBuffer); + var expected = JsonNode.Parse(requestSnapshot); + return JsonNode.DeepEquals(actual, expected); + } + public static bool IsJsonEquivalent(HttpContent content, string requestSnapshot) { #pragma warning disable VSTHRD002 diff --git a/test/Cnblogs.DashScope.Tests.Shared/Utils/FakeClientWebSocket.cs b/test/Cnblogs.DashScope.Tests.Shared/Utils/FakeClientWebSocket.cs new file mode 100644 index 0000000..589dbbe --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/Utils/FakeClientWebSocket.cs @@ -0,0 +1,78 @@ +using System.Net.WebSockets; +using System.Threading.Channels; +using Cnblogs.DashScope.Core.Internals; + +namespace Cnblogs.DashScope.Tests.Shared.Utils; + +public sealed class FakeClientWebSocket : IClientWebSocket +{ + public List> ReceivedMessages { get; } = new(); + + public Channel Server { get; } = + Channel.CreateUnbounded(); + + public async Task WriteServerCloseAsync() + { + var close = new WebSocketReceiveResult(1, WebSocketMessageType.Close, true); + await Server.Writer.WriteAsync(close); + Server.Writer.Complete(); + } + + private void Dispose(bool disposing) + { + // nothing to release. + if (disposing) + { + Server.Writer.Complete(); + } + } + + /// + public void Dispose() + { + Dispose(true); + } + + /// + public ClientWebSocketOptions Options { get; set; } = null!; + + /// + public WebSocketCloseStatus? CloseStatus { get; set; } + + /// + public Task ConnectAsync(Uri uri, CancellationToken cancellation) + { + // do nothing. + return Task.CompletedTask; + } + + /// + public Task SendAsync( + ArraySegment buffer, + WebSocketMessageType messageType, + bool endOfMessage, + CancellationToken cancellationToken) + { + ReceivedMessages.Add(buffer); + return Task.CompletedTask; + } + + /// + public async Task ReceiveAsync( + ArraySegment buffer, + CancellationToken cancellationToken) + { + await Server.Reader.WaitToReadAsync(cancellationToken); + return await Server.Reader.ReadAsync(cancellationToken); + } + + /// + public Task CloseAsync( + WebSocketCloseStatus closeStatus, + string? statusDescription, + CancellationToken cancellationToken) + { + CloseStatus = WebSocketCloseStatus.NormalClosure; + return Task.CompletedTask; + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/Utils/Snapshots.SocketRequests.cs b/test/Cnblogs.DashScope.Tests.Shared/Utils/Snapshots.SocketRequests.cs new file mode 100644 index 0000000..eb13289 --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/Utils/Snapshots.SocketRequests.cs @@ -0,0 +1,40 @@ +using Cnblogs.DashScope.Core; + +namespace Cnblogs.DashScope.Tests.Shared.Utils; + +public partial class Snapshots +{ + public static class SpeechSynthesizer + { + private const string GroupName = "speech-synthesizer"; + + public static readonly + SocketMessageSnapshot> + RunTask = new(GroupName, "run-task", new DashScopeWebSocketRequest() + { + Header = new DashScopeWebSocketRequestHeader() + { + Action = "run-task", + Streaming = "duplex", + TaskId = "439e0616-2f5b-44e0-8872-0002a066a49c" + }, + Payload = new DashScopeWebSocketRequestPayload() + { + Task = "tts", + TaskGroup = "audio", + Function = "SpeechSynthesizer", + Model = "cosyvoice-v1", + Parameters = new SpeechSynthesizerParameters() + { + EnableSsml = true, + Format = "mp3", + Pitch = 1.2f, + Voice = "longxiaochun", + Volume = 50, + SampleRate = 0, + Rate = 1.1f, + } + } + }); + } +} diff --git a/test/Cnblogs.DashScope.Tests.Shared/Utils/SocketMessageSnapshot.cs b/test/Cnblogs.DashScope.Tests.Shared/Utils/SocketMessageSnapshot.cs new file mode 100644 index 0000000..612dc8c --- /dev/null +++ b/test/Cnblogs.DashScope.Tests.Shared/Utils/SocketMessageSnapshot.cs @@ -0,0 +1,12 @@ +namespace Cnblogs.DashScope.Tests.Shared.Utils; + +public record SocketMessageSnapshot(string GroupName, string MessageName) +{ + public string GetRequestJson() + { + return File.ReadAllText(Path.Combine("RawHttpData", $"socket-{GroupName}.{MessageName}.json")); + } +} + +public record SocketMessageSnapshot(string GroupName, string MessageName, TMessage Message) + : SocketMessageSnapshot(GroupName, MessageName); diff --git a/test/Cnblogs.DashScope.Tests.Shared/Utils/Sut.cs b/test/Cnblogs.DashScope.Tests.Shared/Utils/Sut.cs index cb42ce2..658a046 100644 --- a/test/Cnblogs.DashScope.Tests.Shared/Utils/Sut.cs +++ b/test/Cnblogs.DashScope.Tests.Shared/Utils/Sut.cs @@ -1,4 +1,5 @@ using Cnblogs.DashScope.Core; +using Cnblogs.DashScope.Core.Internals; using NSubstitute; using NSubstitute.Extensions; @@ -20,7 +21,25 @@ public static class Sut public static (DashScopeClientCore Client, MockHttpMessageHandler Handler) GetTestClient() { var handler = Substitute.ForPartsOf(); - var client = new DashScopeClientCore(new HttpClient(handler) { BaseAddress = new Uri("https://example.com") }); + var client = new DashScopeClientCore( + new HttpClient(handler) { BaseAddress = new Uri("https://example.com") }, + new DashScopeClientWebSocketPool(new DashScopeOptions())); return (client, handler); } + + // IClientWebSocket is internal, use InternalVisibleToAttribute make it visible to Cnblogs.DashScope.Sdk.UnitTests + internal static async + Task<(DashScopeClientCore Client, DashScopeClientWebSocket ClientWebSocket, FakeClientWebSocket Server)> + GetSocketTestClientAsync() + where TOutput : class + { + var socket = new FakeClientWebSocket(); + var dsWebSocket = new DashScopeClientWebSocket(socket); + await dsWebSocket.ConnectAsync( + new Uri(DashScopeDefaults.WebsocketApiBaseAddress), + CancellationToken.None); + var pool = new DashScopeClientWebSocketPool(new List { dsWebSocket }); + var client = new DashScopeClientCore(new HttpClient(), pool); + return (client, dsWebSocket, socket); + } }