| // Copyright 2023 The IREE Authors |
| // |
| // Licensed under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| // Implements logic for lowering StableHLO random number generation to Linalg |
| // dialect. |
| |
| #include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h" |
| #include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h" |
| #include "mlir/Dialect/Linalg/IR/Linalg.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "stablehlo/dialect/StablehloOps.h" |
| |
| namespace mlir::iree_compiler::stablehlo { |
| namespace { |
| class ArithOpBuilder { |
| public: |
| ArithOpBuilder(OpBuilder b, Location l, Value v) |
| : builder(b), loc(l), value(v) {} |
| |
| explicit operator Value() { return value; } |
| Value val() { return value; } |
| |
| ArithOpBuilder constantI(int64_t value, int64_t bits) { |
| Value val = builder.create<arith::ConstantOp>( |
| loc, builder.getIntegerAttr(builder.getIntegerType(bits), value)); |
| return ArithOpBuilder(builder, loc, val); |
| } |
| |
| ArithOpBuilder extendUI(int32_t bits) { |
| Value ext = builder.create<arith::ExtUIOp>( |
| loc, builder.getIntegerType(bits), value); |
| return ArithOpBuilder(builder, loc, ext); |
| } |
| |
| ArithOpBuilder truncI(int64_t bits) { |
| if (value.getType().getIntOrFloatBitWidth() == bits) |
| return *this; |
| Value trunc = builder.create<arith::TruncIOp>( |
| loc, builder.getIntegerType(bits), value); |
| return ArithOpBuilder(builder, loc, trunc); |
| } |
| |
| ArithOpBuilder linalgIndex(int32_t index) { |
| Value val = builder.create<linalg::IndexOp>(loc, index); |
| return ArithOpBuilder(builder, loc, val); |
| } |
| |
| ArithOpBuilder indexCast(int32_t bitwidth) { |
| if (isa<IntegerType>(value.getType())) { |
| Value cast = builder.create<arith::IndexCastOp>( |
| loc, builder.getIndexType(), value); |
| return ArithOpBuilder(builder, loc, cast); |
| } |
| |
| Value cast = builder.create<arith::IndexCastOp>( |
| loc, builder.getIntegerType(bitwidth), value); |
| return ArithOpBuilder(builder, loc, cast); |
| } |
| |
| ArithOpBuilder rotateLeft(int32_t rotation) { |
| int32_t bits = value.getType().getIntOrFloatBitWidth(); |
| ArithOpBuilder cLeft = constantI(rotation, bits); |
| ArithOpBuilder cRight = constantI(bits - rotation, bits); |
| ArithOpBuilder rLeft = (*this << cLeft); |
| ArithOpBuilder rRight = (*this >> cRight); |
| return rLeft | rRight; |
| } |
| |
| ArithOpBuilder operator+(ArithOpBuilder &rhs) { |
| Value res = builder.create<arith::AddIOp>(loc, value, rhs.value); |
| return ArithOpBuilder(builder, loc, res); |
| } |
| |
| ArithOpBuilder operator*(ArithOpBuilder &rhs) { |
| Value res = builder.create<arith::MulIOp>(loc, value, rhs.value); |
| return ArithOpBuilder(builder, loc, res); |
| } |
| |
| ArithOpBuilder operator|(ArithOpBuilder &rhs) { |
| Value res = builder.create<arith::OrIOp>(loc, value, rhs.value); |
| return ArithOpBuilder(builder, loc, res); |
| } |
| |
| ArithOpBuilder operator^(ArithOpBuilder &rhs) { |
| Value res = builder.create<arith::XOrIOp>(loc, value, rhs.value); |
| return ArithOpBuilder(builder, loc, res); |
| } |
| |
| ArithOpBuilder operator<<(ArithOpBuilder &rhs) { |
| Value shl = builder.create<arith::ShLIOp>(loc, value, rhs.value); |
| return ArithOpBuilder(builder, loc, shl); |
| } |
| |
| ArithOpBuilder operator>>(ArithOpBuilder &rhs) { |
| Value shr = builder.create<arith::ShRUIOp>(loc, value, rhs.value); |
| return ArithOpBuilder(builder, loc, shr); |
| } |
| |
| private: |
| OpBuilder builder; |
| Location loc; |
| Value value; |
| }; |
| |
| std::pair<ArithOpBuilder, ArithOpBuilder> splitI64(ArithOpBuilder i64) { |
| auto low = i64.truncI(32); |
| auto c32 = i64.constantI(/*value=*/32, /*bits=*/64); |
| auto high = (i64 >> c32).truncI(32); |
| return {low, high}; |
| } |
| |
| ArithOpBuilder fuseI32s(ArithOpBuilder low, ArithOpBuilder high) { |
| auto c32 = high.constantI(/*value=*/32, /*bits=*/64); |
| high = high.extendUI(64) << c32; |
| low = low.extendUI(64); |
| return low | high; |
| } |
| |
| // Implements the ThreeFry counter-based PRNG algorithm. |
| // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. |
| // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf |
| std::pair<ArithOpBuilder, ArithOpBuilder> |
| runThreeFry2xi32(ArithOpBuilder key0, ArithOpBuilder key1, |
| ArithOpBuilder initialState) { |
| ArithOpBuilder index = initialState.linalgIndex(0); |
| index = index.indexCast(64); |
| index = index + initialState; |
| |
| // Split into the 2xi32 used for threefry. |
| std::pair<ArithOpBuilder, ArithOpBuilder> input = splitI64(index); |
| ArithOpBuilder input0 = input.first; |
| ArithOpBuilder input1 = input.second; |
| |
| // Magic number and rotation distances specified by the Threefry2x32 |
| // algorithm. |
| llvm::SmallVector<int32_t, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24}; |
| ArithOpBuilder magic = key0.constantI(/*value=*/0x1bd11bda, /*bits=*/32); |
| |
| ArithOpBuilder key2 = magic ^ key0 ^ key1; |
| std::array<ArithOpBuilder, 3> ks{key0, key1, key2}; |
| std::array<ArithOpBuilder, 2> x{input0 + key0, input1 + key1}; |
| |
| // Performs a single round of the Threefry2x32 algorithm, with a rotation |
| // amount 'rotation'. |
| for (int i = 0; i < 5; ++i) { |
| int32_t rot = (4 * i) % rotations.size(); |
| int32_t k1 = (i + 1) % ks.size(); |
| int32_t k2 = (i + 2) % ks.size(); |
| |
| for (int j = 0; j < 4; ++j) { |
| x[0] = x[0] + x[1]; |
| x[1] = x[1].rotateLeft(rotations[rot + j]); |
| x[1] = x[0] ^ x[1]; |
| } |
| |
| ArithOpBuilder c = x[0].constantI(/*value=*/i + 1, /*bits=*/32); |
| x[0] = x[0] + ks[k1]; |
| x[1] = x[1] + ks[k2]; |
| x[1] = x[1] + c; |
| } |
| |
| return std::pair<ArithOpBuilder, ArithOpBuilder>(x[0], x[1]); |
| } |
| |
| // Extract and potentially reconstruct the i32 key-pair as necessary. |
| std::pair<Value, Value> extractKey32(OpBuilder &builder, Location loc, |
| Value store) { |
| auto storeTy = cast<ShapedType>(store.getType()); |
| if (storeTy.getRank() != 1) |
| return {nullptr, nullptr}; |
| |
| Type storeETy = storeTy.getElementType(); |
| IntegerType i32Ty = builder.getIntegerType(32); |
| IntegerType i64Ty = builder.getIntegerType(64); |
| |
| if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { |
| Value idx0 = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1); |
| Value key0 = builder.create<tensor::ExtractOp>(loc, store, idx0); |
| Value key1 = builder.create<tensor::ExtractOp>(loc, store, idx1); |
| key0 = builder.create<arith::BitcastOp>(loc, i32Ty, key0); |
| key1 = builder.create<arith::BitcastOp>(loc, i32Ty, key1); |
| return {key0, key1}; |
| } |
| |
| if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { |
| Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value state = builder.create<tensor::ExtractOp>(loc, store, idx1); |
| Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state); |
| auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); |
| return std::pair<Value, Value>(pair.first, pair.second); |
| } |
| |
| // TODO(#14859): Properly handle 128-bit storage keys. |
| if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { |
| Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 0); |
| Value state = builder.create<tensor::ExtractOp>(loc, store, idx1); |
| Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state); |
| auto pair = splitI64(ArithOpBuilder(builder, loc, cast)); |
| return std::pair<Value, Value>(pair.first, pair.second); |
| } |
| |
| return {nullptr, nullptr}; |
| } |
| |
| // Extract and potentially reconstruct the i64 state as necessary. |
| Value extractState64(OpBuilder &builder, Location loc, Value store) { |
| auto storeTy = cast<ShapedType>(store.getType()); |
| if (storeTy.getRank() != 1) |
| return nullptr; |
| |
| Type storeETy = storeTy.getElementType(); |
| IntegerType i64Ty = builder.getIntegerType(64); |
| |
| if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { |
| Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1); |
| Value state = builder.create<tensor::ExtractOp>(loc, store, idx1); |
| Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state); |
| return cast; |
| } |
| |
| // TODO(#14859): Properly handle 128-bit storage keys. |
| if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { |
| Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1); |
| Value state = builder.create<tensor::ExtractOp>(loc, store, idx1); |
| Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state); |
| return cast; |
| } |
| |
| if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { |
| Value idx2 = builder.create<arith::ConstantIndexOp>(loc, 2); |
| Value idx3 = builder.create<arith::ConstantIndexOp>(loc, 3); |
| |
| Value low = builder.create<tensor::ExtractOp>(loc, store, idx2); |
| Value high = builder.create<tensor::ExtractOp>(loc, store, idx3); |
| |
| ArithOpBuilder i64 = fuseI32s(ArithOpBuilder(builder, loc, high), |
| ArithOpBuilder(builder, loc, low)); |
| return builder.create<arith::BitcastOp>(loc, i64Ty, i64.val()); |
| } |
| |
| return nullptr; |
| } |
| |
| Value setState64(OpBuilder &b, Location loc, Value store, Value state) { |
| auto storeTy = cast<ShapedType>(store.getType()); |
| if (storeTy.getRank() != 1) |
| return nullptr; |
| |
| Type storeETy = storeTy.getElementType(); |
| |
| if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) { |
| state = b.create<arith::BitcastOp>(loc, storeETy, state); |
| Value idx1 = b.create<arith::ConstantIndexOp>(loc, 1); |
| return b.create<tensor::InsertOp>(loc, storeTy, state, store, |
| ValueRange{idx1}); |
| } |
| |
| // TODO(#14859): Properly handle 128-bit storage keys. |
| if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) { |
| state = b.create<arith::BitcastOp>(loc, storeETy, state); |
| Value idx1 = b.create<arith::ConstantIndexOp>(loc, 1); |
| return b.create<tensor::InsertOp>(loc, storeTy, state, store, |
| ValueRange{idx1}); |
| } |
| |
| if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) { |
| Value idx2 = b.create<arith::ConstantIndexOp>(loc, 2); |
| Value idx3 = b.create<arith::ConstantIndexOp>(loc, 3); |
| std::pair<ArithOpBuilder, ArithOpBuilder> states = |
| splitI64(ArithOpBuilder(b, loc, state)); |
| Value state0 = |
| b.create<arith::BitcastOp>(loc, storeETy, states.first.val()); |
| Value state1 = |
| b.create<arith::BitcastOp>(loc, storeETy, states.second.val()); |
| Value insert0 = b.create<tensor::InsertOp>(loc, storeTy, state0, store, |
| ValueRange{idx2}); |
| Value insert1 = b.create<tensor::InsertOp>(loc, storeTy, state1, insert0, |
| ValueRange{idx3}); |
| return insert1; |
| } |
| |
| return nullptr; |
| } |
| |
| Value reshapeToTarget(OpBuilder &builder, Location loc, ShapedType destTy, |
| Value src) { |
| auto srcTy = cast<ShapedType>(src.getType()); |
| // Expand out to the target shape. |
| |
| auto reassociationIndices = |
| getReassociationIndicesForCollapse(destTy.getShape(), srcTy.getShape()); |
| if (reassociationIndices.has_value()) { |
| src = builder.create<tensor::ExpandShapeOp>(loc, destTy, src, |
| reassociationIndices.value()); |
| } |
| |
| // It is also possible our target is Rank-0, then we would |
| // need to collapse. |
| reassociationIndices = |
| getReassociationIndicesForCollapse(srcTy.getShape(), destTy.getShape()); |
| if (reassociationIndices.has_value()) { |
| src = builder.create<tensor::CollapseShapeOp>(loc, destTy, src, |
| reassociationIndices.value()); |
| } |
| |
| return src; |
| } |
| |
| // Compute the shape for computing three fry. |
| std::pair<ShapedType, int64_t> threeFry32Shape(ShapedType resultTy) { |
| if (resultTy.getRank() == 0) { |
| return {resultTy, 0}; |
| } |
| |
| ArrayRef<int64_t> shape = resultTy.getShape(); |
| uint64_t halfDim = |
| std::max_element(shape.begin(), shape.end()) - shape.begin(); |
| |
| for (int i = 0, s = shape.size(); i < s; i++) { |
| if (shape[i] & 0x1) |
| continue; |
| halfDim = i; |
| break; |
| } |
| |
| llvm::SmallVector<int64_t> newShape(shape); |
| newShape[halfDim] = (newShape[halfDim] + 1) / 2; |
| if (halfDim == (newShape.size() - 1)) { |
| newShape.push_back(1); |
| } |
| |
| return {RankedTensorType::get(newShape, resultTy.getElementType()), halfDim}; |
| } |
| |
| /// This implementation generates a 32-bit tensor of ThreeFry random numbers. |
| /// It matches the XLA implementation bit-exact and includes an inefficient |
| /// method of concatenating / slicing the pairs of generated numbers. |
| /// |
| /// We should consider dropping the complex slicing and simply generating |
| /// 2x the values, then downcast to a 32-bit. It substantially simplifies |
| /// the computation and avoids the concat / slice behavior. |
| LogicalResult generateLinalgThreeFry32(OpBuilder &builder, Location loc, |
| ShapedType resultTy, Value &store, |
| Value &result) { |
| Type resultETy = resultTy.getElementType(); |
| |
| // Extract the stateful values as an i64 and increment the state ahead. |
| Value initialState = extractState64(builder, loc, store); |
| if (!initialState) |
| return failure(); |
| |
| std::pair<Value, Value> keys = extractKey32(builder, loc, store); |
| if (!keys.first || !keys.second) |
| return failure(); |
| |
| ArithOpBuilder key0(builder, loc, keys.first); |
| ArithOpBuilder key1(builder, loc, keys.second); |
| |
| // Compute the intermediate type we use to compute three fry values, including |
| // the dimension that was halved. |
| auto pair = threeFry32Shape(resultTy); |
| ShapedType intermediateType = pair.first; |
| int64_t halfDim = pair.second; |
| int64_t count = intermediateType.getNumElements(); |
| |
| // Compute the number of random i64s generated and increment state. |
| Value countVal = |
| builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count)); |
| Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal); |
| |
| // Generate a 1D tensor with for the random values. |
| Value destLeft = builder.create<tensor::EmptyOp>( |
| loc, ArrayRef<int64_t>({count}), resultETy); |
| Value destRight = builder.create<tensor::EmptyOp>( |
| loc, ArrayRef<int64_t>({count}), resultETy); |
| |
| ShapedType destTy = llvm::cast<ShapedType>(destLeft.getType()); |
| |
| SmallVector<AffineMap> indexingMaps(2, builder.getMultiDimIdentityMap(1)); |
| SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel); |
| |
| linalg::GenericOp generic = builder.create<linalg::GenericOp>( |
| loc, TypeRange{destTy, destTy}, |
| /*inputs=*/ValueRange(), |
| /*outputs=*/ValueRange{destLeft, destRight}, |
| /*indexingMaps=*/indexingMaps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange) { |
| // Grab three fry results and write to each array. |
| auto split = runThreeFry2xi32( |
| key0, key1, ArithOpBuilder(b, nestedLoc, initialState)); |
| auto first = split.first.truncI(resultETy.getIntOrFloatBitWidth()); |
| auto second = split.second.truncI(resultETy.getIntOrFloatBitWidth()); |
| b.create<linalg::YieldOp>(loc, ValueRange{first.val(), second.val()}); |
| }); |
| |
| if (resultTy.getNumElements() == 1) { |
| result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); |
| store = setState64(builder, loc, store, newState); |
| return success(); |
| } |
| |
| // Reshape to the target size and concatenate on the dimension following the |
| // half dimension. |
| Value random0 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); |
| Value random1 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); |
| Value concatenate = builder.create<mlir::stablehlo::ConcatenateOp>( |
| loc, ValueRange{random0, random1}, |
| builder.getI64IntegerAttr(halfDim + 1)); |
| |
| // Collapse the concat dimension back into the parent. |
| llvm::SmallVector<int64_t> collapseShape(resultTy.getShape()); |
| collapseShape[halfDim] = |
| collapseShape[halfDim] + (collapseShape[halfDim] & 1); |
| Value reshape = builder.create<mlir::stablehlo::ReshapeOp>( |
| loc, resultTy.clone(collapseShape), concatenate); |
| |
| // Slice to only the required results. |
| llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0); |
| llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1); |
| Value slice = builder.create<mlir::stablehlo::SliceOp>( |
| loc, resultTy, reshape, builder.getDenseI64ArrayAttr(offset), |
| builder.getDenseI64ArrayAttr(resultTy.getShape()), |
| builder.getDenseI64ArrayAttr(stride)); |
| |
| // Set the new tensor values. |
| store = setState64(builder, loc, store, newState); |
| result = slice; |
| |
| return success(); |
| } |
| |
| LogicalResult generateLinalgThreeFry64(OpBuilder &builder, Location loc, |
| ShapedType resultTy, Value &store, |
| Value &result) { |
| Type resultETy = resultTy.getElementType(); |
| int64_t count = resultTy.getNumElements(); |
| |
| // Extract the stateful values as an i64 and increment the state ahead. |
| Value initialState = extractState64(builder, loc, store); |
| if (!initialState) |
| return failure(); |
| |
| std::pair<Value, Value> keys = extractKey32(builder, loc, store); |
| if (!keys.first || !keys.second) |
| return failure(); |
| |
| ArithOpBuilder key0(builder, loc, keys.first); |
| ArithOpBuilder key1(builder, loc, keys.second); |
| |
| // Compute the number of random i64s generated and increment state. |
| Value countVal = |
| builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count)); |
| Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal); |
| |
| // Generate a 1D tensor with for the random values. |
| Value dest = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| ShapedType destTy = llvm::cast<ShapedType>(dest.getType()); |
| |
| SmallVector<AffineMap> indexingMaps(1, builder.getMultiDimIdentityMap(1)); |
| SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel); |
| |
| auto random = builder.create<linalg::GenericOp>( |
| loc, destTy, /*inputs=*/ValueRange(), |
| /*outputs=*/ValueRange{dest}, |
| /*indexingMaps=*/indexingMaps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange) { |
| // Generate three fry results, fuse, and return an |
| // i64. |
| auto split = runThreeFry2xi32( |
| key0, key1, ArithOpBuilder(b, nestedLoc, initialState)); |
| Value result = fuseI32s(split.first, split.second).val(); |
| b.create<linalg::YieldOp>(nestedLoc, result); |
| }); |
| |
| store = setState64(builder, loc, store, newState); |
| result = reshapeToTarget(builder, loc, resultTy, random.getResult(0)); |
| return success(); |
| } |
| |
| using PhiloxKey = std::pair<ArithOpBuilder, ArithOpBuilder>; |
| using PhiloxState = std::array<ArithOpBuilder, 4>; |
| |
| // Computes high and low words from multiplying 32 bit integers. |
| // Per the paper, mulhi and mullo of the same arguments can be computed |
| // Simultaneously in a single instruction on x86 architectures. |
| std::pair<ArithOpBuilder, ArithOpBuilder> multiplyHilo(ArithOpBuilder counter, |
| ArithOpBuilder key) { |
| counter = counter.extendUI(64); |
| key = key.extendUI(64); |
| ArithOpBuilder product = counter * key; |
| ArithOpBuilder ci64 = counter.constantI(/*value=*/32, /*bits=*/64); |
| ArithOpBuilder hi = product >> ci64; |
| hi = hi.truncI(32); |
| product = product.truncI(32); |
| return std::pair<ArithOpBuilder, ArithOpBuilder>{hi, product}; |
| } |
| |
| PhiloxState philoxRound(PhiloxState x, PhiloxKey key) { |
| // These are philox specific constants. |
| ArithOpBuilder m0 = x[0].constantI(0xD2511F53, 32); |
| ArithOpBuilder m1 = x[2].constantI(0xCD9E8D57, 32); |
| std::pair<ArithOpBuilder, ArithOpBuilder> p0 = multiplyHilo(x[0], m0); |
| std::pair<ArithOpBuilder, ArithOpBuilder> p1 = multiplyHilo(x[2], m1); |
| |
| PhiloxState state = {p1.first ^ x[1] ^ key.first, p1.second, |
| p0.first ^ x[3] ^ key.second, p0.second}; |
| return state; |
| } |
| |
| PhiloxKey raiseKey(PhiloxKey key) { |
| // These are philox specific constants. |
| ArithOpBuilder w0 = key.first.constantI(0x9E3779B9, 32); |
| ArithOpBuilder w1 = key.first.constantI(0xBB67AE85, 32); |
| return PhiloxKey{key.first + w0, key.second + w1}; |
| } |
| |
| // Implements the Philox 4x32 counter-based PRNG algorithm. |
| // The Philox PRNG has been proposed in: |
| // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. |
| // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf |
| std::array<ArithOpBuilder, 4> runPhilox4x32(PhiloxKey key, |
| ArithOpBuilder state) { |
| ArithOpBuilder index = state.linalgIndex(0); |
| index = index.indexCast(64); |
| index = index + state; |
| |
| // Split into the 2xi32 used for threefry. |
| std::pair<ArithOpBuilder, ArithOpBuilder> input = splitI64(index); |
| ArithOpBuilder input0 = input.first; |
| ArithOpBuilder input1 = input.second; |
| |
| // We initialize the state as such to match the XLA implementation. |
| PhiloxState state4 = {input0, input1, key.first, key.second}; |
| |
| // We perform 10 rounds to match the XLA implementation. |
| constexpr int kNumRounds = 10; |
| for (int round = 0; round < kNumRounds; ++round, key = raiseKey(key)) { |
| state4 = philoxRound(state4, key); |
| } |
| return state4; |
| } |
| |
| // Generates an array of primitive type U32 with the given shape containing |
| // random bits generated by the Philox algorithm. Returns the array and the new |
| // state of the random number generator. |
| LogicalResult generateLinalgPhilox32(OpBuilder &builder, Location loc, |
| ShapedType resultTy, Value &store, |
| Value &result) { |
| Type resultETy = resultTy.getElementType(); |
| |
| Value initialState = extractState64(builder, loc, store); |
| if (!initialState) |
| return failure(); |
| |
| std::pair<Value, Value> keys = extractKey32(builder, loc, store); |
| if (!keys.first || !keys.second) |
| return failure(); |
| |
| int64_t numElements = resultTy.getNumElements(); |
| int64_t count = (numElements + 3) / 4; |
| ShapedType intermediateType = |
| RankedTensorType::get({count, 1}, resultTy.getElementType()); |
| int64_t concatDim = 1; |
| |
| // Compute the number of random i64s generated and increment state. |
| Value countVal = |
| builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count)); |
| Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal); |
| |
| // set up four outputs |
| Value dest0 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| Value dest1 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| Value dest2 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| Value dest3 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| |
| ShapedType destTy = cast<ShapedType>(dest0.getType()); |
| |
| SmallVector<AffineMap> indexingMaps(4, builder.getMultiDimIdentityMap(1)); |
| SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel); |
| |
| linalg::GenericOp generic = builder.create<linalg::GenericOp>( |
| loc, TypeRange{destTy, destTy, destTy, destTy}, |
| /*inputs=*/ValueRange(), |
| /*outputs=*/ValueRange{dest0, dest1, dest2, dest3}, |
| /*indexingMaps=*/indexingMaps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange) { |
| auto output = |
| runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first), |
| ArithOpBuilder(b, nestedLoc, keys.second)}, |
| ArithOpBuilder(b, nestedLoc, initialState)); |
| auto out0 = output[0].truncI(resultETy.getIntOrFloatBitWidth()); |
| auto out1 = output[1].truncI(resultETy.getIntOrFloatBitWidth()); |
| auto out2 = output[2].truncI(resultETy.getIntOrFloatBitWidth()); |
| auto out3 = output[3].truncI(resultETy.getIntOrFloatBitWidth()); |
| b.create<linalg::YieldOp>( |
| loc, ValueRange{out0.val(), out1.val(), out2.val(), out3.val()}); |
| }); |
| |
| if (resultTy.getNumElements() == 1) { |
| result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); |
| store = setState64(builder, loc, store, newState); |
| return success(); |
| } |
| |
| Value r0 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); |
| Value r1 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); |
| Value r2 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(2)); |
| Value r3 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(3)); |
| |
| Value concatenate = builder.create<mlir::stablehlo::ConcatenateOp>( |
| loc, ValueRange{r0, r1, r2, r3}, builder.getI64IntegerAttr(concatDim)); |
| |
| // Collapse the concat dimension back into the parent. |
| llvm::SmallVector<int64_t> collapseShape(intermediateType.getShape()); |
| collapseShape[0] = collapseShape[0] * 4; |
| Value reshapeIntermediate = builder.create<mlir::stablehlo::ReshapeOp>( |
| loc, resultTy.clone(collapseShape), concatenate); |
| |
| // Slice to only the required results. |
| collapseShape[0] = resultTy.getNumElements(); |
| |
| auto sliceResultTy = intermediateType.clone(collapseShape); |
| llvm::SmallVector<int64_t> offset(sliceResultTy.getRank(), 0); |
| llvm::SmallVector<int64_t> stride(sliceResultTy.getRank(), 1); |
| Value slice = builder.create<mlir::stablehlo::SliceOp>( |
| loc, sliceResultTy, reshapeIntermediate, |
| builder.getDenseI64ArrayAttr(offset), |
| builder.getDenseI64ArrayAttr(collapseShape), |
| builder.getDenseI64ArrayAttr(stride)); |
| Value reshapeResult = |
| builder.create<mlir::stablehlo::ReshapeOp>(loc, resultTy, slice); |
| |
| // Set the new tensor values. |
| store = setState64(builder, loc, store, newState); |
| result = reshapeResult; |
| |
| return success(); |
| } |
| |
| LogicalResult generateLinalgPhilox64(OpBuilder &builder, Location loc, |
| ShapedType resultTy, Value &store, |
| Value &result) { |
| Type resultETy = resultTy.getElementType(); |
| |
| Value initialState = extractState64(builder, loc, store); |
| if (!initialState) |
| return failure(); |
| |
| std::pair<Value, Value> keys = extractKey32(builder, loc, store); |
| if (!keys.first || !keys.second) |
| return failure(); |
| |
| int64_t numElements = resultTy.getNumElements(); |
| int64_t count = (numElements + 1) / 2; |
| ShapedType intermediateType = |
| RankedTensorType::get({count, 1}, resultTy.getElementType()); |
| int64_t concatDim = 1; |
| |
| // Compute the number of random i64s generated and increment state. |
| Value countVal = |
| builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count)); |
| Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal); |
| |
| // set up four outputs |
| Value dest0 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| Value dest1 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}), |
| resultETy); |
| ShapedType destTy = cast<ShapedType>(dest0.getType()); |
| |
| SmallVector<AffineMap> indexingMaps(2, builder.getMultiDimIdentityMap(1)); |
| SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel); |
| |
| linalg::GenericOp generic = builder.create<linalg::GenericOp>( |
| loc, TypeRange{destTy, destTy}, |
| /*inputs=*/ValueRange(), |
| /*outputs=*/ValueRange{dest0, dest1}, |
| /*indexingMaps=*/indexingMaps, iterators, |
| [&](OpBuilder &b, Location nestedLoc, ValueRange) { |
| auto output = |
| runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first), |
| ArithOpBuilder(b, nestedLoc, keys.second)}, |
| ArithOpBuilder(b, nestedLoc, initialState)); |
| auto out0 = output[0]; |
| auto out1 = output[1]; |
| auto out2 = output[2]; |
| auto out3 = output[3]; |
| Value result1 = fuseI32s(out0, out1).val(); |
| Value result2 = fuseI32s(out2, out3).val(); |
| b.create<linalg::YieldOp>(loc, ValueRange{result1, result2}); |
| }); |
| |
| if (resultTy.getNumElements() == 1) { |
| result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0)); |
| store = setState64(builder, loc, store, newState); |
| return success(); |
| } |
| |
| Value r0 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(0)); |
| Value r1 = |
| reshapeToTarget(builder, loc, intermediateType, generic.getResult(1)); |
| Value concatenate = builder.create<mlir::stablehlo::ConcatenateOp>( |
| loc, ValueRange{r0, r1}, builder.getI64IntegerAttr(concatDim)); |
| |
| // Collapse the concat dimension back into the parent. |
| llvm::SmallVector<int64_t> collapseShape(intermediateType.getShape()); |
| collapseShape[0] = collapseShape[0] * 2; |
| Value reshapeIntermediate = builder.create<mlir::stablehlo::ReshapeOp>( |
| loc, resultTy.clone(collapseShape), concatenate); |
| |
| // Slice to only the required results. |
| collapseShape[0] = resultTy.getNumElements(); |
| |
| auto sliceResultTy = intermediateType.clone(collapseShape); |
| llvm::SmallVector<int64_t> offset(sliceResultTy.getRank(), 0); |
| llvm::SmallVector<int64_t> stride(sliceResultTy.getRank(), 1); |
| Value slice = builder.create<mlir::stablehlo::SliceOp>( |
| loc, sliceResultTy, reshapeIntermediate, |
| builder.getDenseI64ArrayAttr(offset), |
| builder.getDenseI64ArrayAttr(collapseShape), |
| builder.getDenseI64ArrayAttr(stride)); |
| Value reshapeResult = |
| builder.create<mlir::stablehlo::ReshapeOp>(loc, resultTy, slice); |
| |
| // Set the new tensor values. |
| store = setState64(builder, loc, store, newState); |
| result = reshapeResult; |
| |
| return success(); |
| } |
| |
| LogicalResult generateLinalgThreeFry(OpBuilder &builder, Location loc, |
| ShapedType resultTy, Value &state, |
| Value &result) { |
| Type eTy = resultTy.getElementType(); |
| unsigned bitwidth = eTy.getIntOrFloatBitWidth(); |
| |
| if (bitwidth == 64) { |
| return generateLinalgThreeFry64(builder, loc, resultTy, state, result); |
| } |
| if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) { |
| return generateLinalgThreeFry32(builder, loc, resultTy, state, result); |
| } |
| |
| return failure(); |
| } |
| |
| LogicalResult generateLinalgPhilox(OpBuilder &builder, Location loc, |
| ShapedType resultTy, Value &state, |
| Value &result) { |
| Type eTy = resultTy.getElementType(); |
| unsigned bitwidth = eTy.getIntOrFloatBitWidth(); |
| if (bitwidth == 64) { |
| return generateLinalgPhilox64(builder, loc, resultTy, state, result); |
| } |
| |
| // The 32 bit implementation trancates to result eTy. |
| if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) { |
| return generateLinalgPhilox32(builder, loc, resultTy, state, result); |
| } |
| |
| return failure(); |
| } |
| |
| struct RngBitGeneratorConverter final |
| : OpConversionPattern<mlir::stablehlo::RngBitGeneratorOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::stablehlo::RngBitGeneratorOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| Location loc = op.getLoc(); |
| Value state = adaptor.getInitialState(); |
| auto resultTy = dyn_cast_or_null<ShapedType>( |
| getTypeConverter()->convertType(op.getResult(1).getType())); |
| if (!resultTy) { |
| return rewriter.notifyMatchFailure(op, "type conversion failed"); |
| } |
| |
| if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::THREE_FRY) { |
| Value random; |
| if (failed( |
| generateLinalgThreeFry(rewriter, loc, resultTy, state, random))) { |
| return failure(); |
| } |
| rewriter.replaceOp(op, {state, random}); |
| return success(); |
| } |
| |
| if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::PHILOX || |
| op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::DEFAULT) { |
| Value random; |
| if (failed( |
| generateLinalgPhilox(rewriter, loc, resultTy, state, random))) { |
| return failure(); |
| } |
| rewriter.replaceOp(op, {state, random}); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| }; |
| |
| struct RngUniformConversion final |
| : OpConversionPattern<mlir::stablehlo::RngOp> { |
| using OpConversionPattern::OpConversionPattern; |
| |
| LogicalResult |
| matchAndRewrite(mlir::stablehlo::RngOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // We only handle uniform distributions. |
| if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::UNIFORM) { |
| return failure(); |
| } |
| // TODO(raikonenfnu): Handle other element types as well. |
| auto minTy = dyn_cast<ShapedType>(adaptor.getA().getType()); |
| auto maxTy = dyn_cast<ShapedType>(adaptor.getB().getType()); |
| if (!isa<FloatType>(minTy.getElementType()) || |
| !isa<FloatType>(maxTy.getElementType())) { |
| return rewriter.notifyMatchFailure( |
| op, "expected min/max for rng op to be FloatType"); |
| } |
| auto targetTy = dyn_cast_or_null<ShapedType>( |
| getTypeConverter()->convertType(op.getResult().getType())); |
| if (!targetTy) { |
| return rewriter.notifyMatchFailure( |
| op, "expected target shape of rng op to be ShapedType"); |
| } |
| auto loc = op.getLoc(); |
| Value emptyTensor = |
| getEmptyTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands()); |
| // Creates index map using target matrix's rank. |
| auto targetRank = targetTy.getRank(); |
| SmallVector<AffineMap, 3> indexingMaps( |
| 2, AffineMap::get(targetRank, /*symbolCount=*/0, |
| SmallVector<AffineExpr>({}), rewriter.getContext())); |
| indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank)); |
| const int kInitialSeed = 0; |
| |
| // Generic region with LCG Algorithm that make use of element index from: |
| // https://reviews.llvm.org/D101364 |
| auto linalgOp = rewriter.create<linalg::GenericOp>( |
| loc, /*resultTensors=*/targetTy, |
| /*inputs=*/ |
| ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]}, |
| /*outputs=*/emptyTensor, indexingMaps, |
| getParallelAndReductionIterators(/*nLoops=*/targetRank, |
| /*nReduction=*/0), |
| [&](OpBuilder &b, Location loc, ValueRange args) { |
| llvm::SmallVector<Value> updateVec = {b.create<arith::ConstantOp>( |
| loc, b.getI32IntegerAttr(kInitialSeed))}; |
| Value multiplier = |
| b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1103515245)); |
| Value incrementStep = |
| b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(12345)); |
| // For output matrix with rank N: |
| // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr |
| // ... |
| // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr |
| for (int i = 0; i < targetRank; i++) { |
| Value update = updateVec.back(); |
| Value ind = b.create<linalg::IndexOp>(loc, i); |
| Value castInd = |
| b.create<arith::IndexCastOp>(loc, b.getI32Type(), ind); |
| Value addRes = b.create<arith::AddIOp>(loc, castInd, update); |
| Value multRes = b.create<arith::MulIOp>(loc, addRes, multiplier); |
| Value incRes = b.create<arith::AddIOp>(loc, multRes, incrementStep); |
| updateVec.push_back(incRes); |
| } |
| // Scaling = (max - min) * const(F64, 2.3283064E-10) |
| // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)). |
| Value epsilon = b.create<arith::ConstantOp>( |
| loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10)); |
| Value range = b.create<arith::SubFOp>(loc, args[1], args[0]); |
| Value scale = b.create<arith::MulFOp>(loc, range, epsilon); |
| // Res = cast(T, cast(F64, tempN) * scaling + min) |
| Value updateCast = b.create<arith::UIToFPOp>( |
| loc, targetTy.getElementType(), updateVec.back()); |
| Value scaleUpdate = b.create<arith::MulFOp>(loc, updateCast, scale); |
| Value res = b.create<arith::AddFOp>(loc, scaleUpdate, args[0]); |
| b.create<linalg::YieldOp>(loc, res); |
| }, |
| linalg::getPrunedAttributeList(op)); |
| rewriter.replaceOp(op, linalgOp.getResults()); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| namespace detail { |
| void populateStableHloRandomToLinalgConversionPatterns( |
| MLIRContext *context, TypeConverter &typeConverter, |
| RewritePatternSet *patterns) { |
| patterns->add<RngBitGeneratorConverter, RngUniformConversion>(typeConverter, |
| context); |
| } |
| } // namespace detail |
| } // namespace mlir::iree_compiler::stablehlo |