Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/spirv-tools/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class Pass;
struct DescriptorSetAndBinding;
} // namespace opt

enum class SSARewriteMode {
None,
All,
OpaqueOnly,
SpecialTypes,
};

// C++ interface for SPIR-V optimization functionalities. It wraps the context
// (including target environment and the corresponding SPIR-V grammar) and
// provides methods for registering optimization passes and optimizing.
Expand Down Expand Up @@ -102,6 +109,8 @@ class SPIRV_TOOLS_EXPORT Optimizer {
// interface are considered live and are not eliminated.
Optimizer& RegisterPerformancePasses();
Optimizer& RegisterPerformancePasses(bool preserve_interface);
Optimizer& RegisterPerformancePassesFastCompile();
Optimizer& RegisterPerformancePassesFastCompile(bool preserve_interface);

// Registers passes that attempt to improve the size of generated code.
// This sequence of passes is subject to constant review and will change
Expand All @@ -125,6 +134,10 @@ class SPIRV_TOOLS_EXPORT Optimizer {
// interface are considered live and are not eliminated.
Optimizer& RegisterLegalizationPasses();
Optimizer& RegisterLegalizationPasses(bool preserve_interface);
Optimizer& RegisterLegalizationPassesFastCompile();
Optimizer& RegisterLegalizationPassesFastCompile(
bool preserve_interface, bool include_loop_unroll,
SSARewriteMode ssa_rewrite_mode);

// Register passes specified in the list of |flags|. Each flag must be a
// string of a form accepted by Optimizer::FlagHasValidForm().
Expand Down Expand Up @@ -710,6 +723,7 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
// Only variables that are local to the function and of supported types are
// processed (see IsSSATargetVar for details).
Optimizer::PassToken CreateSSARewritePass();
Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode);

// Create pass to convert relaxed precision instructions to half precision.
// This pass converts as many relaxed float32 arithmetic operations to half as
Expand Down
49 changes: 49 additions & 0 deletions source/opt/dead_variable_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "source/opt/dead_variable_elimination.h"

#include <unordered_set>
#include <vector>

#include "source/opt/ir_context.h"
Expand Down Expand Up @@ -77,9 +78,43 @@ Pass::Status DeadVariableElimination::Process() {
DeleteVariable(result_id);
}
}

ids_to_remove.clear();
for (auto& function : *get_module()) {
if (function.IsDeclaration()) continue;

auto& entry = *function.begin();
for (auto inst = entry.begin(); inst != entry.end(); ++inst) {
if (inst->opcode() != spv::Op::OpVariable) break;
if (!IsFunctionLocalVariable(&*inst)) continue;
if (IsLiveVar(inst->result_id())) continue;
ids_to_remove.push_back(inst->result_id());
}
}

if (!ids_to_remove.empty()) {
modified = true;
for (auto result_id : ids_to_remove) {
DeleteLocalVariable(result_id);
}
}

return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}

bool DeadVariableElimination::IsFunctionLocalVariable(
const Instruction* inst) const {
if (inst->opcode() != spv::Op::OpVariable) return false;

const Instruction* type_inst = get_def_use_mgr()->GetDef(inst->type_id());
if (type_inst == nullptr || type_inst->opcode() != spv::Op::OpTypePointer) {
return false;
}

return spv::StorageClass(type_inst->GetSingleWordInOperand(0)) ==
spv::StorageClass::Function;
}

