// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

// Implements logic for lowering StableHLO random number generation to Linalg
// dialect.

#include "compiler/plugins/input/StableHLO/Conversion/LegalizeToLinalgUtils.h"
#include "compiler/plugins/input/StableHLO/Conversion/Rewriters.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace mlir::iree_compiler::stablehlo {
namespace {
class ArithOpBuilder {
public:
  ArithOpBuilder(OpBuilder b, Location l, Value v)
      : builder(b), loc(l), value(v) {}

  explicit operator Value() { return value; }
  Value val() { return value; }

  ArithOpBuilder constantI(int64_t value, int64_t bits) {
    Value val = builder.create<arith::ConstantOp>(
        loc, builder.getIntegerAttr(builder.getIntegerType(bits), value));
    return ArithOpBuilder(builder, loc, val);
  }

  ArithOpBuilder extendUI(int32_t bits) {
    Value ext = builder.create<arith::ExtUIOp>(
        loc, builder.getIntegerType(bits), value);
    return ArithOpBuilder(builder, loc, ext);
  }

  ArithOpBuilder truncI(int64_t bits) {
    if (value.getType().getIntOrFloatBitWidth() == bits)
      return *this;
    Value trunc = builder.create<arith::TruncIOp>(
        loc, builder.getIntegerType(bits), value);
    return ArithOpBuilder(builder, loc, trunc);
  }

  ArithOpBuilder linalgIndex(int32_t index) {
    Value val = builder.create<linalg::IndexOp>(loc, index);
    return ArithOpBuilder(builder, loc, val);
  }

  ArithOpBuilder indexCast(int32_t bitwidth) {
    if (isa<IntegerType>(value.getType())) {
      Value cast = builder.create<arith::IndexCastOp>(
          loc, builder.getIndexType(), value);
      return ArithOpBuilder(builder, loc, cast);
    }

    Value cast = builder.create<arith::IndexCastOp>(
        loc, builder.getIntegerType(bitwidth), value);
    return ArithOpBuilder(builder, loc, cast);
  }

  ArithOpBuilder rotateLeft(int32_t rotation) {
    int32_t bits = value.getType().getIntOrFloatBitWidth();
    ArithOpBuilder cLeft = constantI(rotation, bits);
    ArithOpBuilder cRight = constantI(bits - rotation, bits);
    ArithOpBuilder rLeft = (*this << cLeft);
    ArithOpBuilder rRight = (*this >> cRight);
    return rLeft | rRight;
  }

  ArithOpBuilder operator+(ArithOpBuilder &rhs) {
    Value res = builder.create<arith::AddIOp>(loc, value, rhs.value);
    return ArithOpBuilder(builder, loc, res);
  }

  ArithOpBuilder operator*(ArithOpBuilder &rhs) {
    Value res = builder.create<arith::MulIOp>(loc, value, rhs.value);
    return ArithOpBuilder(builder, loc, res);
  }

  ArithOpBuilder operator|(ArithOpBuilder &rhs) {
    Value res = builder.create<arith::OrIOp>(loc, value, rhs.value);
    return ArithOpBuilder(builder, loc, res);
  }

  ArithOpBuilder operator^(ArithOpBuilder &rhs) {
    Value res = builder.create<arith::XOrIOp>(loc, value, rhs.value);
    return ArithOpBuilder(builder, loc, res);
  }

  ArithOpBuilder operator<<(ArithOpBuilder &rhs) {
    Value shl = builder.create<arith::ShLIOp>(loc, value, rhs.value);
    return ArithOpBuilder(builder, loc, shl);
  }

