diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index e2c0266..1118767 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -21,6 +21,8 @@ static const std::uint32_t Undefined = UINT32_MAX; class Op; class Operand; +typedef const Op* Ref; + class Module { public: explicit Module(); @@ -47,110 +49,109 @@ public: void SetMemoryModel(spv::AddressingModel addressing_model, spv::MemoryModel memory_model); /// Adds an entry point. - void AddEntryPoint(spv::ExecutionModel execution_model, const Op* entry_point, - const std::string& name, const std::vector& interfaces = {}); + void AddEntryPoint(spv::ExecutionModel execution_model, Ref entry_point, + const std::string& name, const std::vector& interfaces = {}); /** * Adds an instruction to module's code * @param op Instruction to insert into code. Types and constants must not be emitted. * @return Returns op. */ - const Op* Emit(const Op* op); + Ref Emit(Ref op); // Types /// Returns type void. - const Op* TypeVoid(); + Ref TypeVoid(); /// Returns type bool. - const Op* TypeBool(); + Ref TypeBool(); /// Returns type integer. - const Op* TypeInt(int width, bool is_signed); + Ref TypeInt(int width, bool is_signed); /// Returns type float. - const Op* TypeFloat(int width); + Ref TypeFloat(int width); /// Returns type vector. - const Op* TypeVector(const Op* component_type, int component_count); + Ref TypeVector(Ref component_type, int component_count); /// Returns type matrix. - const Op* TypeMatrix(const Op* column_type, int column_count); + Ref TypeMatrix(Ref column_type, int column_count); /// Returns type image. - const Op* TypeImage(const Op* sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, + Ref TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, int sampled, spv::ImageFormat image_format, spv::AccessQualifier access_qualifier = static_cast(Undefined)); /// Returns type sampler. - const Op* TypeSampler(); + Ref TypeSampler(); /// Returns type sampled image. - const Op* TypeSampledImage(const Op* image_type); + Ref TypeSampledImage(Ref image_type); /// Returns type array. - const Op* TypeArray(const Op* element_type, const Op* length); + Ref TypeArray(Ref element_type, Ref length); /// Returns type runtime array. - const Op* TypeRuntimeArray(const Op* element_type); + Ref TypeRuntimeArray(Ref element_type); /// Returns type struct. - const Op* TypeStruct(const std::vector& members = {}); + Ref TypeStruct(const std::vector& members = {}); /// Returns type opaque. - const Op* TypeOpaque(const std::string& name); + Ref TypeOpaque(const std::string& name); /// Returns type pointer. - const Op* TypePointer(spv::StorageClass storage_class, const Op* type); + Ref TypePointer(spv::StorageClass storage_class, Ref type); /// Returns type function. - const Op* TypeFunction(const Op* return_type, const std::vector& arguments = {}); + Ref TypeFunction(Ref return_type, const std::vector& arguments = {}); /// Returns type event. - const Op* TypeEvent(); + Ref TypeEvent(); /// Returns type device event. - const Op* TypeDeviceEvent(); + Ref TypeDeviceEvent(); /// Returns type reserve id. - const Op* TypeReserveId(); + Ref TypeReserveId(); /// Returns type queue. - const Op* TypeQueue(); + Ref TypeQueue(); /// Returns type pipe. - const Op* TypePipe(spv::AccessQualifier access_qualifier); + Ref TypePipe(spv::AccessQualifier access_qualifier); // Constant /// Returns a true scalar constant. - const Op* ConstantTrue(const Op* result_type); + Ref ConstantTrue(Ref result_type); /// Returns a false scalar constant. - const Op* ConstantFalse(const Op* result_type); + Ref ConstantFalse(Ref result_type); /// Returns a numeric scalar constant. - const Op* Constant(const Op* result_type, Operand* literal); + Ref Constant(Ref result_type, Operand* literal); /// Returns a numeric scalar constant. - const Op* ConstantComposite(const Op* result_type, const std::vector& constituents); + Ref ConstantComposite(Ref result_type, const std::vector& constituents); // Function /// Emits a function. - const Op* Function(const Op* result_type, spv::FunctionControlMask function_control, - const Op* function_type); + Ref Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type); /// Emits a function end. - const Op* FunctionEnd(); + Ref FunctionEnd(); // Flow /// Emits a label. It starts a block. - const Op* Label(); + Ref Label(); /// Emits a return. It ends a block. - const Op* Return(); + Ref Return(); // Literals static Operand* Literal(std::uint32_t value); @@ -161,11 +162,11 @@ public: static Operand* Literal(double value); private: - const Op* AddCode(Op* op); + Ref AddCode(Op* op); - const Op* AddCode(spv::Op opcode, std::uint32_t id = UINT32_MAX); + Ref AddCode(spv::Op opcode, std::uint32_t id = UINT32_MAX); - const Op* AddDeclaration(Op* op); + Ref AddDeclaration(Op* op); std::uint32_t bound{1}; @@ -188,7 +189,7 @@ private: std::vector> declarations; - std::vector code; + std::vector code; std::vector> code_store; }; diff --git a/src/op.cpp b/src/op.cpp index 52f0cc4..a8d5d2a 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -11,7 +11,7 @@ namespace Sirit { -Op::Op(spv::Op opcode_, u32 id_, const Op* result_type_) +Op::Op(spv::Op opcode_, u32 id_, Ref result_type_) : opcode(opcode_), id(id_), result_type(result_type_) { operand_type = OperandType::Op; } @@ -76,8 +76,8 @@ void Op::Add(const std::string& string) { Add(new LiteralString(string)); } -void Op::Add(const std::vector& ids) { - for (const Op* op : ids) { +void Op::Add(const std::vector& ids) { + for (Ref op : ids) { Add(op); } } diff --git a/src/op.h b/src/op.h index 5209eb4..8c884bc 100644 --- a/src/op.h +++ b/src/op.h @@ -15,7 +15,7 @@ namespace Sirit { class Op : public Operand { public: - explicit Op(spv::Op opcode, u32 id = UINT32_MAX, const Op* result_type = nullptr); + explicit Op(spv::Op opcode, u32 id = UINT32_MAX, Ref result_type = nullptr); ~Op(); virtual void Fetch(Stream& stream) const; @@ -33,14 +33,14 @@ public: void Add(const std::string& string); - void Add(const std::vector& ids); + void Add(const std::vector& ids); private: u16 WordCount() const; spv::Op opcode; - const Op* result_type; + Ref result_type; u32 id; diff --git a/src/opcodes/constant.cpp b/src/opcodes/constant.cpp index 764253d..aacd8a9 100644 --- a/src/opcodes/constant.cpp +++ b/src/opcodes/constant.cpp @@ -10,22 +10,21 @@ namespace Sirit { -const Op* Module::ConstantTrue(const Op* result_type) { +Ref Module::ConstantTrue(Ref result_type) { return AddDeclaration(new Op(spv::Op::OpConstantTrue, bound, result_type)); } -const Op* Module::ConstantFalse(const Op* result_type) { +Ref Module::ConstantFalse(Ref result_type) { return AddDeclaration(new Op(spv::Op::OpConstantFalse, bound, result_type)); } -const Op* Module::Constant(const Op* result_type, Operand* literal) { +Ref Module::Constant(Ref result_type, Operand* literal) { Op* op{new Op(spv::Op::OpConstant, bound, result_type)}; op->Add(literal); return AddDeclaration(op); } -const Op* Module::ConstantComposite(const Op* result_type, - const std::vector& constituents) { +Ref Module::ConstantComposite(Ref result_type, const std::vector& constituents) { Op* op{new Op(spv::Op::OpConstantComposite, bound, result_type)}; op->Add(constituents); return AddDeclaration(op); diff --git a/src/opcodes/flow.cpp b/src/opcodes/flow.cpp index 2351cd5..bfe4f17 100644 --- a/src/opcodes/flow.cpp +++ b/src/opcodes/flow.cpp @@ -9,11 +9,11 @@ namespace Sirit { -const Op* Module::Label() { +Ref Module::Label() { return AddCode(spv::Op::OpLabel, bound++); } -const Op* Module::Return() { +Ref Module::Return() { return AddCode(spv::Op::OpReturn); } diff --git a/src/opcodes/function.cpp b/src/opcodes/function.cpp index da2f932..345e6d6 100644 --- a/src/opcodes/function.cpp +++ b/src/opcodes/function.cpp @@ -9,15 +9,14 @@ namespace Sirit { -const Op* Module::Function(const Op* result_type, spv::FunctionControlMask function_control, - const Op* function_type) { +Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) { Op* op{new Op{spv::Op::OpFunction, bound++, result_type}}; op->Add(static_cast(function_control)); op->Add(function_type); return AddCode(op); } -const Op* Module::FunctionEnd() { +Ref Module::FunctionEnd() { return AddCode(spv::Op::OpFunctionEnd); } diff --git a/src/opcodes/type.cpp b/src/opcodes/type.cpp index afae9fd..e7c16fb 100644 --- a/src/opcodes/type.cpp +++ b/src/opcodes/type.cpp @@ -10,15 +10,15 @@ namespace Sirit { -const Op* Module::TypeVoid() { +Ref Module::TypeVoid() { return AddDeclaration(new Op(spv::Op::OpTypeVoid, bound)); } -const Op* Module::TypeBool() { +Ref Module::TypeBool() { return AddDeclaration(new Op(spv::Op::OpTypeBool, bound)); } -const Op* Module::TypeInt(int width, bool is_signed) { +Ref Module::TypeInt(int width, bool is_signed) { if (width == 8) { AddCapability(spv::Capability::Int8); } else if (width == 16) { @@ -32,7 +32,7 @@ const Op* Module::TypeInt(int width, bool is_signed) { return AddDeclaration(op); } -const Op* Module::TypeFloat(int width) { +Ref Module::TypeFloat(int width) { if (width == 16) { AddCapability(spv::Capability::Float16); } else if (width == 64) { @@ -43,7 +43,7 @@ const Op* Module::TypeFloat(int width) { return AddDeclaration(op); } -const Op* Module::TypeVector(const Op* component_type, int component_count) { +Ref Module::TypeVector(Ref component_type, int component_count) { assert(component_count >= 2); Op* op{new Op(spv::Op::OpTypeVector, bound)}; op->Add(component_type); @@ -51,7 +51,7 @@ const Op* Module::TypeVector(const Op* component_type, int component_count) { return AddDeclaration(op); } -const Op* Module::TypeMatrix(const Op* column_type, int column_count) { +Ref Module::TypeMatrix(Ref column_type, int column_count) { assert(column_count >= 2); AddCapability(spv::Capability::Matrix); Op* op{new Op(spv::Op::OpTypeMatrix, bound)}; @@ -60,7 +60,7 @@ const Op* Module::TypeMatrix(const Op* column_type, int column_count) { return AddDeclaration(op); } -const Op* Module::TypeImage(const Op* sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, +Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms, int sampled, spv::ImageFormat image_format, spv::AccessQualifier access_qualifier) { switch (dim) { @@ -138,44 +138,44 @@ const Op* Module::TypeImage(const Op* sampled_type, spv::Dim dim, int depth, boo return AddDeclaration(op); } -const Op* Module::TypeSampler() { +Ref Module::TypeSampler() { return AddDeclaration(new Op(spv::Op::OpTypeSampler, bound)); } -const Op* Module::TypeSampledImage(const Op* image_type) { +Ref Module::TypeSampledImage(Ref image_type) { Op* op{new Op(spv::Op::OpTypeSampledImage, bound)}; op->Add(image_type); return AddDeclaration(op); } -const Op* Module::TypeArray(const Op* element_type, const Op* length) { +Ref Module::TypeArray(Ref element_type, Ref length) { Op* op{new Op(spv::Op::OpTypeArray, bound)}; op->Add(element_type); op->Add(length); return AddDeclaration(op); } -const Op* Module::TypeRuntimeArray(const Op* element_type) { +Ref Module::TypeRuntimeArray(Ref element_type) { AddCapability(spv::Capability::Shader); Op* op{new Op(spv::Op::OpTypeRuntimeArray, bound)}; op->Add(element_type); return AddDeclaration(op); } -const Op* Module::TypeStruct(const std::vector& members) { +Ref Module::TypeStruct(const std::vector& members) { Op* op{new Op(spv::Op::OpTypeStruct, bound)}; op->Add(members); return AddDeclaration(op); } -const Op* Module::TypeOpaque(const std::string& name) { +Ref Module::TypeOpaque(const std::string& name) { AddCapability(spv::Capability::Kernel); Op* op{new Op(spv::Op::OpTypeOpaque, bound)}; op->Add(name); return AddDeclaration(op); } -const Op* Module::TypePointer(spv::StorageClass storage_class, const Op* type) { +Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) { switch (storage_class) { case spv::StorageClass::Uniform: case spv::StorageClass::Output: @@ -197,34 +197,34 @@ const Op* Module::TypePointer(spv::StorageClass storage_class, const Op* type) { return AddDeclaration(op); } -const Op* Module::TypeFunction(const Op* return_type, const std::vector& arguments) { +Ref Module::TypeFunction(Ref return_type, const std::vector& arguments) { Op* op{new Op(spv::Op::OpTypeFunction, bound)}; op->Add(return_type); op->Add(arguments); return AddDeclaration(op); } -const Op* Module::TypeEvent() { +Ref Module::TypeEvent() { AddCapability(spv::Capability::Kernel); return AddDeclaration(new Op(spv::Op::OpTypeEvent, bound)); } -const Op* Module::TypeDeviceEvent() { +Ref Module::TypeDeviceEvent() { AddCapability(spv::Capability::DeviceEnqueue); return AddDeclaration(new Op(spv::Op::OpTypeDeviceEvent, bound)); } -const Op* Module::TypeReserveId() { +Ref Module::TypeReserveId() { AddCapability(spv::Capability::Pipes); return AddDeclaration(new Op(spv::Op::OpTypeReserveId, bound)); } -const Op* Module::TypeQueue() { +Ref Module::TypeQueue() { AddCapability(spv::Capability::DeviceEnqueue); return AddDeclaration(new Op(spv::Op::OpTypeQueue, bound)); } -const Op* Module::TypePipe(spv::AccessQualifier access_qualifier) { +Ref Module::TypePipe(spv::AccessQualifier access_qualifier) { AddCapability(spv::Capability::Pipes); Op* op{new Op(spv::Op::OpTypePipe, bound)}; op->Add(static_cast(access_qualifier)); diff --git a/src/sirit.cpp b/src/sirit.cpp index faec175..a9b1ec7 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -73,8 +73,8 @@ void Module::SetMemoryModel(spv::AddressingModel addressing_model, spv::MemoryMo this->memory_model = memory_model; } -void Module::AddEntryPoint(spv::ExecutionModel execution_model, const Op* entry_point, - const std::string& name, const std::vector& interfaces) { +void Module::AddEntryPoint(spv::ExecutionModel execution_model, Ref entry_point, + const std::string& name, const std::vector& interfaces) { Op* op{new Op(spv::Op::OpEntryPoint)}; op->Add(static_cast(execution_model)); op->Add(entry_point); @@ -83,22 +83,22 @@ void Module::AddEntryPoint(spv::ExecutionModel execution_model, const Op* entry_ entry_points.push_back(std::unique_ptr(op)); } -const Op* Module::Emit(const Op* op) { +Ref Module::Emit(Ref op) { assert(op); code.push_back(op); return op; } -const Op* Module::AddCode(Op* op) { +Ref Module::AddCode(Op* op) { code_store.push_back(std::unique_ptr(op)); return op; } -const Op* Module::AddCode(spv::Op opcode, u32 id) { +Ref Module::AddCode(spv::Op opcode, u32 id) { return AddCode(new Op{opcode, id}); } -const Op* Module::AddDeclaration(Op* op) { +Ref Module::AddDeclaration(Op* op) { const auto& found{std::find_if(declarations.begin(), declarations.end(), [=](const auto& other) { return *other == *op; })};