blob: a25ad5a6946d42e2e3e02552e631d0c52921614b [file]
// Copyright 2021 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "compiler/plugins/target/WebGPUSPIRV/SPIRVToWGSL.h"
#include <string>
#include <utility>
#include <vector>
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
#include "src/tint/lang/core/ir/disassembler.h"
#include "src/tint/lang/spirv/reader/reader.h"
#include "src/tint/lang/wgsl/program/program.h"
#include "src/tint/lang/wgsl/writer/writer.h"
namespace mlir::iree_compiler::IREE::HAL {
namespace {
constexpr uint32_t kSPIRVMagicNumber = 0x07230203u;
void printTintFailure(llvm::StringRef stage, llvm::StringRef reason) {
llvm::errs() << "Tint failed to " << stage;
if (reason.empty()) {
llvm::errs() << ".\n";
return;
}
llvm::errs() << ":\n" << reason;
if (!reason.ends_with('\n')) {
llvm::errs() << "\n";
}
}
void printSPIRVModuleSummary(llvm::ArrayRef<uint32_t> spvBinary) {
llvm::errs() << "SPIR-V module summary:\n";
llvm::errs() << " words: " << spvBinary.size() << " ("
<< spvBinary.size() * sizeof(uint32_t) << " bytes)\n";
if (spvBinary.size() < 5) {
llvm::errs() << " header: incomplete; expected at least 5 words\n";
return;
}
uint32_t magic = spvBinary[0];
uint32_t version = spvBinary[1];
llvm::errs() << " magic: " << llvm::format_hex(magic, 10);
if (magic != kSPIRVMagicNumber) {
llvm::errs() << " (expected " << llvm::format_hex(kSPIRVMagicNumber, 10)
<< ")";
}
llvm::errs() << "\n";
llvm::errs() << " version: " << ((version >> 16) & 0xFF) << "."
<< ((version >> 8) & 0xFF) << " ("
<< llvm::format_hex(version, 10) << ")\n";
llvm::errs() << " generator: " << llvm::format_hex(spvBinary[2], 10) << "\n";
llvm::errs() << " bound: " << spvBinary[3] << "\n";
llvm::errs() << " schema: " << llvm::format_hex(spvBinary[4], 10) << "\n";
}
void printTintIRAtFailure(const tint::core::ir::Module &irModule) {
llvm::errs() << "Tint IR at failure:\n";
llvm::errs() << tint::core::ir::Disassembler(irModule).Plain();
llvm::errs() << "\n";
}
} // namespace
std::optional<std::string>
compileSPIRVToWGSL(llvm::ArrayRef<uint32_t> spvBinary) {
std::vector<uint32_t> binaryVector(spvBinary.begin(), spvBinary.end());
auto irResult = tint::spirv::reader::ReadIR(binaryVector);
if (irResult != tint::Success) {
printTintFailure("parse SPIR-V into IR", irResult.Failure().reason);
printSPIRVModuleSummary(spvBinary);
return std::nullopt;
}
tint::core::ir::Module irModule = irResult.Move();
tint::wgsl::writer::Options writerOptions;
auto programResult =
tint::wgsl::writer::ProgramFromIR(irModule, writerOptions);
if (programResult != tint::Success) {
printTintFailure("lower Tint IR to a WGSL program",
programResult.Failure().reason);
printSPIRVModuleSummary(spvBinary);
printTintIRAtFailure(irModule);
return std::nullopt;
}
auto wgslResult =
tint::wgsl::writer::Generate(programResult.Get(), writerOptions);
if (wgslResult != tint::Success) {
printTintFailure("print WGSL", wgslResult.Failure().reason);
printSPIRVModuleSummary(spvBinary);
printTintIRAtFailure(irModule);
return std::nullopt;
}
return std::move(wgslResult->wgsl);
}
} // namespace mlir::iree_compiler::IREE::HAL