  ArithOpBuilder operator>>(ArithOpBuilder &rhs) {
    Value shr = builder.create<arith::ShRUIOp>(loc, value, rhs.value);
    return ArithOpBuilder(builder, loc, shr);
  }

private:
  OpBuilder builder;
  Location loc;
  Value value;
};

std::pair<ArithOpBuilder, ArithOpBuilder> splitI64(ArithOpBuilder i64) {
  auto low = i64.truncI(32);
  auto c32 = i64.constantI(/*value=*/32, /*bits=*/64);
  auto high = (i64 >> c32).truncI(32);
  return {low, high};
}

ArithOpBuilder fuseI32s(ArithOpBuilder low, ArithOpBuilder high) {
  auto c32 = high.constantI(/*value=*/32, /*bits=*/64);
  high = high.extendUI(64) << c32;
  low = low.extendUI(64);
  return low | high;
}

// Implements the ThreeFry counter-based PRNG algorithm.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
std::pair<ArithOpBuilder, ArithOpBuilder>
runThreeFry2xi32(ArithOpBuilder key0, ArithOpBuilder key1,
                 ArithOpBuilder initialState) {
  ArithOpBuilder index = initialState.linalgIndex(0);
  index = index.indexCast(64);
  index = index + initialState;

  // Split into the 2xi32 used for threefry.
  std::pair<ArithOpBuilder, ArithOpBuilder> input = splitI64(index);
  ArithOpBuilder input0 = input.first;
  ArithOpBuilder input1 = input.second;

  // Magic number and rotation distances specified by the Threefry2x32
  // algorithm.
  llvm::SmallVector<int32_t, 8> rotations = {13, 15, 26, 6, 17, 29, 16, 24};
  ArithOpBuilder magic = key0.constantI(/*value=*/0x1bd11bda, /*bits=*/32);

  ArithOpBuilder key2 = magic ^ key0 ^ key1;
  std::array<ArithOpBuilder, 3> ks{key0, key1, key2};
  std::array<ArithOpBuilder, 2> x{input0 + key0, input1 + key1};

  // Performs a single round of the Threefry2x32 algorithm, with a rotation
  // amount 'rotation'.
  for (int i = 0; i < 5; ++i) {
    int32_t rot = (4 * i) % rotations.size();
    int32_t k1 = (i + 1) % ks.size();
    int32_t k2 = (i + 2) % ks.size();

    for (int j = 0; j < 4; ++j) {
      x[0] = x[0] + x[1];
      x[1] = x[1].rotateLeft(rotations[rot + j]);
      x[1] = x[0] ^ x[1];
    }

    ArithOpBuilder c = x[0].constantI(/*value=*/i + 1, /*bits=*/32);
    x[0] = x[0] + ks[k1];
    x[1] = x[1] + ks[k2];
    x[1] = x[1] + c;
  }

  return std::pair<ArithOpBuilder, ArithOpBuilder>(x[0], x[1]);
}

// Extract and potentially reconstruct the i32 key-pair as necessary.
std::pair<Value, Value> extractKey32(OpBuilder &builder, Location loc,
                                     Value store) {
  auto storeTy = cast<ShapedType>(store.getType());
  if (storeTy.getRank() != 1)
    return {nullptr, nullptr};

  Type storeETy = storeTy.getElementType();
  IntegerType i32Ty = builder.getIntegerType(32);
  IntegerType i64Ty = builder.getIntegerType(64);

  if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) {
    Value idx0 = builder.create<arith::ConstantIndexOp>(loc, 0);
    Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1);
    Value key0 = builder.create<tensor::ExtractOp>(loc, store, idx0);
    Value key1 = builder.create<tensor::ExtractOp>(loc, store, idx1);
    key0 = builder.create<arith::BitcastOp>(loc, i32Ty, key0);
    key1 = builder.create<arith::BitcastOp>(loc, i32Ty, key1);
    return {key0, key1};
  }

  if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) {
    Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 0);
    Value state = builder.create<tensor::ExtractOp>(loc, store, idx1);
    Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state);
    auto pair = splitI64(ArithOpBuilder(builder, loc, cast));
    return std::pair<Value, Value>(pair.first, pair.second);
  }

  // TODO(#14859): Properly handle 128-bit storage keys.
  if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) {
    Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 0);
    Value state = builder.create<tensor::ExtractOp>(loc, store, idx1);
    Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state);
    auto pair = splitI64(ArithOpBuilder(builder, loc, cast));
    return std::pair<Value, Value>(pair.first, pair.second);
  }

  return {nullptr, nullptr};
}

// Extract and potentially reconstruct the i64 state as necessary.
Value extractState64(OpBuilder &builder, Location loc, Value store) {
  auto storeTy = cast<ShapedType>(store.getType());
  if (storeTy.getRank() != 1)
    return nullptr;

  Type storeETy = storeTy.getElementType();
  IntegerType i64Ty = builder.getIntegerType(64);

  if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) {
    Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1);
    Value state = builder.create<tensor::ExtractOp>(loc, store, idx1);
    Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state);
    return cast;
  }

  // TODO(#14859): Properly handle 128-bit storage keys.
  if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) {
    Value idx1 = builder.create<arith::ConstantIndexOp>(loc, 1);
    Value state = builder.create<tensor::ExtractOp>(loc, store, idx1);
    Value cast = builder.create<arith::BitcastOp>(loc, i64Ty, state);
    return cast;
  }

  if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) {
    Value idx2 = builder.create<arith::ConstantIndexOp>(loc, 2);
    Value idx3 = builder.create<arith::ConstantIndexOp>(loc, 3);

    Value low = builder.create<tensor::ExtractOp>(loc, store, idx2);
    Value high = builder.create<tensor::ExtractOp>(loc, store, idx3);

    ArithOpBuilder i64 = fuseI32s(ArithOpBuilder(builder, loc, high),
                                  ArithOpBuilder(builder, loc, low));
    return builder.create<arith::BitcastOp>(loc, i64Ty, i64.val());
  }

  return nullptr;
}

