diff --git a/src/CredentialManagement/Credential.cs b/src/CredentialManagement/Credential.cs index 964f42c147..170bd479d0 100644 --- a/src/CredentialManagement/Credential.cs +++ b/src/CredentialManagement/Credential.cs @@ -60,6 +60,28 @@ public Credential( _lastWriteTime = DateTime.MinValue; } + public static Credential Load(string key) + { + var result = new Credential(); + result.Target = key; + result.Type = CredentialType.Generic; + return result.Load() ? result : null; + } + + public static void Save(string key, string username, string password) + { + var result = new Credential(username, password, key); + result.Save(); + } + + public static void Delete(string key) + { + var result = new Credential(); + result.Target = key; + result.Type = CredentialType.Generic; + result.Delete(); + } + bool disposed; void Dispose(bool disposing) { diff --git a/src/GitHub.Api/WindowsKeychain.cs b/src/GitHub.Api/WindowsKeychain.cs index 5597831f03..5e44d21f2a 100644 --- a/src/GitHub.Api/WindowsKeychain.cs +++ b/src/GitHub.Api/WindowsKeychain.cs @@ -19,17 +19,37 @@ public Task> Load(HostAddress hostAddress) { Guard.ArgumentNotNull(hostAddress, nameof(hostAddress)); + var key = GetKey(hostAddress.CredentialCacheKeyHost); + var keyGit = GetKeyGit(hostAddress.CredentialCacheKeyHost); var keyHost = GetKeyHost(hostAddress.CredentialCacheKeyHost); - - using (var credential = new Credential()) + Tuple result = null; + + // Visual Studio requires two credentials, keyed as "{hostname}" (e.g. "https://github.com/") and + // "git:{hostname}" (e.g. "git:https://github.com"). We have a problem in that these credentials can + // potentially be overwritten by other applications, so we store an extra "master" key as + // "GitHub for Visual Studio - {hostname}". Whenever we read the credentials we overwrite the other + // two keys with the value from the master key. Older versions of GHfVS did not store this master key + // so if it does not exist, try to get the value from the "{hostname}" key. + using (var credential = Credential.Load(key)) + using (var credentialGit = Credential.Load(keyGit)) + using (var credentialHost = Credential.Load(keyHost)) { - credential.Target = keyHost; - credential.Type = CredentialType.Generic; - if (credential.Load()) - return Task.FromResult(Tuple.Create(credential.Username, credential.Password)); + if (credential != null) + { + result = Tuple.Create(credential.Username, credential.Password); + } + else if (credentialHost != null) + { + result = Tuple.Create(credentialHost.Username, credentialHost.Password); + } + + if (result != null) + { + Save(result.Item1, result.Item2, hostAddress); + } } - return Task.FromResult>(null); + return Task.FromResult(result); } /// @@ -39,18 +59,13 @@ public Task Save(string userName, string password, HostAddress hostAddress) Guard.ArgumentNotEmptyString(password, nameof(password)); Guard.ArgumentNotNull(hostAddress, nameof(hostAddress)); + var key = GetKey(hostAddress.CredentialCacheKeyHost); var keyGit = GetKeyGit(hostAddress.CredentialCacheKeyHost); var keyHost = GetKeyHost(hostAddress.CredentialCacheKeyHost); - using (var credential = new Credential(userName, password, keyGit)) - { - credential.Save(); - } - - using (var credential = new Credential(userName, password, keyHost)) - { - credential.Save(); - } + Credential.Save(key, userName, password); + Credential.Save(keyGit, userName, password); + Credential.Save(keyHost, userName, password); return Task.CompletedTask; } @@ -60,26 +75,27 @@ public Task Delete(HostAddress hostAddress) { Guard.ArgumentNotNull(hostAddress, nameof(hostAddress)); + var key = GetKey(hostAddress.CredentialCacheKeyHost); var keyGit = GetKeyGit(hostAddress.CredentialCacheKeyHost); var keyHost = GetKeyHost(hostAddress.CredentialCacheKeyHost); - using (var credential = new Credential()) - { - credential.Target = keyGit; - credential.Type = CredentialType.Generic; - credential.Delete(); - } - - using (var credential = new Credential()) - { - credential.Target = keyHost; - credential.Type = CredentialType.Generic; - credential.Delete(); - } + Credential.Delete(key); + Credential.Delete(keyGit); + Credential.Delete(keyHost); return Task.CompletedTask; } + static string GetKey(string key) + { + key = FormatKey(key); + if (key.StartsWith("git:", StringComparison.Ordinal)) + key = key.Substring("git:".Length); + if (!key.EndsWith("/", StringComparison.Ordinal)) + key += '/'; + return "GitHub for Visual Studio - " + key; + } + static string GetKeyGit(string key) { key = FormatKey(key);