IntervalTree.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. namespace Ryujinx.Common.Collections
  5. {
  6. /// <summary>
  7. /// An Augmented Interval Tree based off of the "TreeDictionary"'s Red-Black Tree. Allows fast overlap checking of ranges.
  8. /// </summary>
  9. /// <typeparam name="K">Key</typeparam>
  10. /// <typeparam name="V">Value</typeparam>
  11. public class IntervalTree<K, V> : IntrusiveRedBlackTreeImpl<IntervalTreeNode<K, V>> where K : IComparable<K>
  12. {
  13. private const int ArrayGrowthSize = 32;
  14. #region Public Methods
  15. /// <summary>
  16. /// Gets the values of the interval whose key is <paramref name="key"/>.
  17. /// </summary>
  18. /// <param name="key">Key of the node value to get</param>
  19. /// <param name="overlaps">Overlaps array to place results in</param>
  20. /// <returns>Number of values found</returns>
  21. /// <exception cref="ArgumentNullException"><paramref name="key"/> is null</exception>
  22. public int Get(K key, ref V[] overlaps)
  23. {
  24. if (key == null)
  25. {
  26. throw new ArgumentNullException(nameof(key));
  27. }
  28. IntervalTreeNode<K, V> node = GetNode(key);
  29. if (node == null)
  30. {
  31. return 0;
  32. }
  33. if (node.Values.Count > overlaps.Length)
  34. {
  35. Array.Resize(ref overlaps, node.Values.Count);
  36. }
  37. int overlapsCount = 0;
  38. foreach (RangeNode<K, V> value in node.Values)
  39. {
  40. overlaps[overlapsCount++] = value.Value;
  41. }
  42. return overlapsCount;
  43. }
  44. /// <summary>
  45. /// Returns the values of the intervals whose start and end keys overlap the given range.
  46. /// </summary>
  47. /// <param name="start">Start of the range</param>
  48. /// <param name="end">End of the range</param>
  49. /// <param name="overlaps">Overlaps array to place results in</param>
  50. /// <param name="overlapCount">Index to start writing results into the array. Defaults to 0</param>
  51. /// <returns>Number of values found</returns>
  52. /// <exception cref="ArgumentNullException"><paramref name="start"/> or <paramref name="end"/> is null</exception>
  53. public int Get(K start, K end, ref V[] overlaps, int overlapCount = 0)
  54. {
  55. if (start == null)
  56. {
  57. throw new ArgumentNullException(nameof(start));
  58. }
  59. if (end == null)
  60. {
  61. throw new ArgumentNullException(nameof(end));
  62. }
  63. GetValues(Root, start, end, ref overlaps, ref overlapCount);
  64. return overlapCount;
  65. }
  66. /// <summary>
  67. /// Adds a new interval into the tree whose start is <paramref name="start"/>, end is <paramref name="end"/> and value is <paramref name="value"/>.
  68. /// </summary>
  69. /// <param name="start">Start of the range to add</param>
  70. /// <param name="end">End of the range to insert</param>
  71. /// <param name="value">Value to add</param>
  72. /// <exception cref="ArgumentNullException"><paramref name="start"/>, <paramref name="end"/> or <paramref name="value"/> are null</exception>
  73. public void Add(K start, K end, V value)
  74. {
  75. if (start == null)
  76. {
  77. throw new ArgumentNullException(nameof(start));
  78. }
  79. if (end == null)
  80. {
  81. throw new ArgumentNullException(nameof(end));
  82. }
  83. if (value == null)
  84. {
  85. throw new ArgumentNullException(nameof(value));
  86. }
  87. Insert(start, end, value);
  88. }
  89. /// <summary>
  90. /// Removes the given <paramref name="value"/> from the tree, searching for it with <paramref name="key"/>.
  91. /// </summary>
  92. /// <param name="key">Key of the node to remove</param>
  93. /// <param name="value">Value to remove</param>
  94. /// <exception cref="ArgumentNullException"><paramref name="key"/> is null</exception>
  95. /// <returns>Number of deleted values</returns>
  96. public int Remove(K key, V value)
  97. {
  98. if (key == null)
  99. {
  100. throw new ArgumentNullException(nameof(key));
  101. }
  102. int removed = Delete(key, value);
  103. Count -= removed;
  104. return removed;
  105. }
  106. /// <summary>
  107. /// Adds all the nodes in the dictionary into <paramref name="list"/>.
  108. /// </summary>
  109. /// <returns>A list of all RangeNodes sorted by Key Order</returns>
  110. public List<RangeNode<K, V>> AsList()
  111. {
  112. List<RangeNode<K, V>> list = new List<RangeNode<K, V>>();
  113. AddToList(Root, list);
  114. return list;
  115. }
  116. #endregion
  117. #region Private Methods (BST)
  118. /// <summary>
  119. /// Adds all RangeNodes that are children of or contained within <paramref name="node"/> into <paramref name="list"/>, in Key Order.
  120. /// </summary>
  121. /// <param name="node">The node to search for RangeNodes within</param>
  122. /// <param name="list">The list to add RangeNodes to</param>
  123. private void AddToList(IntervalTreeNode<K, V> node, List<RangeNode<K, V>> list)
  124. {
  125. if (node == null)
  126. {
  127. return;
  128. }
  129. AddToList(node.Left, list);
  130. list.AddRange(node.Values);
  131. AddToList(node.Right, list);
  132. }
  133. /// <summary>
  134. /// Retrieve the node reference whose key is <paramref name="key"/>, or null if no such node exists.
  135. /// </summary>
  136. /// <param name="key">Key of the node to get</param>
  137. /// <returns>Node reference in the tree</returns>
  138. /// <exception cref="ArgumentNullException"><paramref name="key"/> is null</exception>
  139. private IntervalTreeNode<K, V> GetNode(K key)
  140. {
  141. if (key == null)
  142. {
  143. throw new ArgumentNullException(nameof(key));
  144. }
  145. IntervalTreeNode<K, V> node = Root;
  146. while (node != null)
  147. {
  148. int cmp = key.CompareTo(node.Start);
  149. if (cmp < 0)
  150. {
  151. node = node.Left;
  152. }
  153. else if (cmp > 0)
  154. {
  155. node = node.Right;
  156. }
  157. else
  158. {
  159. return node;
  160. }
  161. }
  162. return null;
  163. }
  164. /// <summary>
  165. /// Retrieve all values that overlap the given start and end keys.
  166. /// </summary>
  167. /// <param name="start">Start of the range</param>
  168. /// <param name="end">End of the range</param>
  169. /// <param name="overlaps">Overlaps array to place results in</param>
  170. /// <param name="overlapCount">Overlaps count to update</param>
  171. private void GetValues(IntervalTreeNode<K, V> node, K start, K end, ref V[] overlaps, ref int overlapCount)
  172. {
  173. if (node == null || start.CompareTo(node.Max) >= 0)
  174. {
  175. return;
  176. }
  177. GetValues(node.Left, start, end, ref overlaps, ref overlapCount);
  178. bool endsOnRight = end.CompareTo(node.Start) > 0;
  179. if (endsOnRight)
  180. {
  181. if (start.CompareTo(node.End) < 0)
  182. {
  183. // Contains this node. Add overlaps to list.
  184. foreach (RangeNode<K,V> overlap in node.Values)
  185. {
  186. if (start.CompareTo(overlap.End) < 0)
  187. {
  188. if (overlaps.Length >= overlapCount)
  189. {
  190. Array.Resize(ref overlaps, overlapCount + ArrayGrowthSize);
  191. }
  192. overlaps[overlapCount++] = overlap.Value;
  193. }
  194. }
  195. }
  196. GetValues(node.Right, start, end, ref overlaps, ref overlapCount);
  197. }
  198. }
  199. /// <summary>
  200. /// Inserts a new node into the tree with a given <paramref name="start"/>, <paramref name="end"/> and <paramref name="value"/>.
  201. /// </summary>
  202. /// <param name="start">Start of the range to insert</param>
  203. /// <param name="end">End of the range to insert</param>
  204. /// <param name="value">Value to insert</param>
  205. private void Insert(K start, K end, V value)
  206. {
  207. IntervalTreeNode<K, V> newNode = BSTInsert(start, end, value);
  208. RestoreBalanceAfterInsertion(newNode);
  209. }
  210. /// <summary>
  211. /// Propagate an increase in max value starting at the given node, heading up the tree.
  212. /// This should only be called if the max increases - not for rebalancing or removals.
  213. /// </summary>
  214. /// <param name="node">The node to start propagating from</param>
  215. private void PropagateIncrease(IntervalTreeNode<K, V> node)
  216. {
  217. K max = node.Max;
  218. IntervalTreeNode<K, V> ptr = node;
  219. while ((ptr = ptr.Parent) != null)
  220. {
  221. if (max.CompareTo(ptr.Max) > 0)
  222. {
  223. ptr.Max = max;
  224. }
  225. else
  226. {
  227. break;
  228. }
  229. }
  230. }
  231. /// <summary>
  232. /// Propagate recalculating max value starting at the given node, heading up the tree.
  233. /// This fully recalculates the max value from all children when there is potential for it to decrease.
  234. /// </summary>
  235. /// <param name="node">The node to start propagating from</param>
  236. private void PropagateFull(IntervalTreeNode<K, V> node)
  237. {
  238. IntervalTreeNode<K, V> ptr = node;
  239. do
  240. {
  241. K max = ptr.End;
  242. if (ptr.Left != null && ptr.Left.Max.CompareTo(max) > 0)
  243. {
  244. max = ptr.Left.Max;
  245. }
  246. if (ptr.Right != null && ptr.Right.Max.CompareTo(max) > 0)
  247. {
  248. max = ptr.Right.Max;
  249. }
  250. ptr.Max = max;
  251. } while ((ptr = ptr.Parent) != null);
  252. }
  253. /// <summary>
  254. /// Insertion Mechanism for the interval tree. Similar to a BST insert, with the start of the range as the key.
  255. /// Iterates the tree starting from the root and inserts a new node where all children in the left subtree are less than <paramref name="start"/>, and all children in the right subtree are greater than <paramref name="start"/>.
  256. /// Each node can contain multiple values, and has an end address which is the maximum of all those values.
  257. /// Post insertion, the "max" value of the node and all parents are updated.
  258. /// </summary>
  259. /// <param name="start">Start of the range to insert</param>
  260. /// <param name="end">End of the range to insert</param>
  261. /// <param name="value">Value to insert</param>
  262. /// <returns>The inserted Node</returns>
  263. private IntervalTreeNode<K, V> BSTInsert(K start, K end, V value)
  264. {
  265. IntervalTreeNode<K, V> parent = null;
  266. IntervalTreeNode<K, V> node = Root;
  267. while (node != null)
  268. {
  269. parent = node;
  270. int cmp = start.CompareTo(node.Start);
  271. if (cmp < 0)
  272. {
  273. node = node.Left;
  274. }
  275. else if (cmp > 0)
  276. {
  277. node = node.Right;
  278. }
  279. else
  280. {
  281. node.Values.Add(new RangeNode<K, V>(start, end, value));
  282. if (end.CompareTo(node.End) > 0)
  283. {
  284. node.End = end;
  285. if (end.CompareTo(node.Max) > 0)
  286. {
  287. node.Max = end;
  288. PropagateIncrease(node);
  289. }
  290. }
  291. Count++;
  292. return node;
  293. }
  294. }
  295. IntervalTreeNode<K, V> newNode = new IntervalTreeNode<K, V>(start, end, value, parent);
  296. if (newNode.Parent == null)
  297. {
  298. Root = newNode;
  299. }
  300. else if (start.CompareTo(parent.Start) < 0)
  301. {
  302. parent.Left = newNode;
  303. }
  304. else
  305. {
  306. parent.Right = newNode;
  307. }
  308. PropagateIncrease(newNode);
  309. Count++;
  310. return newNode;
  311. }
  312. /// <summary>
  313. /// Removes instances of <paramref name="value"> from the dictionary after searching for it with <paramref name="key">.
  314. /// </summary>
  315. /// <param name="key">Key to search for</param>
  316. /// <param name="value">Value to delete</param>
  317. /// <returns>Number of deleted values</returns>
  318. private int Delete(K key, V value)
  319. {
  320. IntervalTreeNode<K, V> nodeToDelete = GetNode(key);
  321. if (nodeToDelete == null)
  322. {
  323. return 0;
  324. }
  325. int removed = nodeToDelete.Values.RemoveAll(node => node.Value.Equals(value));
  326. if (nodeToDelete.Values.Count > 0)
  327. {
  328. if (removed > 0)
  329. {
  330. nodeToDelete.End = nodeToDelete.Values.Max(node => node.End);
  331. // Recalculate max from children and new end.
  332. PropagateFull(nodeToDelete);
  333. }
  334. return removed;
  335. }
  336. IntervalTreeNode<K, V> replacementNode;
  337. if (LeftOf(nodeToDelete) == null || RightOf(nodeToDelete) == null)
  338. {
  339. replacementNode = nodeToDelete;
  340. }
  341. else
  342. {
  343. replacementNode = PredecessorOf(nodeToDelete);
  344. }
  345. IntervalTreeNode<K, V> tmp = LeftOf(replacementNode) ?? RightOf(replacementNode);
  346. if (tmp != null)
  347. {
  348. tmp.Parent = ParentOf(replacementNode);
  349. }
  350. if (ParentOf(replacementNode) == null)
  351. {
  352. Root = tmp;
  353. }
  354. else if (replacementNode == LeftOf(ParentOf(replacementNode)))
  355. {
  356. ParentOf(replacementNode).Left = tmp;
  357. }
  358. else
  359. {
  360. ParentOf(replacementNode).Right = tmp;
  361. }
  362. if (replacementNode != nodeToDelete)
  363. {
  364. nodeToDelete.Start = replacementNode.Start;
  365. nodeToDelete.Values = replacementNode.Values;
  366. nodeToDelete.End = replacementNode.End;
  367. nodeToDelete.Max = replacementNode.Max;
  368. }
  369. PropagateFull(replacementNode);
  370. if (tmp != null && ColorOf(replacementNode) == Black)
  371. {
  372. RestoreBalanceAfterRemoval(tmp);
  373. }
  374. return removed;
  375. }
  376. #endregion
  377. protected override void RotateLeft(IntervalTreeNode<K, V> node)
  378. {
  379. if (node != null)
  380. {
  381. base.RotateLeft(node);
  382. PropagateFull(node);
  383. }
  384. }
  385. protected override void RotateRight(IntervalTreeNode<K, V> node)
  386. {
  387. if (node != null)
  388. {
  389. base.RotateRight(node);
  390. PropagateFull(node);
  391. }
  392. }
  393. public bool ContainsKey(K key)
  394. {
  395. if (key == null)
  396. {
  397. throw new ArgumentNullException(nameof(key));
  398. }
  399. return GetNode(key) != null;
  400. }
  401. }
  402. /// <summary>
  403. /// Represents a value and its start and end keys.
  404. /// </summary>
  405. /// <typeparam name="K"></typeparam>
  406. /// <typeparam name="V"></typeparam>
  407. public readonly struct RangeNode<K, V>
  408. {
  409. public readonly K Start;
  410. public readonly K End;
  411. public readonly V Value;
  412. public RangeNode(K start, K end, V value)
  413. {
  414. Start = start;
  415. End = end;
  416. Value = value;
  417. }
  418. }
  419. /// <summary>
  420. /// Represents a node in the IntervalTree which contains start and end keys of type K, and a value of generic type V.
  421. /// </summary>
  422. /// <typeparam name="K">Key type of the node</typeparam>
  423. /// <typeparam name="V">Value type of the node</typeparam>
  424. public class IntervalTreeNode<K, V> : IntrusiveRedBlackTreeNode<IntervalTreeNode<K, V>>
  425. {
  426. /// <summary>
  427. /// The start of the range.
  428. /// </summary>
  429. internal K Start;
  430. /// <summary>
  431. /// The end of the range - maximum of all in the Values list.
  432. /// </summary>
  433. internal K End;
  434. /// <summary>
  435. /// The maximum end value of this node and all its children.
  436. /// </summary>
  437. internal K Max;
  438. /// <summary>
  439. /// Values contained on the node that shares a common Start value.
  440. /// </summary>
  441. internal List<RangeNode<K, V>> Values;
  442. internal IntervalTreeNode(K start, K end, V value, IntervalTreeNode<K, V> parent)
  443. {
  444. Start = start;
  445. End = end;
  446. Max = end;
  447. Values = new List<RangeNode<K, V>> { new RangeNode<K, V>(start, end, value) };
  448. Parent = parent;
  449. }
  450. }
  451. }