Value setState64(OpBuilder &b, Location loc, Value store, Value state) {
  auto storeTy = cast<ShapedType>(store.getType());
  if (storeTy.getRank() != 1)
    return nullptr;

  Type storeETy = storeTy.getElementType();

  if (storeTy.getDimSize(0) == 2 && storeETy.isInteger(64)) {
    state = b.create<arith::BitcastOp>(loc, storeETy, state);
    Value idx1 = b.create<arith::ConstantIndexOp>(loc, 1);
    return b.create<tensor::InsertOp>(loc, storeTy, state, store,
                                      ValueRange{idx1});
  }

  // TODO(#14859): Properly handle 128-bit storage keys.
  if (storeTy.getDimSize(0) == 3 && storeETy.isInteger(64)) {
    state = b.create<arith::BitcastOp>(loc, storeETy, state);
    Value idx1 = b.create<arith::ConstantIndexOp>(loc, 1);
    return b.create<tensor::InsertOp>(loc, storeTy, state, store,
                                      ValueRange{idx1});
  }

  if (storeTy.getDimSize(0) == 4 && storeETy.isInteger(32)) {
    Value idx2 = b.create<arith::ConstantIndexOp>(loc, 2);
    Value idx3 = b.create<arith::ConstantIndexOp>(loc, 3);
    std::pair<ArithOpBuilder, ArithOpBuilder> states =
        splitI64(ArithOpBuilder(b, loc, state));
    Value state0 =
        b.create<arith::BitcastOp>(loc, storeETy, states.first.val());
    Value state1 =
        b.create<arith::BitcastOp>(loc, storeETy, states.second.val());
    Value insert0 = b.create<tensor::InsertOp>(loc, storeTy, state0, store,
                                               ValueRange{idx2});
    Value insert1 = b.create<tensor::InsertOp>(loc, storeTy, state1, insert0,
                                               ValueRange{idx3});
    return insert1;
  }

  return nullptr;
}

Value reshapeToTarget(OpBuilder &builder, Location loc, ShapedType destTy,
                      Value src) {
  auto srcTy = cast<ShapedType>(src.getType());
  // Expand out to the target shape.

  auto reassociationIndices =
      getReassociationIndicesForCollapse(destTy.getShape(), srcTy.getShape());
  if (reassociationIndices.has_value()) {
    src = builder.create<tensor::ExpandShapeOp>(loc, destTy, src,
                                                reassociationIndices.value());
  }

  // It is also possible our target is Rank-0, then we would
  // need to collapse.
  reassociationIndices =
      getReassociationIndicesForCollapse(srcTy.getShape(), destTy.getShape());
  if (reassociationIndices.has_value()) {
    src = builder.create<tensor::CollapseShapeOp>(loc, destTy, src,
                                                  reassociationIndices.value());
  }

  return src;
}

// Compute the shape for computing three fry.
std::pair<ShapedType, int64_t> threeFry32Shape(ShapedType resultTy) {
  if (resultTy.getRank() == 0) {
    return {resultTy, 0};
  }

  ArrayRef<int64_t> shape = resultTy.getShape();
  uint64_t halfDim =
      std::max_element(shape.begin(), shape.end()) - shape.begin();

  for (int i = 0, s = shape.size(); i < s; i++) {
    if (shape[i] & 0x1)
      continue;
    halfDim = i;
    break;
  }

  llvm::SmallVector<int64_t> newShape(shape);
  newShape[halfDim] = (newShape[halfDim] + 1) / 2;
  if (halfDim == (newShape.size() - 1)) {
    newShape.push_back(1);
  }

  return {RankedTensorType::get(newShape, resultTy.getElementType()), halfDim};
}

