8000 Updates to examples for v2.2.1 and Apple Silicon · dotnet/TorchSharpExamples@1ed13f9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1ed13f9

Browse files
Updates to examples for v2.2.1 and Apple Silicon
1 parent f0a1deb commit 1ed13f9

File tree

16 files changed

+57
-22
lines changed

16 files changed

+57
-22
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,13 @@ healthchecksdb
349349
# Backup folder for Package Reference Convert tool in Visual Studio 2017
350350
MigrationBackup/
351351

352+
Downloads/
353+
runs/
354+
352355
# Ionide (cross platform F# VS Code tools) working folder
353356
.ionide/
354357

355358
*.dat.x
356359
*.dat.y
360+
361+
nuget.config

src/CSharp/CSharpExamples/CIFAR10.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Linq;
55
using System.Collections.Generic;
66
using System.Diagnostics;
7+
using System.Runtime.InteropServices;
78

89
using TorchSharp;
910
using static TorchSharp.torchvision;
@@ -47,9 +48,10 @@ internal static void Run(int epochs, int timeout, string logdir, string modelNam
4748
// This worked on a GeForce RTX 2080 SUPER with 8GB, for all the available network architectures.
4849
// It may not fit with less memory than that, but it's worth modifying the batch size to fit in memory.
4950
torch.cuda.is_available() ? torch.CUDA :
51+
torch.mps_is_available() ? torch.MPS :
5052
torch.CPU;
5153

52-
if (device.type == DeviceType.CUDA)
54+
if (device.type != DeviceType.CPU)
5355
{
5456
_trainBatchSize *= 8;
5557
_testBatchSize *= 8;

src/CSharp/CSharpExamples/CSharpExamples.csproj

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
<OutputType>Exe</OutputType>
55
<TargetFramework>net6.0</TargetFramework>
66
<StartupObject>CSharpExamples.Program</StartupObject>
7-
<PlatformTarget>x64</PlatformTarget>
87
</PropertyGroup>
98

109
<ItemGroup>
@@ -18,8 +17,8 @@
1817
</ItemGroup>
1918

2019
<ItemGroup>
21-
<PackageReference Include="TorchSharp-cpu" Version="0.100.5" />
22-
<PackageReference Include="TorchVision" Version="0.100.5" />
20+
<PackageReference Include="TorchSharp-cpu" Version="0.102.0" />
21+
<PackageReference Include="TorchVision" Version="0.102.0" />
2322
</ItemGroup>
2423

2524
<ItemGroup>

src/CSharp/CSharpExamples/MNIST.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ internal static void Run(int epochs, int timeout, string logdir, string dataset)
5252
dataset = "mnist";
5353
}
5454

55-
var device = cuda.is_available() ? CUDA : CPU;
55+
var device =
56+
torch.cuda.is_available() ? torch.CUDA :
57+
torch.mps_is_available() ? torch.MPS :
58+
torch.CPU;
5659

5760
Console.WriteLine();
5861
Console.WriteLine($"\tRunning MNIST with {dataset} on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}.");

src/CSharp/CSharpExamples/SequenceToSequence.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ internal static void Run(int epochs, int timeout, string logdir)
5151

5252
var cwd = Environment.CurrentDirectory;
5353

54-
var device = torch.cuda.is_available() ? torch.CUDA : torch.CPU;
54+
var device =
55+
torch.cuda.is_available() ? torch.CUDA :
56+
torch.mps_is_available() ? torch.MPS :
57+
torch.CPU;
5558

5659
Console.WriteLine();
5760
Console.WriteLine($"\tRunning SequenceToSequence on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}.");

src/CSharp/CSharpExamples/TextClassification.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ internal static void Run(int epochs, int timeout, string logdir)
4848

4949
var cwd = Environment.CurrentDirectory;
5050

51-
var device = torch.cuda.is_available() ? torch.CUDA : torch.CPU;
51+
var device =
52+
torch.cuda.is_available() ? torch.CUDA :
53+
torch.mps_is_available() ? torch.MPS :
54+
torch.CPU;
55+
5256
Console.WriteLine();
5357
Console.WriteLine($"\tRunning TextClassification on {device.type.ToString()} for {epochs} epochs, terminating after {TimeSpan.FromSeconds(timeout)}.");
5458
Console.WriteLine();

src/CSharp/Models/AlexNet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public AlexNet(string name, int numClasses, Device device = null) : base(name)
5353

5454
RegisterComponents();
5555

56-
if (device != null && device.type == DeviceType.CUDA)
56+
if (device != null && device.type != DeviceType.CPU)
5757
this.to(device);
5858
}
5959

src/CSharp/Models/MNIST.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public Model(string name, torch.Device device = null) : base(name)
3535
{
3636
RegisterComponents();
3737

38-
if (device != null && device.type == DeviceType.CUDA)
38+
if (device != null && device.type != DeviceType.CPU)
3939
this.to(device);
4040
}
4141

src/CSharp/Models/MobileNet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public MobileNet(string name, int numClasses, Device device = null) : base(name)
4343

4444
RegisterComponents();
4545

46-
if (device != null && device.type == DeviceType.CUDA)
46+
if (device != null && device.type != DeviceType.CPU)
4747
this.to(device);
4848
}
4949

src/CSharp/Models/Models.csproj

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
<PropertyGroup>
44
<TargetFramework>net6.0</TargetFramework>
5-
<PlatformTarget>x64</PlatformTarget>
65
</PropertyGroup>
76

87
<ItemGroup>
9-
<PackageReference Include="TorchSharp-cpu" Version="0.100.5" />
10-
<PackageReference Include="TorchVision" Version="0.100.5" />
8+
<PackageReference Include="TorchSharp-cpu" Version="0.102.0" />
9+
<PackageReference Include="TorchVision" Version="0.102.0" />
1110
</ItemGroup>
1211

1312
</Project>

0 commit comments

Comments
 (0)
0