DeviceState.cs 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Reflection;
  7. using System.Runtime.CompilerServices;
  8. namespace Ryujinx.Graphics.Device
  9. {
  10. public class DeviceState<TState> : IDeviceState where TState : unmanaged
  11. {
  12. private const int RegisterSize = sizeof(int);
  13. public TState State;
  14. private readonly BitArray _readableRegisters;
  15. private readonly BitArray _writableRegisters;
  16. private readonly Dictionary<int, Func<int>> _readCallbacks;
  17. private readonly Dictionary<int, Action<int>> _writeCallbacks;
  18. public DeviceState(IReadOnlyDictionary<string, RwCallback> callbacks = null)
  19. {
  20. int size = (Unsafe.SizeOf<TState>() + RegisterSize - 1) / RegisterSize;
  21. _readableRegisters = new BitArray(size);
  22. _writableRegisters = new BitArray(size);
  23. _readCallbacks = new Dictionary<int, Func<int>>();
  24. _writeCallbacks = new Dictionary<int, Action<int>>();
  25. var fields = typeof(TState).GetFields();
  26. int offset = 0;
  27. for (int fieldIndex = 0; fieldIndex < fields.Length; fieldIndex++)
  28. {
  29. var field = fields[fieldIndex];
  30. var regAttr = field.GetCustomAttributes<RegisterAttribute>(false).FirstOrDefault();
  31. int sizeOfField = SizeCalculator.SizeOf(field.FieldType);
  32. for (int i = 0; i < ((sizeOfField + 3) & ~3); i += 4)
  33. {
  34. _readableRegisters[(offset + i) / RegisterSize] = regAttr?.AccessControl.HasFlag(AccessControl.ReadOnly) ?? true;
  35. _writableRegisters[(offset + i) / RegisterSize] = regAttr?.AccessControl.HasFlag(AccessControl.WriteOnly) ?? true;
  36. }
  37. if (callbacks != null && callbacks.TryGetValue(field.Name, out var cb))
  38. {
  39. if (cb.Read != null)
  40. {
  41. _readCallbacks.Add(offset, cb.Read);
  42. }
  43. if (cb.Write != null)
  44. {
  45. _writeCallbacks.Add(offset, cb.Write);
  46. }
  47. }
  48. offset += sizeOfField;
  49. }
  50. Debug.Assert(offset == Unsafe.SizeOf<TState>());
  51. }
  52. public virtual int Read(int offset)
  53. {
  54. if (Check(offset) && _readableRegisters[offset / RegisterSize])
  55. {
  56. int alignedOffset = Align(offset);
  57. if (_readCallbacks.TryGetValue(alignedOffset, out Func<int> read))
  58. {
  59. return read();
  60. }
  61. else
  62. {
  63. return GetRef<int>(alignedOffset);
  64. }
  65. }
  66. return 0;
  67. }
  68. public virtual void Write(int offset, int data)
  69. {
  70. if (Check(offset) && _writableRegisters[offset / RegisterSize])
  71. {
  72. int alignedOffset = Align(offset);
  73. GetRef<int>(alignedOffset) = data;
  74. if (_writeCallbacks.TryGetValue(alignedOffset, out Action<int> write))
  75. {
  76. write(data);
  77. }
  78. }
  79. }
  80. private bool Check(int offset)
  81. {
  82. return (uint)Align(offset) < Unsafe.SizeOf<TState>();
  83. }
  84. public ref T GetRef<T>(int offset) where T : unmanaged
  85. {
  86. if ((uint)(offset + Unsafe.SizeOf<T>()) > Unsafe.SizeOf<TState>())
  87. {
  88. throw new ArgumentOutOfRangeException(nameof(offset));
  89. }
  90. return ref Unsafe.As<TState, T>(ref Unsafe.AddByteOffset(ref State, (IntPtr)offset));
  91. }
  92. private static int Align(int offset)
  93. {
  94. return offset & ~(RegisterSize - 1);
  95. }
  96. }
  97. }