/// This implementation generates a 32-bit tensor of ThreeFry random numbers.
/// It matches the XLA implementation bit-exact and includes an inefficient
/// method of concatenating / slicing the pairs of generated numbers.
///
/// We should consider dropping the complex slicing and simply generating
/// 2x the values, then downcast to a 32-bit. It substantially simplifies
/// the computation and avoids the concat / slice behavior.
LogicalResult generateLinalgThreeFry32(OpBuilder &builder, Location loc,
                                       ShapedType resultTy, Value &store,
                                       Value &result) {
  Type resultETy = resultTy.getElementType();

  // Extract the stateful values as an i64 and increment the state ahead.
  Value initialState = extractState64(builder, loc, store);
  if (!initialState)
    return failure();

  std::pair<Value, Value> keys = extractKey32(builder, loc, store);
  if (!keys.first || !keys.second)
    return failure();

  ArithOpBuilder key0(builder, loc, keys.first);
  ArithOpBuilder key1(builder, loc, keys.second);

  // Compute the intermediate type we use to compute three fry values, including
  // the dimension that was halved.
  auto pair = threeFry32Shape(resultTy);
  ShapedType intermediateType = pair.first;
  int64_t halfDim = pair.second;
  int64_t count = intermediateType.getNumElements();

  // Compute the number of random i64s generated and increment state.
  Value countVal =
      builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count));
  Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal);

  // Generate a 1D tensor with for the random values.
  Value destLeft = builder.create<tensor::EmptyOp>(
      loc, ArrayRef<int64_t>({count}), resultETy);
  Value destRight = builder.create<tensor::EmptyOp>(
      loc, ArrayRef<int64_t>({count}), resultETy);

  ShapedType destTy = llvm::cast<ShapedType>(destLeft.getType());

  SmallVector<AffineMap> indexingMaps(2, builder.getMultiDimIdentityMap(1));
  SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel);

  linalg::GenericOp generic = builder.create<linalg::GenericOp>(
      loc, TypeRange{destTy, destTy},
      /*inputs=*/ValueRange(),
      /*outputs=*/ValueRange{destLeft, destRight},
      /*indexingMaps=*/indexingMaps, iterators,
      [&](OpBuilder &b, Location nestedLoc, ValueRange) {
        // Grab three fry results and write to each array.
        auto split = runThreeFry2xi32(
            key0, key1, ArithOpBuilder(b, nestedLoc, initialState));
        auto first = split.first.truncI(resultETy.getIntOrFloatBitWidth());
        auto second = split.second.truncI(resultETy.getIntOrFloatBitWidth());
        b.create<linalg::YieldOp>(loc, ValueRange{first.val(), second.val()});
      });

  if (resultTy.getNumElements() == 1) {
    result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0));
    store = setState64(builder, loc, store, newState);
    return success();
  }

  // Reshape to the target size and concatenate on the dimension following the
  // half dimension.
  Value random0 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(0));
  Value random1 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(1));
  Value concatenate = builder.create<mlir::stablehlo::ConcatenateOp>(
      loc, ValueRange{random0, random1},
      builder.getI64IntegerAttr(halfDim + 1));

  // Collapse the concat dimension back into the parent.
  llvm::SmallVector<int64_t> collapseShape(resultTy.getShape());
  collapseShape[halfDim] =
      collapseShape[halfDim] + (collapseShape[halfDim] & 1);
  Value reshape = builder.create<mlir::stablehlo::ReshapeOp>(
      loc, resultTy.clone(collapseShape), concatenate);

  // Slice to only the required results.
  llvm::SmallVector<int64_t> offset(resultTy.getRank(), 0);
  llvm::SmallVector<int64_t> stride(resultTy.getRank(), 1);
  Value slice = builder.create<mlir::stablehlo::SliceOp>(
      loc, resultTy, reshape, builder.getDenseI64ArrayAttr(offset),
      builder.getDenseI64ArrayAttr(resultTy.getShape()),
      builder.getDenseI64ArrayAttr(stride));

  // Set the new tensor values.
  store = setState64(builder, loc, store, newState);
  result = slice;

  return success();
}

