IoMap.cs 5.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. using Ryujinx.Common.Logging;
  2. using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  3. using Ryujinx.Graphics.Shader.Translation;
  4. using System.Globalization;
  5. namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
  6. {
  7. static class IoMap
  8. {
  9. public static (string, AggregateType) GetMslBuiltIn(
  10. ShaderDefinitions definitions,
  11. IoVariable ioVariable,
  12. int location,
  13. int component,
  14. bool isOutput,
  15. bool isPerPatch)
  16. {
  17. (string, AggregateType) returnValue = ioVariable switch
  18. {
  19. IoVariable.BaseInstance => ("base_instance", AggregateType.U32),
  20. IoVariable.BaseVertex => ("base_vertex", AggregateType.U32),
  21. IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
  22. IoVariable.ClipDistance => ("out.clip_distance", AggregateType.Array | AggregateType.FP32),
  23. IoVariable.FragmentOutputColor => ($"out.color{location}", definitions.GetFragmentOutputColorType(location)),
  24. IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32),
  25. IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool),
  26. IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32),
  27. IoVariable.InstanceId => ("instance_id", AggregateType.U32),
  28. IoVariable.InstanceIndex => ("instance_index", AggregateType.U32),
  29. IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32),
  30. IoVariable.PointCoord => ("in.point_coord", AggregateType.Vector2 | AggregateType.FP32),
  31. IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
  32. IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
  33. IoVariable.PrimitiveId => ("in.primitive_id", AggregateType.U32),
  34. IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32),
  35. IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
  36. IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32),
  37. IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32),
  38. IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
  39. IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32),
  40. IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool),
  41. IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),
  42. IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32),
  43. IoVariable.VertexId => ("vertex_id", AggregateType.S32),
  44. // gl_VertexIndex does not have a direct equivalent in MSL
  45. IoVariable.VertexIndex => ("vertex_id", AggregateType.U32),
  46. IoVariable.ViewportIndex => ("viewport_array_index", AggregateType.S32),
  47. IoVariable.FragmentCoord => ("in.position", AggregateType.Vector4 | AggregateType.FP32),
  48. _ => (null, AggregateType.Invalid),
  49. };
  50. if (returnValue.Item2 == AggregateType.Invalid)
  51. {
  52. Logger.Warning?.PrintMsg(LogClass.Gpu, $"Unable to find type for IoVariable {ioVariable}!");
  53. }
  54. return returnValue;
  55. }
  56. private static (string, AggregateType) GetUserDefinedVariableName(ShaderDefinitions definitions, int location, int component, bool isOutput, bool isPerPatch)
  57. {
  58. string name = isPerPatch
  59. ? Defaults.PerPatchAttributePrefix
  60. : (isOutput ? Defaults.OAttributePrefix : Defaults.IAttributePrefix);
  61. if (location < 0)
  62. {
  63. return (name, definitions.GetUserDefinedType(0, isOutput));
  64. }
  65. name += location.ToString(CultureInfo.InvariantCulture);
  66. if (definitions.HasPerLocationInputOrOutputComponent(IoVariable.UserDefined, location, component, isOutput))
  67. {
  68. name += "_" + "xyzw"[component & 3];
  69. }
  70. string prefix = isOutput ? "out" : "in";
  71. return (prefix + "." + name, definitions.GetUserDefinedType(location, isOutput));
  72. }
  73. }
  74. }