diff --git a/src/GitHub.Api/Caching/FileCache.cs b/src/GitHub.Api/Caching/FileCache.cs new file mode 100644 index 0000000000..dd5f937e61 --- /dev/null +++ b/src/GitHub.Api/Caching/FileCache.cs @@ -0,0 +1,1295 @@ +/* +Copyright 2012, 2013, 2017 Adam Carter (http://adam-carter.com) + +This file is part of FileCache (http://github.com/acarteas/FileCache). + +FileCache is distributed under the Apache License 2.0. +Consult "LICENSE.txt" included in this package for the Apache License 2.0. +*/ +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.Serialization; +using System.Runtime.Serialization.Formatters.Binary; +using System.Threading; +using System.Threading.Tasks; + +namespace System.Runtime.Caching +{ + public class FileCache : ObjectCache + { + private static int _nameCounter = 1; + private string _name = ""; + private SerializationBinder _binder; + private string _cacheSubFolder = "cache"; + private string _policySubFolder = "policy"; + private TimeSpan _cleanInterval = new TimeSpan(7, 0, 0, 0); // default to 1 week + private const string LastCleanedDateFile = "cache.lcd"; + private const string CacheSizeFile = "cache.size"; + // this is a file used to prevent multiple processes from trying to "clean" at the same time + private const string SemaphoreFile = "cache.sem"; + private long _currentCacheSize = 0; + private PayloadMode _readMode = PayloadMode.Serializable; + public string CacheDir { get; protected set; } + + + /// + /// Used to store the default region when accessing the cache via [] calls + /// + public string DefaultRegion { get; set; } + + /// + /// Used to set the default policy when setting cache values via [] calls + /// + public CacheItemPolicy DefaultPolicy { get; set; } + + /// + /// Specified how the cache payload is to be handled. + /// + public enum PayloadMode + { + /// + /// Treat the payload a a serializable object. + /// + Serializable, + /// + /// Treat the payload as a file name. File content will be copied on add, while get returns the file name. + /// + Filename, + /// + /// Treat the paylad as raw bytes. A byte[] and readable streams are supported on add. + /// + RawBytes + } + + /// + /// Specified whether the payload is deserialized or just the file name. + /// + public PayloadMode PayloadReadMode + { + get => _readMode; + set + { + if (value == PayloadMode.RawBytes) + { + throw new ArgumentException("The read mode cannot be set to RawBytes. Use the file name please."); + } + _readMode = value; + } + } + + /// + /// Specified how the payload is to be handled on add operations. + /// + public PayloadMode PayloadWriteMode { get; set; } = PayloadMode.Serializable; + + /// + /// The amount of time before expiry that a filename will be used as a payoad. I.e. + /// the amount of time the cache's user can safely use the file delivered as a payload. + /// Default 10 minutes. + /// + public TimeSpan FilenameAsPayloadSafetyMargin = TimeSpan.FromMinutes(10); + + /// + /// Used to determine how long the FileCache will wait for a file to become + /// available. Default (00:00:00) is indefinite. Should the timeout be + /// reached, an exception will be thrown. + /// + public TimeSpan AccessTimeout { get; set; } + + /// + /// Used to specify the disk size, in bytes, that can be used by the File Cache + /// + public long MaxCacheSize { get; set; } + + /// + /// Returns the approximate size of the file cache + /// + public long CurrentCacheSize + { + get + { + // if this is the first query, we need to load the cache size from somewhere + if (_currentCacheSize == 0) + { + // Read the system file for cache size + object cacheSizeObj = ReadSysFile(CacheSizeFile); + + // Did we successfully get data from the file? + if (cacheSizeObj != null) + { + _currentCacheSize = (long)cacheSizeObj; + } + } + + return _currentCacheSize; + } + private set + { + // no need to do a pointless re-store of the same value + if (_currentCacheSize != value || value == 0) + { + WriteSysFile(CacheSizeFile, value); + _currentCacheSize = value; + } + } + } + + /// + /// Event that will be called when is reached. + /// + public event EventHandler MaxCacheSizeReached = delegate { }; + + public event EventHandler CacheResized = delegate { }; + + /// + /// The default cache path used by FC. + /// + private string DefaultCachePath + { + get + { + return Directory.GetCurrentDirectory(); + } + } + + #region constructors + + /// + /// Creates a default instance of the file cache. Don't use if you plan to serialize custom objects + /// + /// If true, will calcualte the cache's current size upon new object creation. + /// Turned off by default as directory traversal is somewhat expensive and may not always be necessary based on + /// use case. + /// + /// If supplied, sets the interval of time that must occur between self cleans + public FileCache(bool calculateCacheSize = false, TimeSpan cleanInterval = new TimeSpan()) + { + // CT note: I moved this code to an init method because if the user specified a cache root, that needs to + // be set before checking if we should clean (otherwise it will look for the file in the wrong place) + Init(calculateCacheSize, cleanInterval); + } + + /// + /// Creates an instance of the file cache using the supplied path as the root save path. + /// + /// The cache's root file path + /// If true, will calcualte the cache's current size upon new object creation. + /// Turned off by default as directory traversal is somewhat expensive and may not always be necessary based on + /// use case. + /// + /// If supplied, sets the interval of time that must occur between self cleans + public FileCache(string cacheRoot, bool calculateCacheSize = false, TimeSpan cleanInterval = new TimeSpan()) + { + CacheDir = cacheRoot; + Init(calculateCacheSize, cleanInterval, false); + } + + /// + /// Creates an instance of the file cache. + /// + /// The SerializationBinder used to deserialize cached objects. Needed if you plan + /// to cache custom objects. + /// + /// If true, will calcualte the cache's current size upon new object creation. + /// Turned off by default as directory traversal is somewhat expensive and may not always be necessary based on + /// use case. + /// + /// If supplied, sets the interval of time that must occur between self cleans + public FileCache(SerializationBinder binder, bool calculateCacheSize = false, TimeSpan cleanInterval = new TimeSpan()) + { + _binder = binder; + Init(calculateCacheSize, cleanInterval, true, false); + } + + /// + /// Creates an instance of the file cache. + /// + /// The cache's root file path + /// The SerializationBinder used to deserialize cached objects. Needed if you plan + /// to cache custom objects. + /// If true, will calcualte the cache's current size upon new object creation. + /// Turned off by default as directory traversal is somewhat expensive and may not always be necessary based on + /// use case. + /// + /// If supplied, sets the interval of time that must occur between self cleans + public FileCache(string cacheRoot, SerializationBinder binder, bool calculateCacheSize = false, TimeSpan cleanInterval = new TimeSpan()) + { + _binder = binder; + CacheDir = cacheRoot; + Init(calculateCacheSize, cleanInterval, false, false); + } + + #endregion + + #region custom methods + + private void Init(bool calculateCacheSize = false, TimeSpan cleanInterval = new TimeSpan(), bool setCacheDirToDefault = true, bool setBinderToDefault = true) + { + _name = "FileCache_" + _nameCounter; + _nameCounter++; + + DefaultRegion = null; + DefaultPolicy = new CacheItemPolicy(); + AccessTimeout = new TimeSpan(); + MaxCacheSize = long.MaxValue; + + // set default values if not already set + if (setCacheDirToDefault) + CacheDir = DefaultCachePath; + if (setBinderToDefault) + _binder = new FileCacheBinder(); + + // if it doesn't exist, we need to make it + if (!Directory.Exists(CacheDir)) + Directory.CreateDirectory(CacheDir); + + // only set the clean interval if the user supplied it + if (cleanInterval > new TimeSpan()) + { + _cleanInterval = cleanInterval; + } + + //check to see if cache is in need of immediate cleaning + if (ShouldClean()) + { + CleanCacheAsync(); + } + else if (calculateCacheSize || CurrentCacheSize == 0) + { + // This is in an else if block, because CleanCacheAsync will + // update the cache size, so no need to do it twice. + UpdateCacheSizeAsync(); + } + + MaxCacheSizeReached += FileCache_MaxCacheSizeReached; + } + + private void FileCache_MaxCacheSizeReached(object sender, FileCacheEventArgs e) + { + Task.Factory.StartNew((Action)(() => + { + // Shrink the cache to 75% of the max size + // that way there's room for it to grow a bit + // before we have to do this again. + long newSize = ShrinkCacheToSize((long)(MaxCacheSize * 0.75)); + })); + } + + + // Returns the cleanlock file if it can be opened, otherwise it is being used by another process so return null + private FileStream GetCleaningLock() + { + try + { + return File.Open(Path.Combine(CacheDir, SemaphoreFile), FileMode.OpenOrCreate, FileAccess.ReadWrite, FileShare.None); + } + catch (Exception) + { + return null; + } + } + + // Determines whether or not enough time has passed that the cache should clean itself + private bool ShouldClean() + { + try + { + // if the file can't be found, or is corrupt this will throw an exception + DateTime? lastClean = ReadSysFile(LastCleanedDateFile) as DateTime?; + + //AC: rewrote to be safer in null cases + if (lastClean == null) + { + return true; + } + + // return true if the amount of time between now and the last clean is greater than or equal to the + // clean interval, otherwise return false. + return DateTime.Now - lastClean >= _cleanInterval; + } + catch (Exception) + { + return true; + } + } + + /// + /// Shrinks the cache until the cache size is less than + /// or equal to the size specified (in bytes). This is a + /// rather expensive operation, so use with discretion. + /// + /// The new size of the cache + public long ShrinkCacheToSize(long newSize, string regionName = null) + { + long originalSize = 0, amount = 0, removed = 0; + + //lock down other treads from trying to shrink or clean + using (FileStream cLock = GetCleaningLock()) + { + if (cLock == null) + return -1; + + // if we're shrinking the whole cache, we can use the stored + // size if it's available. If it's not available we calculate it and store + // it for next time. + if (regionName == null) + { + if (CurrentCacheSize == 0) + { + CurrentCacheSize = GetCacheSize(); + } + + originalSize = CurrentCacheSize; + } + else + { + originalSize = GetCacheSize(regionName); + } + + // Find out how much we need to get rid of + amount = originalSize - newSize; + + // CT note: This will update CurrentCacheSize + removed = DeleteOldestFiles(amount, regionName); + + // unlock the semaphore for others + cLock.Close(); + } + + // trigger the event + CacheResized(this, new FileCacheEventArgs(originalSize - removed, MaxCacheSize)); + + // return the final size of the cache (or region) + return originalSize - removed; + } + + public void CleanCacheAsync() + { + Task.Factory.StartNew((Action)(() => + { + CleanCache(); + })); + } + + /// + /// Loop through the cache and delete all expired files + /// + /// The amount removed (in bytes) + public long CleanCache(string regionName = null) + { + long removed = 0; + + //lock down other treads from trying to shrink or clean + using (FileStream cLock = GetCleaningLock()) + { + if (cLock == null) + return 0; + + foreach (string key in GetKeys(regionName)) + { + CacheItemPolicy policy = GetPolicy(key, regionName); + if (policy.AbsoluteExpiration < DateTime.Now) + { + try + { + string cachePath = GetCachePath(key, regionName); + string policyPath = GetPolicyPath(key, regionName); + CacheItemReference ci = new CacheItemReference(key, cachePath, policyPath); + Remove(key, regionName); // CT note: Remove will update CurrentCacheSize + removed += ci.Length; + } + catch (Exception) // skip if the file cannot be accessed + { } + } + } + + // mark that we've cleaned the cache + WriteSysFile(LastCleanedDateFile, DateTime.Now); + + // unlock + cLock.Close(); + } + + return removed; + } + + public void ClearRegion(string regionName) + { + using (var cLock = GetCleaningLock()) + { + if (cLock == null) + return; + + foreach (var key in GetKeys(regionName)) + { + Remove(key, regionName); + } + + cLock.Close(); + } + } + + /// + /// Delete the oldest items in the cache to shrink the chache by the + /// specified amount (in bytes). + /// + /// The amount of data that was actually removed + private long DeleteOldestFiles(long amount, string regionName = null) + { + // Verify that we actually need to shrink + if (amount <= 0) + { + return 0; + } + + //Heap of all CacheReferences + PriortyQueue cacheReferences = new PriortyQueue(); + + //build a heap of all files in cache region + foreach (string key in GetKeys(regionName)) + { + try + { + //build item reference + string cachePath = GetCachePath(key, regionName); + string policyPath = GetPolicyPath(key, regionName); + CacheItemReference ci = new CacheItemReference(key, cachePath, policyPath); + cacheReferences.Enqueue(ci); + } + catch (FileNotFoundException) + { + } + } + + //remove cache items until size requirement is met + long removedBytes = 0; + while (removedBytes < amount && cacheReferences.GetSize() > 0) + { + //remove oldest item + CacheItemReference oldest = cacheReferences.Dequeue(); + removedBytes += oldest.Length; + Remove(oldest.Key, regionName); + } + return removedBytes; + } + + /// + /// This method calls GetCacheSize on a separate thread to + /// calculate and then store the size of the cache. + /// + public void UpdateCacheSizeAsync() + { + Task.Factory.StartNew((Action)(() => + { + CurrentCacheSize = GetCacheSize(); + })); + } + + //AC Note: From MSDN / SO (http://stackoverflow.com/questions/468119/whats-the-best-way-to-calculate-the-size-of-a-directory-in-net) + /// + /// Calculates the size, in bytes of the file cache + /// + /// The region to calculate. If NULL, will return total size. + /// + public long GetCacheSize(string regionName = null) + { + long size = 0; + + //AC note: First parameter is unused, so just pass in garbage ("DummyValue") + string policyPath = Path.GetDirectoryName(GetPolicyPath("DummyValue", regionName)); + string cachePath = Path.GetDirectoryName(GetCachePath("DummyValue", regionName)); + size += CacheSizeHelper(new DirectoryInfo(policyPath)); + size += CacheSizeHelper(new DirectoryInfo(cachePath)); + return size; + } + + /// + /// Helper method for public . + /// + /// + /// + private long CacheSizeHelper(DirectoryInfo root) + { + long size = 0; + + // Add file sizes. + var fis = root.EnumerateFiles(); + foreach (FileInfo fi in fis) + { + size += fi.Length; + } + // Add subdirectory sizes. + var dis = root.EnumerateDirectories(); + foreach (DirectoryInfo di in dis) + { + size += CacheSizeHelper(di); + } + return size; + } + + /// + /// Flushes the file cache using DateTime.Now as the minimum date + /// + /// + public void Flush(string regionName = null) + { + Flush(DateTime.Now, regionName); + } + + /// + /// Flushes the cache based on last access date, filtered by optional region + /// + /// + /// + public void Flush(DateTime minDate, string regionName = null) + { + // prevent other threads from altering stuff while we delete junk + using (FileStream cLock = GetCleaningLock()) + { + if (cLock == null) + return; + + //AC note: First parameter is unused, so just pass in garbage ("DummyValue") + string policyPath = Path.GetDirectoryName(GetPolicyPath("DummyValue", regionName)); + string cachePath = Path.GetDirectoryName(GetCachePath("DummyValue", regionName)); + FlushHelper(new DirectoryInfo(policyPath), minDate); + FlushHelper(new DirectoryInfo(cachePath), minDate); + + // Update the Cache size + CurrentCacheSize = GetCacheSize(); + + // unlock + cLock.Close(); + } + } + + /// + /// Helper method for public flush + /// + /// + /// + private void FlushHelper(DirectoryInfo root, DateTime minDate) + { + // check files. + foreach (FileInfo fi in root.EnumerateFiles()) + { + //is the file stale? + if (minDate > File.GetLastAccessTime(fi.FullName)) + { + File.Delete(fi.FullName); + } + } + + // check subdirectories + foreach (DirectoryInfo di in root.EnumerateDirectories()) + { + FlushHelper(di, minDate); + } + } + + /// + /// Returns the policy attached to a given cache item. + /// + /// The key of the item + /// The region in which the key exists + /// + public CacheItemPolicy GetPolicy(string key, string regionName = null) + { + CacheItemPolicy policy = new CacheItemPolicy(); + FileCachePayload payload = ReadFile(PayloadMode.Filename, key, regionName) as FileCachePayload; + if (payload != null) + { + try + { + policy.SlidingExpiration = payload.Policy.SlidingExpiration; + policy.AbsoluteExpiration = payload.Policy.AbsoluteExpiration; + } + catch (Exception) + { + } + } + return policy; + } + + /// + /// Returns a list of keys for a given region. + /// + /// + /// + public IEnumerable GetKeys(string regionName = null) + { + string region = ""; + if (string.IsNullOrEmpty(regionName) == false) + { + region = regionName; + } + string directory = Path.Combine(CacheDir, _cacheSubFolder, region); + if (Directory.Exists(directory)) + { + foreach (string file in Directory.EnumerateFiles(directory)) + { + yield return Path.GetFileNameWithoutExtension(file); + } + } + } + + #endregion + + #region helper methods + + /// + /// This function servies to centralize file stream access within this class. + /// + /// + /// + /// + /// + /// + private FileStream GetStream(string path, FileMode mode, FileAccess access, FileShare share) + { + FileStream stream = null; + TimeSpan interval = new TimeSpan(0, 0, 0, 0, 50); + TimeSpan totalTime = new TimeSpan(); + while (stream == null) + { + try + { + stream = File.Open(path, mode, access, share); + } + catch (IOException ex) + { + Thread.Sleep(interval); + totalTime += interval; + + //if we've waited too long, throw the original exception. + if (AccessTimeout.Ticks != 0) + { + if (totalTime > AccessTimeout) + { + throw ex; + } + } + } + } + return stream; + } + + /// + /// This function serves to centralize file reads within this class. + /// + /// the payload reading mode + /// + /// + /// + private FileCachePayload ReadFile(PayloadMode mode, string key, string regionName = null, SerializationBinder objectBinder = null) + { + object data = null; + SerializableCacheItemPolicy policy = new SerializableCacheItemPolicy(); + string cachePath = GetCachePath(key, regionName); + string policyPath = GetPolicyPath(key, regionName); + FileCachePayload payload = new FileCachePayload(null); + + if (File.Exists(cachePath)) + { + switch (mode) + { + default: + case PayloadMode.Filename: + data = cachePath; + break; + case PayloadMode.Serializable: + data = DeserializePayloadData(objectBinder, cachePath); + break; + case PayloadMode.RawBytes: + data = LoadRawPayloadData(cachePath); + break; + } + } + if (File.Exists(policyPath)) + { + using (FileStream stream = GetStream(policyPath, FileMode.Open, FileAccess.Read, FileShare.Read)) + { + BinaryFormatter formatter = new BinaryFormatter(); + formatter.Binder = new LocalCacheBinder(); + try + { + policy = formatter.Deserialize(stream) as SerializableCacheItemPolicy; + } + catch (SerializationException) + { + policy = new SerializableCacheItemPolicy(); + } + } + } + payload.Payload = data; + payload.Policy = policy; + return payload; + } + + private object LoadRawPayloadData(string cachePath) + { + throw new NotSupportedException("Reading raw payload is not currently supported."); + } + + private object DeserializePayloadData(SerializationBinder objectBinder, string cachePath) + { + object data; + using (FileStream stream = GetStream(cachePath, FileMode.Open, FileAccess.Read, FileShare.Read)) + { + BinaryFormatter formatter = new BinaryFormatter(); + + //AC: From http://spazzarama.com//2009/06/25/binary-deserialize-unable-to-find-assembly/ + // Needed to deserialize custom objects + if (objectBinder != null) + { + //take supplied binder over default binder + formatter.Binder = objectBinder; + } + else if (_binder != null) + { + formatter.Binder = _binder; + } + + try + { + data = formatter.Deserialize(stream); + } + catch (SerializationException) + { + data = null; + } + } + + return data; + } + + /// + /// This function serves to centralize file writes within this class + /// + private void WriteFile(PayloadMode mode, string key, FileCachePayload data, string regionName = null, bool policyUpdateOnly = false) + { + string cachedPolicy = GetPolicyPath(key, regionName); + string cachedItemPath = GetCachePath(key, regionName); + + + if (!policyUpdateOnly) + { + long oldBlobSize = 0; + if (File.Exists(cachedItemPath)) + { + oldBlobSize = new FileInfo(cachedItemPath).Length; + } + + switch (mode) + { + case PayloadMode.Serializable: + using (FileStream stream = GetStream(cachedItemPath, FileMode.Create, FileAccess.Write, FileShare.None)) + { + + BinaryFormatter formatter = new BinaryFormatter(); + formatter.Serialize(stream, data.Payload); + } + break; + case PayloadMode.RawBytes: + using (FileStream stream = GetStream(cachedItemPath, FileMode.Create, FileAccess.Write, FileShare.None)) + { + + if (data.Payload is byte[]) + { + byte[] dataPayload = (byte[])data.Payload; + stream.Write(dataPayload, 0, dataPayload.Length); + } + else if (data.Payload is Stream) + { + Stream dataPayload = (Stream)data.Payload; + dataPayload.CopyTo(stream); + // no close or the like, we are not the owner + } + } + break; + + case PayloadMode.Filename: + File.Copy((string)data.Payload, cachedItemPath, true); + break; + } + + //adjust cache size (while we have the file to ourselves) + CurrentCacheSize += new FileInfo(cachedItemPath).Length - oldBlobSize; + } + + //remove current policy file from cache size calculations + if (File.Exists(cachedPolicy)) + { + CurrentCacheSize -= new FileInfo(cachedPolicy).Length; + } + + //write the cache policy + using (FileStream stream = GetStream(cachedPolicy, FileMode.Create, FileAccess.Write, FileShare.None)) + { + BinaryFormatter formatter = new BinaryFormatter(); + formatter.Serialize(stream, data.Policy); + + // adjust cache size + CurrentCacheSize += new FileInfo(cachedPolicy).Length; + + stream.Close(); + } + + //check to see if limit was reached + if (CurrentCacheSize > MaxCacheSize) + { + MaxCacheSizeReached(this, new FileCacheEventArgs(CurrentCacheSize, MaxCacheSize)); + } + } + + /// + /// Reads data in from a system file. System files are not part of the + /// cache itself, but serve as a way for the cache to store data it + /// needs to operate. + /// + /// The name of the sysfile (without directory) + /// The data from the file + private object ReadSysFile(string filename) + { + // sys files go in the root directory + string path = Path.Combine(CacheDir, filename); + object data = null; + + if (File.Exists(path)) + { + for (int i = 5; i > 0; i--) // try 5 times to read the file, if we can't, give up + { + try + { + using (FileStream stream = GetStream(path, FileMode.Open, FileAccess.Read, FileShare.Read)) + { + BinaryFormatter formatter = new BinaryFormatter(); + try + { + data = formatter.Deserialize(stream); + } + catch (Exception) + { + data = null; + } + finally + { + stream.Close(); + } + } + break; + } + catch (IOException) + { + // we timed out... so try again + } + } + } + + return data; + } + + /// + /// Writes data to a system file that is not part of the cache itself, + /// but is used to help it function. + /// + /// The name of the sysfile (without directory) + /// The data to write to the file + private void WriteSysFile(string filename, object data) + { + // sys files go in the root directory + string path = Path.Combine(CacheDir, filename); + + // write the data to the file + using (FileStream stream = GetStream(path, FileMode.Create, FileAccess.Write, FileShare.Write)) + { + BinaryFormatter formatter = new BinaryFormatter(); + formatter.Serialize(stream, data); + stream.Close(); + } + } + + /// + /// Builds a string that will place the specified file name within the appropriate + /// cache and workspace folder. + /// + /// + /// + /// + private string GetCachePath(string FileName, string regionName = null) + { + if (regionName == null) + { + regionName = ""; + } + string directory = Path.Combine(CacheDir, _cacheSubFolder, regionName); + string filePath = Path.Combine(directory, Path.GetFileNameWithoutExtension(FileName) + ".dat"); + if (!Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + } + return filePath; + } + + /// + /// Builds a string that will get the path to the supplied file's policy file + /// + /// + /// + /// + private string GetPolicyPath(string FileName, string regionName = null) + { + if (regionName == null) + { + regionName = ""; + } + string directory = Path.Combine(CacheDir, _policySubFolder, regionName); + string filePath = Path.Combine(directory, Path.GetFileNameWithoutExtension(FileName) + ".policy"); + if (!Directory.Exists(directory)) + { + Directory.CreateDirectory(directory); + } + return filePath; + } + + #endregion + + #region ObjectCache overrides + + public override object AddOrGetExisting(string key, object value, CacheItemPolicy policy, string regionName = null) + { + string path = GetCachePath(key, regionName); + object oldData = null; + + //pull old value if it exists + if (File.Exists(path)) + { + try + { + oldData = Get(key, regionName); + } + catch (Exception) + { + oldData = null; + } + } + SerializableCacheItemPolicy cachePolicy = new SerializableCacheItemPolicy(policy); + FileCachePayload newPayload = new FileCachePayload(value, cachePolicy); + WriteFile(PayloadWriteMode, key, newPayload, regionName); + + //As documented in the spec (http://msdn.microsoft.com/en-us/library/dd780602.aspx), return the old + //cached value or null + return oldData; + } + + public override CacheItem AddOrGetExisting(CacheItem value, CacheItemPolicy policy) + { + object oldData = AddOrGetExisting(value.Key, value.Value, policy, value.RegionName); + CacheItem returnItem = null; + if (oldData != null) + { + returnItem = new CacheItem(value.Key) + { + Value = oldData, + RegionName = value.RegionName + }; + } + return returnItem; + } + + public override object AddOrGetExisting(string key, object value, DateTimeOffset absoluteExpiration, string regionName = null) + { + CacheItemPolicy policy = new CacheItemPolicy(); + policy.AbsoluteExpiration = absoluteExpiration; + return AddOrGetExisting(key, value, policy, regionName); + } + + public override bool Contains(string key, string regionName = null) + { + string path = GetCachePath(key, regionName); + return File.Exists(path); + } + + public override CacheEntryChangeMonitor CreateCacheEntryChangeMonitor(IEnumerable keys, string regionName = null) + { + throw new NotImplementedException(); + } + + public override DefaultCacheCapabilities DefaultCacheCapabilities + { + get + { + //AC note: can use boolean OR "|" to set multiple flags. + return DefaultCacheCapabilities.CacheRegions + | + DefaultCacheCapabilities.AbsoluteExpirations + | + DefaultCacheCapabilities.SlidingExpirations + ; + } + } + + public override object Get(string key, string regionName = null) + { + FileCachePayload payload = ReadFile(PayloadReadMode, key, regionName) as FileCachePayload; + string cachedItemPath = GetCachePath(key, regionName); + + DateTime cutoff = DateTime.Now; + if (PayloadReadMode == PayloadMode.Filename) + { + cutoff += FilenameAsPayloadSafetyMargin; + } + + //null payload? + if (payload != null) + { + //did the item expire? + if (payload.Policy.AbsoluteExpiration < cutoff) + { + //set the payload to null + payload.Payload = null; + + //delete the file from the cache + try + { + // CT Note: I changed this to Remove from File.Delete so that the coresponding + // policy file will be deleted as well, and CurrentCacheSize will be updated. + Remove(key, regionName); + } + catch (Exception) + { + } + } + else + { + //does the item have a sliding expiration? + if (payload.Policy.SlidingExpiration > new TimeSpan()) + { + payload.Policy.AbsoluteExpiration = DateTime.Now.Add(payload.Policy.SlidingExpiration); + WriteFile(PayloadWriteMode, cachedItemPath, payload, regionName, true); + } + + } + } + else + { + //remove null payload + Remove(key, regionName); + + //create dummy one for return + payload = new FileCachePayload(null); + } + return payload.Payload; + } + + public override CacheItem GetCacheItem(string key, string regionName = null) + { + object value = Get(key, regionName); + CacheItem item = new CacheItem(key); + item.Value = value; + item.RegionName = regionName; + return item; + } + + public override long GetCount(string regionName = null) + { + if (regionName == null) + { + regionName = ""; + } + string path = Path.Combine(CacheDir, _cacheSubFolder, regionName); + if (Directory.Exists(path)) + return Directory.GetFiles(path).Count(); + else + return 0; + } + + /// + /// Returns an enumerator for the specified region (defaults to base-level cache directory). + /// This function *WILL NOT* recursively locate files in subdirectories. + /// + /// + /// + public IEnumerator> GetEnumerator(string regionName = null) + { + string region = ""; + if (string.IsNullOrEmpty(regionName) == false) + { + region = regionName; + } + List> enumerator = new List>(); + + string directory = Path.Combine(CacheDir, _cacheSubFolder, region); + foreach (string filePath in Directory.EnumerateFiles(directory)) + { + string key = Path.GetFileNameWithoutExtension(filePath); + enumerator.Add(new KeyValuePair(key, this.Get(key, regionName))); + } + return enumerator.GetEnumerator(); + } + + /// + /// Will return an enumerator with all cache items listed in the root file path ONLY. Use the other + /// if you want to specify a region + /// + /// + protected override IEnumerator> GetEnumerator() + { + return GetEnumerator(null); + } + + public override IDictionary GetValues(IEnumerable keys, string regionName = null) + { + Dictionary values = new Dictionary(); + foreach (string key in keys) + { + values[key] = Get(key, regionName); + } + return values; + } + + public override string Name + { + get { return _name; } + } + + public override object Remove(string key, string regionName = null) + { + object valueToDelete = null; + if (Contains(key, regionName)) + { + // Because of the possibility of multiple threads accessing this, it's possible that + // while we're trying to remove something, another thread has already removed it. + try + { + //remove cache entry + // CT note: calling Get from remove leads to an infinite loop and stack overflow, + // so I replaced it with a simple ReadFile call. None of the code here actually + // uses this object returned, but just in case someone else's outside code does. + FileCachePayload fcp = ReadFile(PayloadMode.Filename, key, regionName); + valueToDelete = fcp.Payload; + string path = GetCachePath(key, regionName); + CurrentCacheSize -= new FileInfo(path).Length; + File.Delete(path); + + //remove policy file + string cachedPolicy = GetPolicyPath(key, regionName); + CurrentCacheSize -= new FileInfo(cachedPolicy).Length; + File.Delete(cachedPolicy); + } + catch (IOException) + { + } + + } + return valueToDelete; + } + + public override void Set(string key, object value, CacheItemPolicy policy, string regionName = null) + { + Add(key, value, policy, regionName); + } + + public override void Set(CacheItem item, CacheItemPolicy policy) + { + Add(item, policy); + } + + public override void Set(string key, object value, DateTimeOffset absoluteExpiration, string regionName = null) + { + Add(key, value, absoluteExpiration, regionName); + } + + public override object this[string key] + { + get + { + return this.Get(key, DefaultRegion); + } + set + { + this.Set(key, value, DefaultPolicy, DefaultRegion); + } + } + + #endregion + + private class LocalCacheBinder : System.Runtime.Serialization.SerializationBinder + { + public override Type BindToType(string assemblyName, string typeName) + { + Type typeToDeserialize = null; + + String currentAssembly = Assembly.GetAssembly(typeof(LocalCacheBinder)).FullName; + assemblyName = currentAssembly; + + // Get the type using the typeName and assemblyName + typeToDeserialize = Type.GetType(String.Format("{0}, {1}", + typeName, assemblyName)); + + return typeToDeserialize; + } + } + + // CT: This private class is used to help shrink the cache. + // It computes the total size of an entry including it's policy file. + // It also implements IComparable functionality to allow for sorting based on access time + private class CacheItemReference : IComparable + { + public readonly DateTime LastAccessTime; + public readonly long Length; + public readonly string Key; + + public CacheItemReference(string key, string cachePath, string policyPath) + { + Key = key; + FileInfo cfi = new FileInfo(cachePath); + FileInfo pfi = new FileInfo(policyPath); + cfi.Refresh(); + LastAccessTime = cfi.LastAccessTime; + Length = cfi.Length + pfi.Length; + } + + public int CompareTo(CacheItemReference other) + { + int i = LastAccessTime.CompareTo(other.LastAccessTime); + + // It's possible, although rare, that two different items will have + // the same LastAccessTime. So in that case, we need to check to see + // if they're actually the same. + if (i == 0) + { + // second order should be length (but from smallest to largest, + // that way we delete smaller files first) + i = -1 * Length.CompareTo(other.Length); + if (i == 0) + { + i = Key.CompareTo(other.Key); + } + } + + return i; + } + + public static bool operator >(CacheItemReference lhs, CacheItemReference rhs) + { + if (lhs.CompareTo(rhs) > 0) + { + return true; + } + return false; + } + + public static bool operator <(CacheItemReference lhs, CacheItemReference rhs) + { + if (lhs.CompareTo(rhs) < 0) + { + return true; + } + return false; + } + } + } +} \ No newline at end of file diff --git a/src/GitHub.Api/Caching/FileCacheBinder.cs b/src/GitHub.Api/Caching/FileCacheBinder.cs new file mode 100644 index 0000000000..dd03649603 --- /dev/null +++ b/src/GitHub.Api/Caching/FileCacheBinder.cs @@ -0,0 +1,34 @@ +/* +Copyright 2012, 2013, 2017 Adam Carter (http://adam-carter.com) + +This file is part of FileCache (http://github.com/acarteas/FileCache). + +FileCache is distributed under the Apache License 2.0. +Consult "LICENSE.txt" included in this package for the Apache License 2.0. +*/ +using System.Reflection; + +namespace System.Runtime.Caching +{ + /// + /// You should be able to copy & paste this code into your local project to enable caching custom objects. + /// + public sealed class FileCacheBinder : System.Runtime.Serialization.SerializationBinder + { + public override Type BindToType(string assemblyName, string typeName) + { + Type typeToDeserialize = null; + + String currentAssembly = Assembly.GetExecutingAssembly().FullName; + + // In this case we are always using the current assembly + assemblyName = currentAssembly; + + // Get the type using the typeName and assemblyName + typeToDeserialize = Type.GetType(String.Format("{0}, {1}", + typeName, assemblyName)); + + return typeToDeserialize; + } + } +} \ No newline at end of file diff --git a/src/GitHub.Api/Caching/FileCacheEventArgs.cs b/src/GitHub.Api/Caching/FileCacheEventArgs.cs new file mode 100644 index 0000000000..917ff89e95 --- /dev/null +++ b/src/GitHub.Api/Caching/FileCacheEventArgs.cs @@ -0,0 +1,22 @@ +/* +Copyright 2012, 2013, 2017 Adam Carter (http://adam-carter.com) + +This file is part of FileCache (http://github.com/acarteas/FileCache). + +FileCache is distributed under the Apache License 2.0. +Consult "LICENSE.txt" included in this package for the Apache License 2.0. +*/ + +namespace System.Runtime.Caching +{ + public class FileCacheEventArgs : EventArgs + { + public long CurrentCacheSize { get; private set; } + public long MaxCacheSize { get; private set; } + public FileCacheEventArgs(long currentSize, long maxSize) + { + CurrentCacheSize = currentSize; + MaxCacheSize = maxSize; + } + } +} \ No newline at end of file diff --git a/src/GitHub.Api/Caching/FileCachePayload.cs b/src/GitHub.Api/Caching/FileCachePayload.cs new file mode 100644 index 0000000000..1361e6f663 --- /dev/null +++ b/src/GitHub.Api/Caching/FileCachePayload.cs @@ -0,0 +1,33 @@ +/* +Copyright 2012, 2013, 2017 Adam Carter (http://adam-carter.com) + +This file is part of FileCache (http://github.com/acarteas/FileCache). + +FileCache is distributed under the Apache License 2.0. +Consult "LICENSE.txt" included in this package for the Apache License 2.0. +*/ + +namespace System.Runtime.Caching +{ + [Serializable] + public class FileCachePayload + { + public object Payload { get; set; } + public SerializableCacheItemPolicy Policy { get; set; } + + public FileCachePayload(object payload) + { + Payload = payload; + Policy = new SerializableCacheItemPolicy() + { + AbsoluteExpiration = DateTime.Now.AddYears(10) + }; + } + + public FileCachePayload(object payload, SerializableCacheItemPolicy policy) + { + Payload = payload; + Policy = policy; + } + } +} \ No newline at end of file diff --git a/src/GitHub.Api/Caching/PriortyQueue.cs b/src/GitHub.Api/Caching/PriortyQueue.cs new file mode 100644 index 0000000000..cda897b9d5 --- /dev/null +++ b/src/GitHub.Api/Caching/PriortyQueue.cs @@ -0,0 +1,207 @@ +/* +Copyright 2012, 2013, 2017 Adam Carter (http://adam-carter.com) + +This file is part of FileCache (http://github.com/acarteas/FileCache). + +FileCache is distributed under the Apache License 2.0. +Consult "LICENSE.txt" included in this package for the Apache License 2.0. +*/ +using System.Collections.Generic; + +namespace System.Runtime.Caching +{ + /// + /// A basic min priorty queue (min heap) + /// + /// Data type to store + public class PriortyQueue where T : IComparable + { + + private List _items; + private IComparer _comparer; + + /// + /// Default constructor. + /// + /// The comparer to use. The default comparer will make the smallest item the root of the heap. + /// + /// + public PriortyQueue(IComparer comparer = null) + { + _items = new List(); + if (comparer == null) + { + _comparer = new GenericComparer(); + } + } + + /// + /// Constructor that will convert an existing list into a min heap + /// + /// The unsorted list of items + /// The comparer to use. The default comparer will make the smallest item the root of the heap. + public PriortyQueue(List unsorted, IComparer comparer = null) + : this(comparer) + { + for (int i = 0; i < unsorted.Count; i++) + { + _items.Add(unsorted[i]); + } + BuildHeap(); + } + + private void BuildHeap() + { + for (int i = _items.Count / 2; i >= 0; i--) + { + adjustHeap(i); + } + } + + //Percolates the item specified at by index down into its proper location within a heap. Used + //for dequeue operations and array to heap conversions + private void adjustHeap(int index) + { + //cannot percolate empty list + if (_items.Count == 0) + { + return; + } + + //GOAL: get value at index, make sure this value is less than children + // IF NOT: swap with smaller of two + // (continue to do so until we can't swap) + T item = _items[index]; + + //helps us figure out if a given index has children + int end_location = _items.Count; + + //keeps track of smallest index + int smallest_index = index; + + //while we're not the last thing in the heap + while (index < end_location) + { + //get left child index + int left_child_index = (2 * index) + 1; + int right_child_index = left_child_index + 1; + + //Three cases: + // 1. left index is out of range + // 2. right index is out or range + // 3. both indices are valid + if (left_child_index < end_location) + { + //CASE 1 is FALSE + //remember that left index is the smallest + smallest_index = left_child_index; + + if (right_child_index < end_location) + { + //CASE 2 is FALSE (CASE 3 is true) + //TODO: find value of smallest index + smallest_index = (_comparer.Compare(_items[left_child_index], _items[right_child_index]) < 0) + ? left_child_index + : right_child_index; + } + } + + //we have two things: original index and (potentially) a child index + if (_comparer.Compare(_items[index], _items[smallest_index]) > 0) + { + //move parent down (it was too big) + T temp = _items[index]; + _items[index] = _items[smallest_index]; + _items[smallest_index] = temp; + + //update index + index = smallest_index; + } + else + { + //no swap necessary + break; + } + } + } + + public bool isEmpty() + { + return _items.Count == 0; + } + + public int GetSize() + { + return _items.Count; + } + + + public void Enqueue(T item) + { + //calculate positions + int current_position = _items.Count; + int parent_position = (current_position - 1) / 2; + + //insert element (note: may get erased if we hit the WHILE loop) + _items.Add(item); + + //find parent, but be careful if we are an empty queue + T parent = default(T); + if (parent_position >= 0) + { + //find parent + parent = _items[parent_position]; + + //bubble up until we're done + while (_comparer.Compare(parent, item) > 0 && current_position > 0) + { + //move parent down + _items[current_position] = parent; + + //recalculate position + current_position = parent_position; + parent_position = (current_position - 1) / 2; + + //make sure that we have a valid index + if (parent_position >= 0) + { + //find parent + parent = _items[parent_position]; + } + } + } //end check for nullptr + + //after WHILE loop, current_position will point to the place that + //variable "item" needs to go + _items[current_position] = item; + + } + + public T GetFirst() + { + return _items[0]; + } + + public T Dequeue() + { + int last_position = _items.Count - 1; + T last_item = _items[last_position]; + T top = _items[0]; + _items[0] = last_item; + _items.RemoveAt(_items.Count - 1); + + //percolate down + adjustHeap(0); + return top; + } + + + private class GenericComparer : IComparer where TInner : IComparable + { + public int Compare(TInner x, TInner y) + { + return x.CompareTo(y); + } + } + } +} \ No newline at end of file diff --git a/src/GitHub.Api/Caching/SerializableCacheItemPolicy.cs b/src/GitHub.Api/Caching/SerializableCacheItemPolicy.cs new file mode 100644 index 0000000000..a3a22f5c54 --- /dev/null +++ b/src/GitHub.Api/Caching/SerializableCacheItemPolicy.cs @@ -0,0 +1,44 @@ +/* +Copyright 2012, 2013, 2017 Adam Carter (http://adam-carter.com) + +This file is part of FileCache (http://github.com/acarteas/FileCache). + +FileCache is distributed under the Apache License 2.0. +Consult "LICENSE.txt" included in this package for the Apache License 2.0. +*/ + +namespace System.Runtime.Caching +{ + [Serializable] + public class SerializableCacheItemPolicy + { + public DateTimeOffset AbsoluteExpiration { get; set; } + + private TimeSpan _slidingExpiration; + public TimeSpan SlidingExpiration + { + get + { + return _slidingExpiration; + } + set + { + _slidingExpiration = value; + if (_slidingExpiration > new TimeSpan()) + { + AbsoluteExpiration = DateTimeOffset.Now.Add(_slidingExpiration); + } + } + } + public SerializableCacheItemPolicy(CacheItemPolicy policy) + { + AbsoluteExpiration = policy.AbsoluteExpiration; + SlidingExpiration = policy.SlidingExpiration; + } + + public SerializableCacheItemPolicy() + { + SlidingExpiration = new TimeSpan(); + } + } +} \ No newline at end of file diff --git a/src/GitHub.Api/GitHub.Api.csproj b/src/GitHub.Api/GitHub.Api.csproj index d481115cdc..2d4a991d64 100644 --- a/src/GitHub.Api/GitHub.Api.csproj +++ b/src/GitHub.Api/GitHub.Api.csproj @@ -17,6 +17,7 @@ + diff --git a/src/GitHub.Api/GraphQLClient.cs b/src/GitHub.Api/GraphQLClient.cs new file mode 100644 index 0000000000..155bd2bd07 --- /dev/null +++ b/src/GitHub.Api/GraphQLClient.cs @@ -0,0 +1,157 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Runtime.Caching; +using System.Security.Cryptography; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using GitHub.Extensions; +using Octokit.GraphQL; +using Octokit.GraphQL.Core; + +namespace GitHub.Api +{ + public class GraphQLClient : IGraphQLClient + { + public static readonly TimeSpan DefaultCacheDuration = TimeSpan.FromHours(8); + readonly IConnection connection; + readonly FileCache cache; + + public GraphQLClient( + IConnection connection, + FileCache cache) + { + this.connection = connection; + this.cache = cache; + } + + public Task ClearCache(string regionName) + { + // Switch to background thread because FileCache does not provide an async API. + return Task.Run(() => cache.ClearRegion(GetFullRegionName(regionName))); + } + + public Task Run( + IQueryableValue query, + Dictionary variables = null, + bool refresh = false, + TimeSpan? cacheDuration = null, + string regionName = null, + CancellationToken cancellationToken = default) + { + return Run(query.Compile(), variables, refresh, cacheDuration, regionName, cancellationToken); + } + + public Task> Run( + IQueryableList query, + Dictionary variables = null, + bool refresh = false, + TimeSpan? cacheDuration = null, + string regionName = null, + CancellationToken cancellationToken = default) + { + return Run(query.Compile(), variables, refresh, cacheDuration, regionName, cancellationToken); + } + + public async Task Run( + ICompiledQuery query, + Dictionary variables = null, + bool refresh = false, + TimeSpan? cacheDuration = null, + string regionName = null, + CancellationToken cancellationToken = default) + { + if (!query.IsMutation) + { + var wrapper = new CachingWrapper( + this, + refresh, + cacheDuration ?? DefaultCacheDuration, + GetFullRegionName(regionName)); + return await wrapper.Run(query, variables, cancellationToken); + } + else + { + return await connection.Run(query, variables, cancellationToken); + } + } + + string GetFullRegionName(string regionName) + { + var result = connection.Uri.Host; + + if (!string.IsNullOrWhiteSpace(regionName)) + { + result += Path.DirectorySeparatorChar + regionName; + } + + return result.EnsureValidPath(); + } + + static string GetHash(string input) + { + var sb = new StringBuilder(); + + using (var hash = SHA256.Create()) + { + var result = hash.ComputeHash(Encoding.UTF8.GetBytes(input)); + + foreach (var b in result) + { + sb.Append(b.ToString("x2", CultureInfo.InvariantCulture)); + } + } + + return sb.ToString(); + } + + class CachingWrapper : IConnection + { + readonly GraphQLClient owner; + readonly bool refresh; + readonly TimeSpan cacheDuration; + readonly string regionName; + + public CachingWrapper( + GraphQLClient owner, + bool refresh, + TimeSpan cacheDuration, + string regionName) + { + this.owner = owner; + this.refresh = refresh; + this.cacheDuration = cacheDuration; + this.regionName = regionName; + } + + public Uri Uri => owner.connection.Uri; + + public Task Run(string query, CancellationToken cancellationToken = default) + { + // Switch to background thread because FileCache does not provide an async API. + return Task.Run(async () => + { + var hash = GetHash(query); + + if (refresh) + { + owner.cache.Remove(hash, regionName); + } + + var data = (string) owner.cache.Get(hash, regionName); + + if (data != null) + { + return data; + } + + var result = await owner.connection.Run(query, cancellationToken); + owner.cache.Add(hash, result, DateTimeOffset.Now + cacheDuration, regionName); + return result; + }, cancellationToken); + } + } + } +} diff --git a/src/GitHub.Api/GraphQLClientFactory.cs b/src/GitHub.Api/GraphQLClientFactory.cs index cd91295935..635467a845 100644 --- a/src/GitHub.Api/GraphQLClientFactory.cs +++ b/src/GitHub.Api/GraphQLClientFactory.cs @@ -1,6 +1,9 @@ using System; using System.ComponentModel.Composition; +using System.IO; +using System.Runtime.Caching; using System.Threading.Tasks; +using GitHub.Info; using GitHub.Models; using GitHub.Primitives; using Octokit.GraphQL; @@ -17,6 +20,7 @@ public class GraphQLClientFactory : IGraphQLClientFactory { readonly IKeychain keychain; readonly IProgram program; + readonly FileCache cache; /// /// Initializes a new instance of the class. @@ -28,14 +32,21 @@ public GraphQLClientFactory(IKeychain keychain, IProgram program) { this.keychain = keychain; this.program = program; + + var cachePath = Path.Combine( + Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), + ApplicationInfo.ApplicationName, + "GraphQLCache"); + cache = new FileCache(cachePath); } /// - public Task CreateConnection(HostAddress address) + public Task CreateConnection(HostAddress address) { var credentials = new GraphQLKeychainCredentialStore(keychain, address); var header = new ProductHeaderValue(program.ProductHeader.Name, program.ProductHeader.Version); - return Task.FromResult(new Connection(header, address.GraphQLUri, credentials)); + var connection = new Connection(header, address.GraphQLUri, credentials); + return Task.FromResult(new GraphQLClient(connection, cache)); } } } diff --git a/src/GitHub.Api/IGraphQLClient.cs b/src/GitHub.Api/IGraphQLClient.cs new file mode 100644 index 0000000000..d45062c6b4 --- /dev/null +++ b/src/GitHub.Api/IGraphQLClient.cs @@ -0,0 +1,38 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Octokit.GraphQL; +using Octokit.GraphQL.Core; + +namespace GitHub.Api +{ + public interface IGraphQLClient + { + Task ClearCache(string regionName); + + Task Run( + IQueryableValue query, + Dictionary variables = null, + bool refresh = false, + TimeSpan? cacheDuration = null, + string regionName = null, + CancellationToken cancellationToken = default); + + Task> Run( + IQueryableList query, + Dictionary variables = null, + bool refresh = false, + TimeSpan? cacheDuration = null, + string regionName = null, + CancellationToken cancellationToken = default); + + Task Run( + ICompiledQuery query, + Dictionary variables = null, + bool refresh = false, + TimeSpan? cacheDuration = null, + string regionName = null, + CancellationToken cancellationToken = default); + } +} \ No newline at end of file diff --git a/src/GitHub.Api/IGraphQLClientFactory.cs b/src/GitHub.Api/IGraphQLClientFactory.cs index 464fab0de8..f29fba4b7f 100644 --- a/src/GitHub.Api/IGraphQLClientFactory.cs +++ b/src/GitHub.Api/IGraphQLClientFactory.cs @@ -4,16 +4,15 @@ namespace GitHub.Api { /// - /// Creates GraphQL s for querying the - /// GitHub GraphQL API. + /// Creates s for querying the GitHub GraphQL API. /// public interface IGraphQLClientFactory { /// - /// Creates a new . + /// Creates a new . /// /// The address of the server. - /// A task returning the created connection. - Task CreateConnection(HostAddress address); + /// A task returning the created client. + Task CreateConnection(HostAddress address); } } \ No newline at end of file diff --git a/src/GitHub.App/SampleData/Dialog/Clone/SelectPageViewModelDesigner.cs b/src/GitHub.App/SampleData/Dialog/Clone/SelectPageViewModelDesigner.cs index 82df1bfae9..14ad9d514e 100644 --- a/src/GitHub.App/SampleData/Dialog/Clone/SelectPageViewModelDesigner.cs +++ b/src/GitHub.App/SampleData/Dialog/Clone/SelectPageViewModelDesigner.cs @@ -2,11 +2,13 @@ using System.Collections.Generic; using System.ComponentModel; using System.Linq; +using System.Reactive; using System.Threading.Tasks; using System.Windows.Data; using GitHub.Models; using GitHub.ViewModels; using GitHub.ViewModels.Dialog.Clone; +using ReactiveUI; namespace GitHub.SampleData.Dialog.Clone { diff --git a/src/GitHub.App/Services/PullRequestService.cs b/src/GitHub.App/Services/PullRequestService.cs index 2e2eaa0239..7c143409ba 100644 --- a/src/GitHub.App/Services/PullRequestService.cs +++ b/src/GitHub.App/Services/PullRequestService.cs @@ -218,7 +218,8 @@ public async Task> ReadPullRequests( { nameof(states), states.Select(x => (Octokit.GraphQL.Model.PullRequestState)x).ToList() }, }; - var result = await graphql.Run(query, vars); + var region = owner + '/' + name + "/pr-list"; + var result = await graphql.Run(query, vars, regionName: region); foreach (var item in result.Items.Cast()) { @@ -293,6 +294,14 @@ public async Task> ReadPullRequests( return result; } + public async Task ClearPullRequestsCache(HostAddress address, string owner, string name) + { + var region = owner + '/' + name + "/pr-list"; + var graphql = await graphqlFactory.CreateConnection(address); + + await graphql.ClearCache(region); + } + public async Task> ReadAssignableUsers( HostAddress address, string owner, @@ -325,7 +334,7 @@ public async Task> ReadAssignableUsers( { nameof(after), after }, }; - return await graphql.Run(readAssignableUsers, vars); + return await graphql.Run(readAssignableUsers, vars, cacheDuration: TimeSpan.FromHours(1)); } public IObservable CreatePullRequest(IModelService modelService, diff --git a/src/GitHub.App/Services/RepositoryCloneService.cs b/src/GitHub.App/Services/RepositoryCloneService.cs index 9a27bf7dcd..7fc5e6e3e4 100644 --- a/src/GitHub.App/Services/RepositoryCloneService.cs +++ b/src/GitHub.App/Services/RepositoryCloneService.cs @@ -64,7 +64,7 @@ public RepositoryCloneService( } /// - public async Task ReadViewerRepositories(HostAddress address) + public async Task ReadViewerRepositories(HostAddress address, bool refresh = false) { if (readViewerRepositories == null) { @@ -107,7 +107,7 @@ public async Task ReadViewerRepositories(HostAddress ad } var graphql = await graphqlFactory.CreateConnection(address).ConfigureAwait(false); - var result = await graphql.Run(readViewerRepositories).ConfigureAwait(false); + var result = await graphql.Run(readViewerRepositories, cacheDuration: TimeSpan.FromHours(1), refresh: refresh).ConfigureAwait(false); return result; } diff --git a/src/GitHub.App/ViewModels/Dialog/Clone/RepositorySelectViewModel.cs b/src/GitHub.App/ViewModels/Dialog/Clone/RepositorySelectViewModel.cs index 05845cb82e..bf45c4ec78 100644 --- a/src/GitHub.App/ViewModels/Dialog/Clone/RepositorySelectViewModel.cs +++ b/src/GitHub.App/ViewModels/Dialog/Clone/RepositorySelectViewModel.cs @@ -4,6 +4,7 @@ using System.ComponentModel.Composition; using System.Globalization; using System.Linq; +using System.Reactive; using System.Reactive.Linq; using System.Threading.Tasks; using System.Windows.Data; @@ -30,7 +31,6 @@ public class RepositorySelectViewModel : ViewModelBase, IRepositorySelectViewMod string filter; bool isEnabled; bool isLoading; - bool loadingStarted; IReadOnlyList items; ICollectionView itemsView; ObservableAsPropertyHelper repository; @@ -113,16 +113,37 @@ public void Initialize(IConnection connection) public async Task Activate() { - if (connection == null || loadingStarted) return; + await this.LoadItems(true); + } + + static string GroupName(KeyValuePair> group, int max) + { + var name = group.Key; + if (group.Value.Count == max) + { + name += $" ({string.Format(CultureInfo.InvariantCulture, Resources.MostRecentlyPushed, max)})"; + } + + return name; + } + + async Task LoadItems(bool refresh) + { + if (connection == null && !IsLoading) return; Error = null; IsLoading = true; - loadingStarted = true; try { + if (refresh) + { + Items = new List(); + ItemsView = CollectionViewSource.GetDefaultView(Items); + } + var results = await log.TimeAsync(nameof(service.ReadViewerRepositories), - () => service.ReadViewerRepositories(connection.HostAddress)); + () => service.ReadViewerRepositories(connection.HostAddress, refresh)); var yourRepositories = results.Repositories .Where(r => r.Owner == results.Owner) @@ -163,17 +184,6 @@ public async Task Activate() } } - static string GroupName(KeyValuePair> group, int max) - { - var name = group.Key; - if (group.Value.Count == max) - { - name += $" ({string.Format(CultureInfo.InvariantCulture, Resources.MostRecentlyPushed, max)})"; - } - - return name; - } - bool FilterItem(object obj) { if (obj is IRepositoryItemViewModel item && !string.IsNullOrWhiteSpace(Filter)) diff --git a/src/GitHub.App/ViewModels/GitHubPane/IssueListViewModelBase.cs b/src/GitHub.App/ViewModels/GitHubPane/IssueListViewModelBase.cs index 9525991fd0..271b88ab32 100644 --- a/src/GitHub.App/ViewModels/GitHubPane/IssueListViewModelBase.cs +++ b/src/GitHub.App/ViewModels/GitHubPane/IssueListViewModelBase.cs @@ -1,14 +1,12 @@ using System; using System.Collections.Generic; using System.ComponentModel; -using System.Globalization; using System.IO; using System.Linq; using System.Reactive; using System.Reactive.Disposables; using System.Reactive.Linq; using System.Threading.Tasks; -using System.Windows.Threading; using GitHub.Collections; using GitHub.Extensions; using GitHub.Extensions.Reactive; @@ -154,21 +152,21 @@ public async Task InitializeAsync(LocalRepositoryModel repository, IConnection c Forks = new RepositoryModel[] { - RemoteRepository, - repository, + RemoteRepository, + repository, }; } this.WhenAnyValue(x => x.SelectedState, x => x.RemoteRepository) .Skip(1) - .Subscribe(_ => Refresh().Forget()); + .Subscribe(_ => InitializeItemSource(false).Forget()); Observable.Merge( this.WhenAnyValue(x => x.SearchQuery).Skip(1).SelectUnit(), AuthorFilter.WhenAnyValue(x => x.Selected).Skip(1).SelectUnit()) .Subscribe(_ => FilterChanged()); - await Refresh(); + await InitializeItemSource(true); } catch (Exception ex) { @@ -182,18 +180,43 @@ public async Task InitializeAsync(LocalRepositoryModel repository, IConnection c /// Refreshes the view model. /// /// A task tracking the operation. - public override Task Refresh() + public override Task Refresh() => InitializeItemSource(true); + + /// + /// When overridden in a derived class, creates the + /// that will act as the source for . + /// + /// + /// Whether the item source is being created due to being called. + /// + protected abstract Task> CreateItemSource(bool refresh); + + /// + /// When overridden in a derived class, navigates to the specified item. + /// + /// The item. + /// A task tracking the operation. + protected abstract Task DoOpenItem(IIssueListItemViewModelBase item); + + /// + /// Loads a page of authors for the . + /// + /// The GraphQL "after" cursor. + /// A task that returns a page of authors. + protected abstract Task> LoadAuthors(string after); + + async Task InitializeItemSource(bool refresh) { if (RemoteRepository == null) { // If an exception occurred reading the parent repository, do nothing. - return Task.CompletedTask; + return; } subscription?.Dispose(); var dispose = new CompositeDisposable(); - var itemSource = CreateItemSource(); + var itemSource = await CreateItemSource(refresh); var items = new VirtualizingList(itemSource, null); var view = new VirtualizingListCollectionView(items); @@ -219,30 +242,8 @@ public override Task Refresh() x => items.InitializationError -= x) .Subscribe(x => Error = x.EventArgs.GetException())); subscription = dispose; - - return Task.CompletedTask; } - /// - /// When overridden in a derived class, creates the - /// that will act as the source for . - /// - protected abstract IVirtualizingListSource CreateItemSource(); - - /// - /// When overridden in a derived class, navigates to the specified item. - /// - /// The item. - /// A task tracking the operation. - protected abstract Task DoOpenItem(IIssueListItemViewModelBase item); - - /// - /// Loads a page of authors for the . - /// - /// The GraphQL "after" cursor. - /// A task that returns a page of authors. - protected abstract Task> LoadAuthors(string after); - void FilterChanged() { if (!string.IsNullOrWhiteSpace(SearchQuery)) diff --git a/src/GitHub.App/ViewModels/GitHubPane/PullRequestListViewModel.cs b/src/GitHub.App/ViewModels/GitHubPane/PullRequestListViewModel.cs index d81467e938..23890674c8 100644 --- a/src/GitHub.App/ViewModels/GitHubPane/PullRequestListViewModel.cs +++ b/src/GitHub.App/ViewModels/GitHubPane/PullRequestListViewModel.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.ComponentModel.Composition; -using System.Diagnostics; using System.Reactive; using System.Reactive.Linq; using System.Threading.Tasks; @@ -69,8 +68,16 @@ public PullRequestListViewModel( public ReactiveCommand OpenItemInBrowser { get; } /// - protected override IVirtualizingListSource CreateItemSource() + protected override async Task> CreateItemSource(bool refresh) { + if (refresh) + { + await service.ClearPullRequestsCache( + HostAddress.Create(RemoteRepository.CloneUrl), + RemoteRepository.Owner, + RemoteRepository.Name); + } + return new ItemSource(this); } diff --git a/src/GitHub.Exports.Reactive/Services/IPullRequestService.cs b/src/GitHub.Exports.Reactive/Services/IPullRequestService.cs index 9513f0a177..e472ffd47a 100644 --- a/src/GitHub.Exports.Reactive/Services/IPullRequestService.cs +++ b/src/GitHub.Exports.Reactive/Services/IPullRequestService.cs @@ -19,6 +19,7 @@ public interface IPullRequestService : IIssueishService /// The repository name. /// The end cursor of the previous page, or null for the first page. /// The pull request states to filter by + /// Whether the data should be refreshed instead of read from the cache. /// A page of pull request item models. Task> ReadPullRequests( HostAddress address, @@ -27,6 +28,14 @@ Task> ReadPullRequests( string after, PullRequestState[] states); + /// + /// Clears the cache for . + /// + /// The host address. + /// The repository owner. + /// The repository name. + Task ClearPullRequestsCache(HostAddress address, string owner, string name); + /// /// Reads a page of users that can be assigned to pull requests. /// diff --git a/src/GitHub.Exports.Reactive/Services/IRepositoryCloneService.cs b/src/GitHub.Exports.Reactive/Services/IRepositoryCloneService.cs index 365902c98a..152a203d60 100644 --- a/src/GitHub.Exports.Reactive/Services/IRepositoryCloneService.cs +++ b/src/GitHub.Exports.Reactive/Services/IRepositoryCloneService.cs @@ -67,6 +67,6 @@ Task CloneOrOpenRepository( /// bool DestinationFileExists(string path); - Task ReadViewerRepositories(HostAddress address); + Task ReadViewerRepositories(HostAddress address, bool refresh = false); } } diff --git a/src/GitHub.Exports.Reactive/ViewModels/Dialog/Clone/IRepositorySelectViewModel.cs b/src/GitHub.Exports.Reactive/ViewModels/Dialog/Clone/IRepositorySelectViewModel.cs index e59837d972..b8cf181f63 100644 --- a/src/GitHub.Exports.Reactive/ViewModels/Dialog/Clone/IRepositorySelectViewModel.cs +++ b/src/GitHub.Exports.Reactive/ViewModels/Dialog/Clone/IRepositorySelectViewModel.cs @@ -1,8 +1,10 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Reactive; using System.Threading.Tasks; using GitHub.Models; +using ReactiveUI; namespace GitHub.ViewModels.Dialog.Clone { diff --git a/src/GitHub.Extensions/StringExtensions.cs b/src/GitHub.Extensions/StringExtensions.cs index 9e3c649e25..be59463011 100644 --- a/src/GitHub.Extensions/StringExtensions.cs +++ b/src/GitHub.Extensions/StringExtensions.cs @@ -109,6 +109,26 @@ public static string EnsureEndsWith(this string s, char c) return s.TrimEnd(c) + c; } + public static string EnsureValidPath(this string path) + { + if (string.IsNullOrEmpty(path)) return null; + + var components = path.Split(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar); + var result = new StringBuilder(); + + foreach (var component in components) + { + if (result.Length > 0) + { + result.Append(Path.DirectorySeparatorChar); + } + + result.Append(CoerceValidFileName(component)); + } + + return result.ToString(); + } + public static string NormalizePath(this string path) { if (String.IsNullOrEmpty(path)) return null; @@ -243,5 +263,33 @@ public static string GetSha256Hash(this string input) return string.Join("", hash.Select(b => b.ToString("x2", CultureInfo.InvariantCulture))); } } + + /// + /// Strip illegal chars and reserved words from a candidate filename (should not include the directory path) + /// + /// + /// http://stackoverflow.com/questions/309485/c-sharp-sanitize-file-name + /// + static string CoerceValidFileName(string filename) + { + var invalidChars = Regex.Escape(new string(Path.GetInvalidFileNameChars())); + var invalidReStr = string.Format(CultureInfo.InvariantCulture, @"[{0}]+", invalidChars); + + var reservedWords = new[] + { + "CON", "PRN", "AUX", "CLOCK$", "NUL", "COM0", "COM1", "COM2", "COM3", "COM4", + "COM5", "COM6", "COM7", "COM8", "COM9", "LPT0", "LPT1", "LPT2", "LPT3", "LPT4", + "LPT5", "LPT6", "LPT7", "LPT8", "LPT9" + }; + + var sanitisedNamePart = Regex.Replace(filename, invalidReStr, "_"); + foreach (var reservedWord in reservedWords) + { + var reservedWordPattern = string.Format(CultureInfo.InvariantCulture, "^{0}\\.", reservedWord); + sanitisedNamePart = Regex.Replace(sanitisedNamePart, reservedWordPattern, "_reservedWord_.", RegexOptions.IgnoreCase); + } + + return sanitisedNamePart; + } } } diff --git a/src/GitHub.InlineReviews/Services/IPullRequestSessionService.cs b/src/GitHub.InlineReviews/Services/IPullRequestSessionService.cs index 88d11c3866..608ccf0932 100644 --- a/src/GitHub.InlineReviews/Services/IPullRequestSessionService.cs +++ b/src/GitHub.InlineReviews/Services/IPullRequestSessionService.cs @@ -156,8 +156,14 @@ Task ExtractFileFromGit( /// The repository owner. /// The repository name. /// The pull request number. + /// Whether the data should be refreshed instead of read from the cache. /// A task returning the pull request model. - Task ReadPullRequestDetail(HostAddress address, string owner, string name, int number); + Task ReadPullRequestDetail( + HostAddress address, + string owner, + string name, + int number, + bool refresh = false); /// /// Reads the current viewer for the specified address.. diff --git a/src/GitHub.InlineReviews/Services/PullRequestSession.cs b/src/GitHub.InlineReviews/Services/PullRequestSession.cs index e868b95cbe..783254237e 100644 --- a/src/GitHub.InlineReviews/Services/PullRequestSession.cs +++ b/src/GitHub.InlineReviews/Services/PullRequestSession.cs @@ -268,7 +268,8 @@ public async Task Refresh() address, RepositoryOwner, LocalRepository.Name, - PullRequest.Number); + PullRequest.Number, + true); await Update(model); } diff --git a/src/GitHub.InlineReviews/Services/PullRequestSessionService.cs b/src/GitHub.InlineReviews/Services/PullRequestSessionService.cs index cafb0194ef..368832be71 100644 --- a/src/GitHub.InlineReviews/Services/PullRequestSessionService.cs +++ b/src/GitHub.InlineReviews/Services/PullRequestSessionService.cs @@ -290,21 +290,21 @@ public async Task ReadFileAsync(string path) return null; } - public virtual Task ReadPullRequestDetail(HostAddress address, string owner, string name, int number) + public virtual Task ReadPullRequestDetail(HostAddress address, string owner, string name, int number, bool refresh = false) { // The reviewThreads/isResolved field is only guaranteed to be available on github.com if (address.IsGitHubDotCom()) { - return ReadPullRequestDetailWithResolved(address, owner, name, number); + return ReadPullRequestDetailWithResolved(address, owner, name, number, refresh); } else { - return ReadPullRequestDetailWithoutResolved(address, owner, name, number); + return ReadPullRequestDetailWithoutResolved(address, owner, name, number, refresh); } } - async Task ReadPullRequestDetailWithResolved( - HostAddress address, string owner, string name, int number) + async Task ReadPullRequestDetailWithResolved(HostAddress address, string owner, + string name, int number, bool refresh) { if (readPullRequestWithResolved == null) @@ -424,7 +424,7 @@ async Task ReadPullRequestDetailWithResolved( }; var connection = await graphqlFactory.CreateConnection(address); - var result = await connection.Run(readPullRequestWithResolved, vars); + var result = await connection.Run(readPullRequestWithResolved, vars, refresh); var apiClient = await apiClientFactory.Create(address); @@ -432,7 +432,7 @@ async Task ReadPullRequestDetailWithResolved( async () => await apiClient.GetPullRequestFiles(owner, name, number).ToList()); var lastCommitModel = await log.TimeAsync(nameof(GetPullRequestLastCommitAdapter), - () => GetPullRequestLastCommitAdapter(address, owner, name, number)); + () => GetPullRequestLastCommitAdapter(address, owner, name, number, refresh)); result.Statuses = (IReadOnlyList)lastCommitModel.Statuses ?? Array.Empty(); @@ -488,8 +488,8 @@ async Task ReadPullRequestDetailWithResolved( return result; } - async Task ReadPullRequestDetailWithoutResolved( - HostAddress address, string owner, string name, int number) + async Task ReadPullRequestDetailWithoutResolved(HostAddress address, string owner, + string name, int number, bool refresh) { if (readPullRequestWithoutResolved == null) { @@ -602,7 +602,7 @@ async Task ReadPullRequestDetailWithoutResolved( }; var connection = await graphqlFactory.CreateConnection(address); - var result = await connection.Run(readPullRequestWithoutResolved, vars); + var result = await connection.Run(readPullRequestWithoutResolved, vars, refresh); var apiClient = await apiClientFactory.Create(address); @@ -610,7 +610,7 @@ async Task ReadPullRequestDetailWithoutResolved( async () => await apiClient.GetPullRequestFiles(owner, name, number).ToList()); var lastCommitModel = await log.TimeAsync(nameof(GetPullRequestLastCommitAdapter), - () => GetPullRequestLastCommitAdapter(address, owner, name, number)); + () => GetPullRequestLastCommitAdapter(address, owner, name, number, refresh)); result.Statuses = (IReadOnlyList)lastCommitModel.Statuses ?? Array.Empty(); @@ -708,7 +708,7 @@ public virtual async Task ReadViewer(HostAddress address) } var connection = await graphqlFactory.CreateConnection(address); - return await connection.Run(readViewer); + return await connection.Run(readViewer, cacheDuration: TimeSpan.FromMinutes(10)); } public async Task GetGraphQLPullRequestId( @@ -778,7 +778,7 @@ public async Task CreatePendingReview( var address = HostAddress.Create(localRepository.CloneUrl); var graphql = await graphqlFactory.CreateConnection(address); var (_, owner, number) = await CreatePendingReviewCore(localRepository, pullRequestId); - var detail = await ReadPullRequestDetail(address, owner, localRepository.Name, number); + var detail = await ReadPullRequestDetail(address, owner, localRepository.Name, number, true); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentStartReview); @@ -807,7 +807,7 @@ public async Task CancelPendingReview( }); var result = await graphql.Run(mutation); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } /// @@ -839,7 +839,7 @@ public async Task PostReview( var result = await graphql.Run(mutation); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewPosts); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } public async Task SubmitPendingReview( @@ -868,7 +868,7 @@ public async Task SubmitPendingReview( var result = await graphql.Run(mutation); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewPosts); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } /// @@ -902,7 +902,7 @@ public async Task PostPendingReviewComment( var result = await graphql.Run(addComment); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentPost); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } /// @@ -932,7 +932,7 @@ public async Task PostPendingReviewCommentReply( var result = await graphql.Run(addComment); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentPost); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } /// @@ -973,7 +973,7 @@ public async Task PostStandaloneReviewComment( var result = await graphql.Run(mutation); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentPost); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } /// @@ -1004,7 +1004,7 @@ await apiClient.DeletePullRequestReviewComment( commentDatabaseId); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentDelete); - return await ReadPullRequestDetail(address, remoteRepositoryOwner, localRepository.Name, pullRequestId); + return await ReadPullRequestDetail(address, remoteRepositoryOwner, localRepository.Name, pullRequestId, true); } /// @@ -1031,7 +1031,7 @@ public async Task EditComment(LocalRepositoryModel local var result = await graphql.Run(editComment); await usageTracker.IncrementCounter(x => x.NumberOfPRReviewDiffViewInlineCommentPost); - return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number); + return await ReadPullRequestDetail(address, result.Login, localRepository.Name, result.Number, true); } async Task<(string id, string owner, int number)> CreatePendingReviewCore(LocalRepositoryModel localRepository, string pullRequestId) @@ -1076,7 +1076,7 @@ Task GetRepository(LocalRepositoryModel repository) return Task.Run(() => gitService.GetRepository(repository.LocalPath)); } - async Task GetPullRequestLastCommitAdapter(HostAddress address, string owner, string name, int number) + async Task GetPullRequestLastCommitAdapter(HostAddress address, string owner, string name, int number, bool refresh) { ICompiledQuery> query; if (address.IsGitHubDotCom()) @@ -1167,7 +1167,7 @@ async Task GetPullRequestLastCommitAdapter(HostAddress addres }; var connection = await graphqlFactory.CreateConnection(address); - var result = await connection.Run(query, vars); + var result = await connection.Run(query, vars, refresh); return result.First(); } diff --git a/test/GitHub.App.UnitTests/ViewModels/Dialog/Clone/RepositorySelectViewModelTests.cs b/test/GitHub.App.UnitTests/ViewModels/Dialog/Clone/RepositorySelectViewModelTests.cs index 110a9a29bb..5bcb1aed3b 100644 --- a/test/GitHub.App.UnitTests/ViewModels/Dialog/Clone/RepositorySelectViewModelTests.cs +++ b/test/GitHub.App.UnitTests/ViewModels/Dialog/Clone/RepositorySelectViewModelTests.cs @@ -113,7 +113,7 @@ static IRepositoryCloneService CreateRepositoryCloneService( var viewRepositoriesModel = CreateViewerRepositoriesModel(contributedToRepositories: contributedToRepositories); var repositoryCloneService = Substitute.For(); - repositoryCloneService.ReadViewerRepositories(hostAddress).Returns(viewRepositoriesModel); + repositoryCloneService.ReadViewerRepositories(hostAddress, Arg.Any()).Returns(viewRepositoriesModel); return repositoryCloneService; } diff --git a/test/GitHub.App.UnitTests/ViewModels/GitHubPane/IssueListViewModelBaseTests.cs b/test/GitHub.App.UnitTests/ViewModels/GitHubPane/IssueListViewModelBaseTests.cs index 2b274342c4..4d036c5e4e 100644 --- a/test/GitHub.App.UnitTests/ViewModels/GitHubPane/IssueListViewModelBaseTests.cs +++ b/test/GitHub.App.UnitTests/ViewModels/GitHubPane/IssueListViewModelBaseTests.cs @@ -1,8 +1,13 @@ using System; using System.Collections.Generic; +using System.Globalization; using System.Linq; +using System.Reactive; using System.Reactive.Linq; +using System.Reactive.Subjects; +using System.Reactive.Threading.Tasks; using System.Threading.Tasks; +using GitHub; using GitHub.Collections; using GitHub.Models; using GitHub.Primitives; @@ -172,7 +177,8 @@ public Target(IRepositoryService repositoryService, int itemCount) public override IReadOnlyList States { get; } = new[] { "Open", "Closed" }; - protected override IVirtualizingListSource CreateItemSource() => ItemSource; + protected override Task> CreateItemSource(bool refresh) + => Task.FromResult(ItemSource); protected override Task DoOpenItem(IIssueListItemViewModelBase item) { diff --git a/test/GitHub.InlineReviews.UnitTests/Services/PullRequestSessionTests.cs b/test/GitHub.InlineReviews.UnitTests/Services/PullRequestSessionTests.cs index a3189bc39f..1cac6eea6a 100644 --- a/test/GitHub.InlineReviews.UnitTests/Services/PullRequestSessionTests.cs +++ b/test/GitHub.InlineReviews.UnitTests/Services/PullRequestSessionTests.cs @@ -813,7 +813,8 @@ static void UpdateReadPullRequest(IPullRequestSessionService service, PullReques Arg.Any(), Arg.Any(), Arg.Any(), - Arg.Any()).Returns(pullRequest); + Arg.Any(), + Arg.Any()).Returns(pullRequest); } } }