LogicalResult generateLinalgThreeFry64(OpBuilder &builder, Location loc,
                                       ShapedType resultTy, Value &store,
                                       Value &result) {
  Type resultETy = resultTy.getElementType();
  int64_t count = resultTy.getNumElements();

  // Extract the stateful values as an i64 and increment the state ahead.
  Value initialState = extractState64(builder, loc, store);
  if (!initialState)
    return failure();

  std::pair<Value, Value> keys = extractKey32(builder, loc, store);
  if (!keys.first || !keys.second)
    return failure();

  ArithOpBuilder key0(builder, loc, keys.first);
  ArithOpBuilder key1(builder, loc, keys.second);

  // Compute the number of random i64s generated and increment state.
  Value countVal =
      builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count));
  Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal);

  // Generate a 1D tensor with for the random values.
  Value dest = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                               resultETy);
  ShapedType destTy = llvm::cast<ShapedType>(dest.getType());

  SmallVector<AffineMap> indexingMaps(1, builder.getMultiDimIdentityMap(1));
  SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel);

  auto random = builder.create<linalg::GenericOp>(
      loc, destTy, /*inputs=*/ValueRange(),
      /*outputs=*/ValueRange{dest},
      /*indexingMaps=*/indexingMaps, iterators,
      [&](OpBuilder &b, Location nestedLoc, ValueRange) {
        // Generate three fry results, fuse, and return an
        // i64.
        auto split = runThreeFry2xi32(
            key0, key1, ArithOpBuilder(b, nestedLoc, initialState));
        Value result = fuseI32s(split.first, split.second).val();
        b.create<linalg::YieldOp>(nestedLoc, result);
      });

  store = setState64(builder, loc, store, newState);
  result = reshapeToTarget(builder, loc, resultTy, random.getResult(0));
  return success();
}

using PhiloxKey = std::pair<ArithOpBuilder, ArithOpBuilder>;
using PhiloxState = std::array<ArithOpBuilder, 4>;

// Computes high and low words from multiplying 32 bit integers.
// Per the paper, mulhi and mullo of the same arguments can be computed
// Simultaneously in a single instruction on x86 architectures.
std::pair<ArithOpBuilder, ArithOpBuilder> multiplyHilo(ArithOpBuilder counter,
                                                       ArithOpBuilder key) {
  counter = counter.extendUI(64);
  key = key.extendUI(64);
  ArithOpBuilder product = counter * key;
  ArithOpBuilder ci64 = counter.constantI(/*value=*/32, /*bits=*/64);
  ArithOpBuilder hi = product >> ci64;
  hi = hi.truncI(32);
  product = product.truncI(32);
  return std::pair<ArithOpBuilder, ArithOpBuilder>{hi, product};
}

PhiloxState philoxRound(PhiloxState x, PhiloxKey key) {
  // These are philox specific constants.
  ArithOpBuilder m0 = x[0].constantI(0xD2511F53, 32);
  ArithOpBuilder m1 = x[2].constantI(0xCD9E8D57, 32);
  std::pair<ArithOpBuilder, ArithOpBuilder> p0 = multiplyHilo(x[0], m0);
  std::pair<ArithOpBuilder, ArithOpBuilder> p1 = multiplyHilo(x[2], m1);

  PhiloxState state = {p1.first ^ x[1] ^ key.first, p1.second,
                       p0.first ^ x[3] ^ key.second, p0.second};
  return state;
}

PhiloxKey raiseKey(PhiloxKey key) {
  // These are philox specific constants.
  ArithOpBuilder w0 = key.first.constantI(0x9E3779B9, 32);
  ArithOpBuilder w1 = key.first.constantI(0xBB67AE85, 32);
  return PhiloxKey{key.first + w0, key.second + w1};
}

// Implements the Philox 4x32 counter-based PRNG algorithm.
// The Philox PRNG has been proposed in:
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
std::array<ArithOpBuilder, 4> runPhilox4x32(PhiloxKey key,
                                            ArithOpBuilder state) {
  ArithOpBuilder index = state.linalgIndex(0);
  index = index.indexCast(64);
  index = index + state;

  // Split into the 2xi32 used for threefry.
  std::pair<ArithOpBuilder, ArithOpBuilder> input = splitI64(index);
  ArithOpBuilder input0 = input.first;
  ArithOpBuilder input1 = input.second;

  // We initialize the state as such to match the XLA implementation.
  PhiloxState state4 = {input0, input1, key.first, key.second};

  // We perform 10 rounds to match the XLA implementation.
  constexpr int kNumRounds = 10;
  for (int round = 0; round < kNumRounds; ++round, key = raiseKey(key)) {
    state4 = philoxRound(state4, key);
  }
  return state4;
}