void DeadVariableElimination::DeleteVariable(uint32_t result_id) {
Instruction* inst = get_def_use_mgr()->GetDef(result_id);
assert(inst->opcode() == spv::Op::OpVariable &&
Expand Down Expand Up @@ -108,5 +143,19 @@ void DeadVariableElimination::DeleteVariable(uint32_t result_id) {
}
context()->KillDef(result_id);
}

void DeadVariableElimination::DeleteLocalVariable(uint32_t result_id) {
std::queue<Instruction*> dead_stores;
std::unordered_set<Instruction*> processed;
AddStores(result_id, &dead_stores);
while (!dead_stores.empty()) {
Instruction* inst = dead_stores.front();
dead_stores.pop();
if (!processed.insert(inst).second) continue;
DCEInst(inst, nullptr);
}

context()->KillDef(result_id);
}
} // namespace opt
} // namespace spvtools
2 changes: 2 additions & 0 deletions source/opt/dead_variable_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DeadVariableElimination : public MemPass {
private:
// Deletes the OpVariable instruction who result id is |result_id|.
void DeleteVariable(uint32_t result_id);
void DeleteLocalVariable(uint32_t result_id);
bool IsFunctionLocalVariable(const Instruction* inst) const;

// Keeps track of the number of references of an id. Once that value is 0, it
// is safe to remove the corresponding instruction.
Expand Down
23 changes: 23 additions & 0 deletions source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,18 @@ bool LocalSingleStoreElimPass::RewriteLoads(
else
stored_id = store_inst->GetSingleWordInOperand(kVariableInitIdInIdx);

const auto get_image_pointer_id = [this](uint32_t value_id) {
Instruction* value_inst = context()->get_def_use_mgr()->GetDef(value_id);
while (value_inst && value_inst->opcode() == spv::Op::OpCopyObject) {
value_id = value_inst->GetSingleWordInOperand(0);
value_inst = context()->get_def_use_mgr()->GetDef(value_id);
}
if (!value_inst || value_inst->opcode() != spv::Op::OpLoad) {
return uint32_t{0};
}
return value_inst->GetSingleWordInOperand(0);
};

*all_rewritten = true;
bool modified = false;
for (Instruction* use : uses) {
Expand All @@ -319,6 +331,17 @@ bool LocalSingleStoreElimPass::RewriteLoads(
context()->KillNamesAndDecorates(use->result_id());
context()->ReplaceAllUsesWith(use->result_id(), stored_id);
context()->KillInst(use);
} else if (use->opcode() == spv::Op::OpImageTexelPointer &&
dominator_analysis->Dominates(store_inst, use)) {
const uint32_t image_ptr_id = get_image_pointer_id(stored_id);
if (image_ptr_id == 0) {
*all_rewritten = false;
continue;
}
modified = true;
context()->ForgetUses(use);
use->SetInOperand(0, {image_ptr_id});
context()->AnalyzeUses(use);
} else {
*all_rewritten = false;
}
Expand Down
27 changes: 25 additions & 2 deletions source/opt/mem_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,27 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
}

bool MemPass::IsTargetType(const Instruction* typeInst) const {
if (IsBaseTargetType(typeInst)) return true;
switch (ssa_rewrite_mode_) {
case SSARewriteMode::None:
return false;
case SSARewriteMode::OpaqueOnly:
if (typeInst->IsOpaqueType()) return true;
break;
case SSARewriteMode::SpecialTypes:
switch (typeInst->opcode()) {
case spv::Op::OpTypePointer:
case spv::Op::OpTypeUntypedPointerKHR:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
break;
}
break;
case SSARewriteMode::All:
if (IsBaseTargetType(typeInst)) return true;
break;
}
if (typeInst->opcode() == spv::Op::OpTypeArray) {
if (!IsTargetType(
get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
Expand Down Expand Up @@ -198,7 +218,7 @@ bool MemPass::IsLiveVar(uint32_t varId) const {
void MemPass::AddStores(uint32_t ptr_id, std::queue<Instruction*>* insts) {
get_def_use_mgr()->ForEachUser(ptr_id, [this, insts](Instruction* user) {
spv::Op op = user->opcode();
if (IsNonPtrAccessChain(op)) {
if (IsNonPtrAccessChain(op) || op == spv::Op::OpCopyObject) {
AddStores(user->result_id(), insts);
} else if (op == spv::Op::OpStore) {
insts->push(user);
Expand Down Expand Up @@ -243,6 +263,9 @@ void MemPass::DCEInst(Instruction* inst,

MemPass::MemPass() {}

MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
: ssa_rewrite_mode_(ssa_rewrite_mode) {}

bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {
auto dbg_op = user->GetCommonDebugOpcode();
Expand Down
6 changes: 5 additions & 1 deletion source/opt/mem_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <unordered_set>
#include <utility>

#include "spirv-tools/optimizer.hpp"
#include "source/opt/basic_block.h"
#include "source/opt/def_use_manager.h"
#include "source/opt/dominator_analysis.h"
Expand Down Expand Up @@ -69,6 +70,7 @@ class MemPass : public Pass {

protected:
MemPass();
explicit MemPass(SSARewriteMode ssa_rewrite_mode);

// Returns true if |typeInst| is a scalar type
// or a vector or matrix
Expand Down Expand Up @@ -133,7 +135,9 @@ class MemPass : public Pass {
// Cache of verified non-target vars
std::unordered_set<uint32_t> seen_non_target_vars_;

private:
private:
SSARewriteMode ssa_rewrite_mode_ = SSARewriteMode::All;

// Return true if all uses of |varId| are only through supported reference
// operations ie. loads and store. Also cache in supported_ref_vars_.
// TODO(dnovillo): This function is replicated in other passes and it's
Expand Down
122 changes: 122 additions & 0 deletions source/opt/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,72 @@ Optimizer& Optimizer::RegisterLegalizationPasses() {
return RegisterLegalizationPasses(false);
}

Optimizer& Optimizer::RegisterLegalizationPassesFastCompile(
bool preserve_interface, bool include_loop_unroll,
SSARewriteMode ssa_rewrite_mode) {
auto& optimizer =
// Wrap OpKill instructions so all other code can be inlined.
RegisterPass(CreateWrapOpKillPass())
// Remove unreachable block so that merge return works.
.RegisterPass(CreateDeadBranchElimPass())
// Merge the returns so we can inline.
.RegisterPass(CreateMergeReturnPass())
// Make sure uses and definitions are in the same function.
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateEliminateDeadFunctionsPass());
optimizer.RegisterPass(CreatePrivateToLocalPass());
// Fix up the storage classes that DXC may have purposely generated
// incorrectly. All functions are inlined, and a lot of dead code has
// been removed.
optimizer.RegisterPass(CreateFixStorageClassPass());
// Propagate the value stored to the loads in very simple cases.
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateSSARewritePass(SSARewriteMode::SpecialTypes));
optimizer
// Split up aggregates so they are easier to deal with.
.RegisterPass(CreateScalarReplacementPass(0));
// Remove loads and stores so everything is in intermediate values.
// Takes care of copy propagation of non-members.
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
if (ssa_rewrite_mode != SSARewriteMode::None) {
optimizer.RegisterPass(CreateSSARewritePass(ssa_rewrite_mode));
}
optimizer
// Propagate constants to get as many constant conditions on branches
// as possible.
.RegisterPass(CreateCCPPass());
if (include_loop_unroll) {
optimizer.RegisterPass(CreateLoopUnrollPass(true));
}
optimizer.RegisterPass(CreateDeadBranchElimPass())
// Copy propagate members. Cleans up code sequences generated by scalar
// replacement. Also important for removing OpPhi nodes.
.RegisterPass(CreateSimplificationPass());
return optimizer
// May need loop unrolling here see
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
// Get rid of unused code that contain traces of illegal code
// or unused references to unbound external objects
.RegisterPass(CreateVectorDCEPass())
.RegisterPass(CreateDeadInsertElimPass())
.RegisterPass(CreateReduceLoadSizePass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateDeadVariableEliminationPass())
.RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
.RegisterPass(CreateInterpolateFixupPass())
.RegisterPass(CreateInvocationInterlockPlacementPass())
.RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
}

Optimizer& Optimizer::RegisterLegalizationPassesFastCompile() {
return RegisterLegalizationPassesFastCompile(false, true,
SSARewriteMode::All);
}

Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
return RegisterPass(CreateWrapOpKillPass())
.RegisterPass(CreateDeadBranchElimPass())
Expand Down Expand Up @@ -231,6 +297,57 @@ Optimizer& Optimizer::RegisterPerformancePasses() {
return RegisterPerformancePasses(false);
}

Optimizer& Optimizer::RegisterPerformancePassesFastCompile(
bool preserve_interface) {
auto& optimizer = RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateDeadVariableEliminationPass())
.RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
.RegisterPass(CreateWrapOpKillPass())
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateMergeReturnPass())
.RegisterPass(CreateInlineExhaustivePass())
.RegisterPass(CreateEliminateDeadFunctionsPass())
.RegisterPass(CreatePrivateToLocalPass())
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateScalarReplacementPass(0))
.RegisterPass(CreateLocalAccessChainConvertPass());
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
optimizer.RegisterPass(CreateCCPPass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
optimizer.RegisterPass(CreateDeadBranchElimPass());
optimizer.RegisterPass(CreateLocalRedundancyEliminationPass());
optimizer.RegisterPass(CreateCombineAccessChainsPass())
.RegisterPass(CreateSimplificationPass())
.RegisterPass(CreateScalarReplacementPass(0))
.RegisterPass(CreateLocalAccessChainConvertPass())
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
.RegisterPass(CreateLocalSingleStoreElimPass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateVectorDCEPass())
.RegisterPass(CreateDeadInsertElimPass())
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateSimplificationPass())
.RegisterPass(CreateIfConversionPass())
.RegisterPass(CreateCopyPropagateArraysPass())
.RegisterPass(CreateReduceLoadSizePass())
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateBlockMergePass());
optimizer.RegisterPass(CreateLocalRedundancyEliminationPass());
return optimizer.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
.RegisterPass(CreateDeadBranchElimPass())
.RegisterPass(CreateBlockMergePass())
.RegisterPass(CreateSimplificationPass());
}

Optimizer& Optimizer::RegisterPerformancePassesFastCompile() {
return RegisterPerformancePassesFastCompile(false);
}

Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) {
return RegisterPass(CreateWrapOpKillPass())
.RegisterPass(CreateDeadBranchElimPass())
Expand Down Expand Up @@ -1024,6 +1141,11 @@ Optimizer::PassToken CreateSSARewritePass() {
MakeUnique<opt::SSARewritePass>());
}

Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode) {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::SSARewritePass>(mode));
}

Optimizer::PassToken CreateCopyPropagateArraysPass() {
return MakeUnique<Optimizer::PassToken::Impl>(
MakeUnique<opt::CopyPropagateArrays>());
Expand Down
Loading
Loading