Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1233,8 +1233,8 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol,
//===----------------------------------------------------------------------===//

def LoadOp : MemRef_Op<"load",
[TypesMatchWith<"result type matches element type of 'memref'",
"memref", "result",
[TypesMatchWith<"result type matches element type of 'base'",
"base", "result",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
Expand Down Expand Up @@ -1273,7 +1273,7 @@ def LoadOp : MemRef_Op<"load",
}];

let arguments = (ins Arg<AnyMemRef, "the reference to load from",
[MemRead]>:$memref,
[MemRead]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
Expand Down Expand Up @@ -1310,6 +1310,8 @@ def LoadOp : MemRef_Op<"load",
let results = (outs AnyType:$result);

let extraClassDeclaration = [{
::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); }
::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); }
Value getMemRef() { return getOperand(0); }
void setMemRef(Value value) { setOperand(0, value); }
MemRefType getMemRefType() {
Expand All @@ -1320,7 +1322,7 @@ def LoadOp : MemRef_Op<"load",
let hasFolder = 1;
let hasVerifier = 1;

let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base)";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2010,8 +2012,8 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
//===----------------------------------------------------------------------===//

def MemRef_StoreOp : MemRef_Op<"store",
[TypesMatchWith<"type of 'value' matches element type of 'memref'",
"memref", "value",
[TypesMatchWith<"type of 'valueToStore' matches element type of 'base'",
"base", "valueToStore",
"::llvm::cast<MemRefType>($_self).getElementType()">,
MemRefsNormalizable,
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
Expand Down Expand Up @@ -2046,9 +2048,9 @@ def MemRef_StoreOp : MemRef_Op<"store",
```
}];

let arguments = (ins AnyType:$value,
let arguments = (ins AnyType:$valueToStore,
Arg<AnyMemRef, "the reference to store to",
[MemWrite]>:$memref,
[MemWrite]>:$base,
Variadic<Index>:$indices,
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal,
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment);
Expand All @@ -2070,7 +2072,10 @@ def MemRef_StoreOp : MemRef_Op<"store",
];

let extraClassDeclaration = [{
Value getValueToStore() { return getOperand(0); }
::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); }
::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); }

Value getValue() { return getOperand(0); }

Value getMemRef() { return getOperand(1); }
void setMemRef(Value value) { setOperand(1, value); }
Expand All @@ -2083,7 +2088,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
let hasVerifier = 1;

let assemblyFormat = [{
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
$valueToStore `,` $base `[` $indices `]` attr-dict `:` type($base)
}];
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
}

auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getBase());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}
Expand All @@ -362,15 +362,15 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
auto arrayValue =
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
dyn_cast<TypedValue<emitc::ArrayType>>(operands.getBase());
if (!arrayValue) {
return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
}

auto subscript = emitc::SubscriptOp::create(
rewriter, op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
operands.getValueToStore());
return success();
}
};
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
// Per memref.load spec, the indices must be in-bounds:
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type,
adaptor.getMemref(),
adaptor.getIndices(), kNoWrapFlags);
Value dataPtr =
getStridedElementPtr(rewriter, loadOp.getLoc(), type, adaptor.getBase(),
adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
loadOp, typeConverter->convertType(type.getElementType()), dataPtr,
loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal());
Expand All @@ -965,11 +965,11 @@ struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
// 0 <= idx < dim_size, and additionally all offsets are non-negative,
// hence inbounds and nuw are used when lowering to llvm.getelementptr.
Value dataPtr =
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(),
getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getBase(),
adaptor.getIndices(), kNoWrapFlags);
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, adaptor.getValue(), dataPtr,
op.getAlignment().value_or(0),
false, op.getNontemporal());
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(
op, adaptor.getValueToStore(), dataPtr, op.getAlignment().value_or(0),
false, op.getNontemporal());
return success();
}
};
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,

const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
adaptor.getIndices(), loc, rewriter);

if (!accessChain)
Expand Down Expand Up @@ -682,7 +682,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
"failed to lower memref in image storage class to storage buffer");

Value loadPtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getBase(),
adaptor.getIndices(), loadOp.getLoc(), rewriter);

if (!loadPtr)
Expand Down Expand Up @@ -743,7 +743,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
return rewriter.notifyMatchFailure(
loadOp, "failed to lower memref in non-image storage class to image");

Value loadPtr = adaptor.getMemref();
Value loadPtr = adaptor.getBase();
auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp);
if (failed(memoryRequirements))
return rewriter.notifyMatchFailure(
Expand Down Expand Up @@ -824,7 +824,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
auto loc = storeOp.getLoc();
auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(),
spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
adaptor.getIndices(), loc, rewriter);

if (!accessChain)
Expand Down Expand Up @@ -874,7 +874,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
storeOp, "failed to determine memory requirements");

auto [memoryAccess, alignment] = *memoryRequirements;
Value storeVal = adaptor.getValue();
Value storeVal = adaptor.getValueToStore();
if (isBool)
storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, accessChain, storeVal,
Expand Down Expand Up @@ -915,7 +915,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
clearBitsMask =
rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);

Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value storeVal =
shiftValue(loc, adaptor.getValueToStore(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
srcBits, dstBits, rewriter);
std::optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
Expand Down Expand Up @@ -1020,7 +1021,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
if (memrefType.getElementType().isSignlessInteger())
return rewriter.notifyMatchFailure(storeOp, "signless int");
auto storePtr = spirv::getElementPtr(
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getMemref(),
*getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.getBase(),
adaptor.getIndices(), storeOp.getLoc(), rewriter);

if (!storePtr)
Expand All @@ -1033,7 +1034,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

auto [memoryAccess, alignment] = *memoryRequirements;
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment);
storeOp, storePtr, adaptor.getValueToStore(), memoryAccess, alignment);
return success();
}

Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
LogicalResult
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getMemRefType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
Expand All @@ -298,7 +298,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// Special case 0-rank memref loads.
Value bitsLoad;
if (convertedType.getRank() == 0) {
bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(),
bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getBase(),
ValueRange{});
} else {
// Linearize the indices of the original load instruction. Do not account
Expand All @@ -307,7 +307,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());

Value newLoad = memref::LoadOp::create(
rewriter, loc, adaptor.getMemref(),
rewriter, loc, adaptor.getBase(),
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
dstBits));

Expand Down Expand Up @@ -414,7 +414,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto convertedType = cast<MemRefType>(adaptor.getMemref().getType());
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
int srcBits = op.getMemRefType().getElementTypeBitWidth();
int dstBits = convertedType.getElementTypeBitWidth();
auto dstIntegerType = rewriter.getIntegerType(dstBits);
Expand All @@ -426,7 +426,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
Location loc = op.getLoc();

// Pad the input value with 0s on the left.
Value input = adaptor.getValue();
Value input = adaptor.getValueToStore();
if (!input.getType().isInteger()) {
input = arith::BitcastOp::create(
rewriter, loc,
Expand All @@ -440,7 +440,7 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign,
extendedInput, adaptor.getMemref(),
extendedInput, adaptor.getBase(),
ValueRange{});
rewriter.eraseOp(op);
return success();
Expand All @@ -460,10 +460,10 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {

// Clear destination bits
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(), storeIndices);
writeMask, adaptor.getBase(), storeIndices);
// Write srcs bits to destination
memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(), storeIndices);
alignedVal, adaptor.getBase(), storeIndices);
rewriter.eraseOp(op);
return success();
}
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
op.getMemRefType()));

rewriter.replaceOpWithNewOp<memref::LoadOp>(
op, newResTy, adaptor.getMemref(), adaptor.getIndices(),
op.getNontemporal());
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, newResTy, adaptor.getBase(),
adaptor.getIndices(),
op.getNontemporal());
return success();
}
};
Expand All @@ -90,7 +90,7 @@ struct ConvertMemRefStore final : OpConversionPattern<memref::StoreOp> {
op.getMemRefType()));

rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(),
op, adaptor.getValueToStore(), adaptor.getBase(), adaptor.getIndices(),
op.getNontemporal());
return success();
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/python/dialects/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def testParallelMemcpy():

with InsertionPoint(loop_block):
idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0])
val = memref.load(memref=copied, indices=[idx])
memref.store(value=val, memref=created, indices=[idx])
val = memref.load(base=copied, indices=[idx])
memref.store(value_to_store=val, base=created, indices=[idx])
openacc.YieldOp([])

openacc.YieldOp([])
Expand Down