// Generates an array of primitive type U32 with the given shape containing
// random bits generated by the Philox algorithm. Returns the array and the new
// state of the random number generator.
LogicalResult generateLinalgPhilox32(OpBuilder &builder, Location loc,
                                     ShapedType resultTy, Value &store,
                                     Value &result) {
  Type resultETy = resultTy.getElementType();

  Value initialState = extractState64(builder, loc, store);
  if (!initialState)
    return failure();

  std::pair<Value, Value> keys = extractKey32(builder, loc, store);
  if (!keys.first || !keys.second)
    return failure();

  int64_t numElements = resultTy.getNumElements();
  int64_t count = (numElements + 3) / 4;
  ShapedType intermediateType =
      RankedTensorType::get({count, 1}, resultTy.getElementType());
  int64_t concatDim = 1;

  // Compute the number of random i64s generated and increment state.
  Value countVal =
      builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count));
  Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal);

  // set up four outputs
  Value dest0 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                                resultETy);
  Value dest1 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                                resultETy);
  Value dest2 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                                resultETy);
  Value dest3 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                                resultETy);

  ShapedType destTy = dest0.getType().cast<ShapedType>();

  SmallVector<AffineMap> indexingMaps(4, builder.getMultiDimIdentityMap(1));
  SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel);

  linalg::GenericOp generic = builder.create<linalg::GenericOp>(
      loc, TypeRange{destTy, destTy, destTy, destTy},
      /*inputs=*/ValueRange(),
      /*outputs=*/ValueRange{dest0, dest1, dest2, dest3},
      /*indexingMaps=*/indexingMaps, iterators,
      [&](OpBuilder &b, Location nestedLoc, ValueRange) {
        auto output =
            runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first),
                                    ArithOpBuilder(b, nestedLoc, keys.second)},
                          ArithOpBuilder(b, nestedLoc, initialState));
        auto out0 = output[0].truncI(resultETy.getIntOrFloatBitWidth());
        auto out1 = output[1].truncI(resultETy.getIntOrFloatBitWidth());
        auto out2 = output[2].truncI(resultETy.getIntOrFloatBitWidth());
        auto out3 = output[3].truncI(resultETy.getIntOrFloatBitWidth());
        b.create<linalg::YieldOp>(
            loc, ValueRange{out0.val(), out1.val(), out2.val(), out3.val()});
      });

  if (resultTy.getNumElements() == 1) {
    result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0));
    store = setState64(builder, loc, store, newState);
    return success();
  }

  Value r0 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(0));
  Value r1 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(1));
  Value r2 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(2));
  Value r3 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(3));

  Value concatenate = builder.create<mlir::stablehlo::ConcatenateOp>(
      loc, ValueRange{r0, r1, r2, r3}, builder.getI64IntegerAttr(concatDim));

  // Collapse the concat dimension back into the parent.
  llvm::SmallVector<int64_t> collapseShape(intermediateType.getShape());
  collapseShape[0] = collapseShape[0] * 4;
  Value reshapeIntermediate = builder.create<mlir::stablehlo::ReshapeOp>(
      loc, resultTy.clone(collapseShape), concatenate);

  // Slice to only the required results.
  collapseShape[0] = resultTy.getNumElements();

  auto sliceResultTy = intermediateType.clone(collapseShape);
  llvm::SmallVector<int64_t> offset(sliceResultTy.getRank(), 0);
  llvm::SmallVector<int64_t> stride(sliceResultTy.getRank(), 1);
  Value slice = builder.create<mlir::stablehlo::SliceOp>(
      loc, sliceResultTy, reshapeIntermediate,
      builder.getDenseI64ArrayAttr(offset),
      builder.getDenseI64ArrayAttr(collapseShape),
      builder.getDenseI64ArrayAttr(stride));
  Value reshapeResult =
      builder.create<mlir::stablehlo::ReshapeOp>(loc, resultTy, slice);

  // Set the new tensor values.
  store = setState64(builder, loc, store, newState);
  result = reshapeResult;

  return success();
}

LogicalResult generateLinalgPhilox64(OpBuilder &builder, Location loc,
                                     ShapedType resultTy, Value &store,
                                     Value &result) {
  Type resultETy = resultTy.getElementType();

  Value initialState = extractState64(builder, loc, store);
  if (!initialState)
    return failure();

  std::pair<Value, Value> keys = extractKey32(builder, loc, store);
  if (!keys.first || !keys.second)
    return failure();

  int64_t numElements = resultTy.getNumElements();
  int64_t count = (numElements + 1) / 2;
  ShapedType intermediateType =
      RankedTensorType::get({count, 1}, resultTy.getElementType());
  int64_t concatDim = 1;

  // Compute the number of random i64s generated and increment state.
  Value countVal =
      builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(count));
  Value newState = builder.create<arith::AddIOp>(loc, initialState, countVal);

  // set up four outputs
  Value dest0 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                                resultETy);
  Value dest1 = builder.create<tensor::EmptyOp>(loc, ArrayRef<int64_t>({count}),
                                                resultETy);
  ShapedType destTy = dest0.getType().cast<ShapedType>();

  SmallVector<AffineMap> indexingMaps(2, builder.getMultiDimIdentityMap(1));
  SmallVector<utils::IteratorType> iterators(1, utils::IteratorType::parallel);

  linalg::GenericOp generic = builder.create<linalg::GenericOp>(
      loc, TypeRange{destTy, destTy},
      /*inputs=*/ValueRange(),
      /*outputs=*/ValueRange{dest0, dest1},
      /*indexingMaps=*/indexingMaps, iterators,
      [&](OpBuilder &b, Location nestedLoc, ValueRange) {
        auto output =
            runPhilox4x32(PhiloxKey{ArithOpBuilder(b, nestedLoc, keys.first),
                                    ArithOpBuilder(b, nestedLoc, keys.second)},
                          ArithOpBuilder(b, nestedLoc, initialState));
        auto out0 = output[0];
        auto out1 = output[1];
        auto out2 = output[2];
        auto out3 = output[3];
        Value result1 = fuseI32s(out0, out1).val();
        Value result2 = fuseI32s(out2, out3).val();
        b.create<linalg::YieldOp>(loc, ValueRange{result1, result2});
      });

  if (resultTy.getNumElements() == 1) {
    result = reshapeToTarget(builder, loc, resultTy, generic.getResult(0));
    store = setState64(builder, loc, store, newState);
    return success();
  }

  Value r0 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(0));
  Value r1 =
      reshapeToTarget(builder, loc, intermediateType, generic.getResult(1));
  Value concatenate = builder.create<mlir::stablehlo::ConcatenateOp>(
      loc, ValueRange{r0, r1}, builder.getI64IntegerAttr(concatDim));

  // Collapse the concat dimension back into the parent.
  llvm::SmallVector<int64_t> collapseShape(intermediateType.getShape());
  collapseShape[0] = collapseShape[0] * 2;
  Value reshapeIntermediate = builder.create<mlir::stablehlo::ReshapeOp>(
      loc, resultTy.clone(collapseShape), concatenate);

  // Slice to only the required results.
  collapseShape[0] = resultTy.getNumElements();

  auto sliceResultTy = intermediateType.clone(collapseShape);
  llvm::SmallVector<int64_t> offset(sliceResultTy.getRank(), 0);
  llvm::SmallVector<int64_t> stride(sliceResultTy.getRank(), 1);
  Value slice = builder.create<mlir::stablehlo::SliceOp>(
      loc, sliceResultTy, reshapeIntermediate,
      builder.getDenseI64ArrayAttr(offset),
      builder.getDenseI64ArrayAttr(collapseShape),
      builder.getDenseI64ArrayAttr(stride));
  Value reshapeResult =
      builder.create<mlir::stablehlo::ReshapeOp>(loc, resultTy, slice);

  // Set the new tensor values.
  store = setState64(builder, loc, store, newState);
  result = reshapeResult;

  return success();
}

LogicalResult generateLinalgThreeFry(OpBuilder &builder, Location loc,
                                     ShapedType resultTy, Value &state,
                                     Value &result) {
  Type eTy = resultTy.getElementType();
  unsigned bitwidth = eTy.getIntOrFloatBitWidth();

  if (bitwidth == 64) {
    return generateLinalgThreeFry64(builder, loc, resultTy, state, result);
  }
  if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) {
    return generateLinalgThreeFry32(builder, loc, resultTy, state, result);
  }

  return failure();
}

LogicalResult generateLinalgPhilox(OpBuilder &builder, Location loc,
                                   ShapedType resultTy, Value &state,
                                   Value &result) {
  Type eTy = resultTy.getElementType();
  unsigned bitwidth = eTy.getIntOrFloatBitWidth();
  if (bitwidth == 64) {
    return generateLinalgPhilox64(builder, loc, resultTy, state, result);
  }

  // The 32 bit implementation trancates to result eTy.
  if (bitwidth == 32 || bitwidth == 16 || bitwidth == 8) {
    return generateLinalgPhilox32(builder, loc, resultTy, state, result);
  }

  return failure();
}

struct RngBitGeneratorConverter final
    : OpConversionPattern<mlir::stablehlo::RngBitGeneratorOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::stablehlo::RngBitGeneratorOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    Value state = adaptor.getInitialState();
    auto resultTy = dyn_cast_or_null<ShapedType>(
        getTypeConverter()->convertType(op.getResult(1).getType()));
    if (!resultTy) {
      return rewriter.notifyMatchFailure(op, "type conversion failed");
    }

    if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::THREE_FRY) {
      Value random;
      if (failed(
              generateLinalgThreeFry(rewriter, loc, resultTy, state, random))) {
        return failure();
      }
      rewriter.replaceOp(op, {state, random});
      return success();
    }

    if (op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::PHILOX ||
        op.getRngAlgorithm() == mlir::stablehlo::RngAlgorithm::DEFAULT) {
      Value random;
      if (failed(
              generateLinalgPhilox(rewriter, loc, resultTy, state, random))) {
        return failure();
      }
      rewriter.replaceOp(op, {state, random});
      return success();
    }

    return failure();
  }
};

struct RngUniformConversion final
    : OpConversionPattern<mlir::stablehlo::RngOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(mlir::stablehlo::RngOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // We only handle uniform distributions.
    if (op.getRngDistribution() != mlir::stablehlo::RngDistribution::UNIFORM) {
      return failure();
    }
    // TODO(raikonenfnu): Handle other element types as well.
    auto minTy = dyn_cast<ShapedType>(adaptor.getA().getType());
    auto maxTy = dyn_cast<ShapedType>(adaptor.getB().getType());
    if (!isa<FloatType>(minTy.getElementType()) ||
        !isa<FloatType>(maxTy.getElementType())) {
      return rewriter.notifyMatchFailure(
          op, "expected min/max for rng op to be FloatType");
    }
    auto targetTy = dyn_cast_or_null<ShapedType>(
        getTypeConverter()->convertType(op.getResult().getType()));
    if (!targetTy) {
      return rewriter.notifyMatchFailure(
          op, "expected target shape of rng op to be ShapedType");
    }
    auto loc = op.getLoc();
    Value emptyTensor =
        getEmptyTensorFor(rewriter, loc, targetTy, op, adaptor.getOperands());
    // Creates index map using target matrix's rank.
    auto targetRank = targetTy.getRank();
    SmallVector<AffineMap, 3> indexingMaps(
        2, AffineMap::get(targetRank, /*symbolCount=*/0,
                          SmallVector<AffineExpr>({}), rewriter.getContext()));
    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(targetRank));
    const int kInitialSeed = 0;

    // Generic region with LCG Algorithm that make use of element index from:
    // https://reviews.llvm.org/D101364
    auto linalgOp = rewriter.create<linalg::GenericOp>(
        loc, /*resultTensors=*/targetTy,
        /*inputs=*/
        ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]},
        /*outputs=*/emptyTensor, indexingMaps,
        getParallelAndReductionIterators(/*nLoops=*/targetRank,
                                         /*nReduction=*/0),
        [&](OpBuilder &b, Location loc, ValueRange args) {
          llvm::SmallVector<Value> updateVec = {b.create<arith::ConstantOp>(
              loc, b.getI32IntegerAttr(kInitialSeed))};
          Value multiplier =
              b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1103515245));
          Value incrementStep =
              b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(12345));
          // For output matrix with rank N:
          // temp1 = (cast(I32, index(D.0)) + seed) * mult + incr
          // ...
          // tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr
          for (int i = 0; i < targetRank; i++) {
            Value update = updateVec.back();
            Value ind = b.create<linalg::IndexOp>(loc, i);
            Value castInd =
                b.create<arith::IndexCastOp>(loc, b.getI32Type(), ind);
            Value addRes = b.create<arith::AddIOp>(loc, castInd, update);
            Value multRes = b.create<arith::MulIOp>(loc, addRes, multiplier);
            Value incRes = b.create<arith::AddIOp>(loc, multRes, incrementStep);
            updateVec.push_back(incRes);
          }
          // Scaling = (max - min) * const(F64, 2.3283064E-10)
          // which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)).
          Value epsilon = b.create<arith::ConstantOp>(
              loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10));
          Value range = b.create<arith::SubFOp>(loc, args[1], args[0]);
          Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
          // Res = cast(T, cast(F64, tempN) * scaling + min)
          Value updateCast = b.create<arith::UIToFPOp>(
              loc, targetTy.getElementType(), updateVec.back());
          Value scaleUpdate = b.create<arith::MulFOp>(loc, updateCast, scale);
          Value res = b.create<arith::AddFOp>(loc, scaleUpdate, args[0]);
          b.create<linalg::YieldOp>(loc, res);
        },
        linalg::getPrunedAttributeList(op));
    rewriter.replaceOp(op, linalgOp.getResults());
    return success();
  }
};
} // namespace

namespace detail {
void populateStableHloRandomToLinalgConversionPatterns(
    MLIRContext *context, TypeConverter &typeConverter,
    RewritePatternSet *patterns) {
  patterns->add<RngBitGeneratorConverter, RngUniformConversion>(typeConverter,
                                                                context);
}
} // namespace detail
} // namespace mlir::iree_compiler::stablehlo
