diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 0bf22928f6900..cf5c93fc136d9 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1233,8 +1233,8 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol, //===----------------------------------------------------------------------===// def LoadOp : MemRef_Op<"load", - [TypesMatchWith<"result type matches element type of 'memref'", - "memref", "result", + [TypesMatchWith<"result type matches element type of 'base'", + "base", "result", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, DeclareOpInterfaceMethods, @@ -1273,7 +1273,7 @@ def LoadOp : MemRef_Op<"load", }]; let arguments = (ins Arg:$memref, + [MemRead]>:$base, Variadic:$indices, DefaultValuedOptionalAttr:$nontemporal, OptionalAttr>:$alignment); @@ -1310,6 +1310,8 @@ def LoadOp : MemRef_Op<"load", let results = (outs AnyType:$result); let extraClassDeclaration = [{ + ::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); } + ::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); } Value getMemRef() { return getOperand(0); } void setMemRef(Value value) { setOperand(0, value); } MemRefType getMemRefType() { @@ -1320,7 +1322,7 @@ def LoadOp : MemRef_Op<"load", let hasFolder = 1; let hasVerifier = 1; - let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; + let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base)"; } //===----------------------------------------------------------------------===// @@ -2010,8 +2012,8 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [ //===----------------------------------------------------------------------===// def MemRef_StoreOp : MemRef_Op<"store", - [TypesMatchWith<"type of 'value' matches element type of 'memref'", - "memref", "value", + [TypesMatchWith<"type of 'valueToStore' matches element type of 'base'", + "base", "valueToStore", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, DeclareOpInterfaceMethods, @@ -2046,9 +2048,9 @@ def MemRef_StoreOp : MemRef_Op<"store", ``` }]; - let arguments = (ins AnyType:$value, + let arguments = (ins AnyType:$valueToStore, Arg:$memref, + [MemWrite]>:$base, Variadic:$indices, DefaultValuedOptionalAttr:$nontemporal, OptionalAttr>:$alignment); @@ -2070,7 +2072,10 @@ def MemRef_StoreOp : MemRef_Op<"store", ]; let extraClassDeclaration = [{ - Value getValueToStore() { return getOperand(0); } + ::mlir::TypedValue<::mlir::MemRefType> getMemref() { return getBase(); } + ::mlir::OpOperand &getMemrefMutable() { return getBaseMutable(); } + + Value getValue() { return getOperand(0); } Value getMemRef() { return getOperand(1); } void setMemRef(Value value) { setOperand(1, value); } @@ -2083,7 +2088,7 @@ def MemRef_StoreOp : MemRef_Op<"store", let hasVerifier = 1; let assemblyFormat = [{ - $value `,` $memref `[` $indices `]` attr-dict `:` type($memref) + $valueToStore `,` $base `[` $indices `]` attr-dict `:` type($base) }]; } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 0a382d812f362..caefe0bde3cff 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -342,7 +342,7 @@ struct ConvertLoad final : public OpConversionPattern { } auto arrayValue = - dyn_cast>(operands.getMemref()); + dyn_cast>(operands.getBase()); if (!arrayValue) { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } @@ -362,7 +362,7 @@ struct ConvertStore final : public OpConversionPattern { matchAndRewrite(memref::StoreOp op, OpAdaptor operands, ConversionPatternRewriter &rewriter) const override { auto arrayValue = - dyn_cast>(operands.getMemref()); + dyn_cast>(operands.getBase()); if (!arrayValue) { return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); } @@ -370,7 +370,7 @@ struct ConvertStore final : public OpConversionPattern { auto subscript = emitc::SubscriptOp::create( rewriter, op.getLoc(), arrayValue, operands.getIndices()); rewriter.replaceOpWithNewOp(op, subscript, - operands.getValue()); + operands.getValueToStore()); return success(); } }; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 91a0c4b55fa84..af1e25add0167 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -941,9 +941,9 @@ struct LoadOpLowering : public LoadStoreOpLowering { // Per memref.load spec, the indices must be in-bounds: // 0 <= idx < dim_size, and additionally all offsets are non-negative, // hence inbounds and nuw are used when lowering to llvm.getelementptr. - Value dataPtr = getStridedElementPtr(rewriter, loadOp.getLoc(), type, - adaptor.getMemref(), - adaptor.getIndices(), kNoWrapFlags); + Value dataPtr = + getStridedElementPtr(rewriter, loadOp.getLoc(), type, adaptor.getBase(), + adaptor.getIndices(), kNoWrapFlags); rewriter.replaceOpWithNewOp( loadOp, typeConverter->convertType(type.getElementType()), dataPtr, loadOp.getAlignment().value_or(0), false, loadOp.getNontemporal()); @@ -965,11 +965,11 @@ struct StoreOpLowering : public LoadStoreOpLowering { // 0 <= idx < dim_size, and additionally all offsets are non-negative, // hence inbounds and nuw are used when lowering to llvm.getelementptr. Value dataPtr = - getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getMemref(), + getStridedElementPtr(rewriter, op.getLoc(), type, adaptor.getBase(), adaptor.getIndices(), kNoWrapFlags); - rewriter.replaceOpWithNewOp(op, adaptor.getValue(), dataPtr, - op.getAlignment().value_or(0), - false, op.getNontemporal()); + rewriter.replaceOpWithNewOp( + op, adaptor.getValueToStore(), dataPtr, op.getAlignment().value_or(0), + false, op.getNontemporal()); return success(); } }; diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index a90dcc8cc3ef1..03e896ce5f9b6 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -555,7 +555,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, const auto &typeConverter = *getTypeConverter(); Value accessChain = - spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), + spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(), adaptor.getIndices(), loc, rewriter); if (!accessChain) @@ -682,7 +682,7 @@ LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, "failed to lower memref in image storage class to storage buffer"); Value loadPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, adaptor.getMemref(), + *getTypeConverter(), memrefType, adaptor.getBase(), adaptor.getIndices(), loadOp.getLoc(), rewriter); if (!loadPtr) @@ -743,7 +743,7 @@ ImageLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, return rewriter.notifyMatchFailure( loadOp, "failed to lower memref in non-image storage class to image"); - Value loadPtr = adaptor.getMemref(); + Value loadPtr = adaptor.getBase(); auto memoryRequirements = calculateMemoryRequirements(loadPtr, loadOp); if (failed(memoryRequirements)) return rewriter.notifyMatchFailure( @@ -824,7 +824,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); Value accessChain = - spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), + spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(), adaptor.getIndices(), loc, rewriter); if (!accessChain) @@ -874,7 +874,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, storeOp, "failed to determine memory requirements"); auto [memoryAccess, alignment] = *memoryRequirements; - Value storeVal = adaptor.getValue(); + Value storeVal = adaptor.getValueToStore(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); rewriter.replaceOpWithNewOp(storeOp, accessChain, storeVal, @@ -915,7 +915,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, clearBitsMask = rewriter.createOrFold(loc, dstType, clearBitsMask); - Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); + Value storeVal = + shiftValue(loc, adaptor.getValueToStore(), offset, mask, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); std::optional scope = getAtomicOpScope(memrefType); @@ -1020,7 +1021,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, if (memrefType.getElementType().isSignlessInteger()) return rewriter.notifyMatchFailure(storeOp, "signless int"); auto storePtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, adaptor.getMemref(), + *getTypeConverter(), memrefType, adaptor.getBase(), adaptor.getIndices(), storeOp.getLoc(), rewriter); if (!storePtr) @@ -1033,7 +1034,7 @@ StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, auto [memoryAccess, alignment] = *memoryRequirements; rewriter.replaceOpWithNewOp( - storeOp, storePtr, adaptor.getValue(), memoryAccess, alignment); + storeOp, storePtr, adaptor.getValueToStore(), memoryAccess, alignment); return success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 09d4ffa61738a..1030faa212f11 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -284,7 +284,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto convertedType = cast(adaptor.getMemref().getType()); + auto convertedType = cast(adaptor.getBase().getType()); auto convertedElementType = convertedType.getElementType(); auto oldElementType = op.getMemRefType().getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); @@ -298,7 +298,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { // Special case 0-rank memref loads. Value bitsLoad; if (convertedType.getRank() == 0) { - bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(), + bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getBase(), ValueRange{}); } else { // Linearize the indices of the original load instruction. Do not account @@ -307,7 +307,7 @@ struct ConvertMemRefLoad final : OpConversionPattern { rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); Value newLoad = memref::LoadOp::create( - rewriter, loc, adaptor.getMemref(), + rewriter, loc, adaptor.getBase(), getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, dstBits)); @@ -414,7 +414,7 @@ struct ConvertMemrefStore final : OpConversionPattern { LogicalResult matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto convertedType = cast(adaptor.getMemref().getType()); + auto convertedType = cast(adaptor.getBase().getType()); int srcBits = op.getMemRefType().getElementTypeBitWidth(); int dstBits = convertedType.getElementTypeBitWidth(); auto dstIntegerType = rewriter.getIntegerType(dstBits); @@ -426,7 +426,7 @@ struct ConvertMemrefStore final : OpConversionPattern { Location loc = op.getLoc(); // Pad the input value with 0s on the left. - Value input = adaptor.getValue(); + Value input = adaptor.getValueToStore(); if (!input.getType().isInteger()) { input = arith::BitcastOp::create( rewriter, loc, @@ -440,7 +440,7 @@ struct ConvertMemrefStore final : OpConversionPattern { // Special case 0-rank memref stores. No need for masking. if (convertedType.getRank() == 0) { memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign, - extendedInput, adaptor.getMemref(), + extendedInput, adaptor.getBase(), ValueRange{}); rewriter.eraseOp(op); return success(); @@ -460,10 +460,10 @@ struct ConvertMemrefStore final : OpConversionPattern { // Clear destination bits memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi, - writeMask, adaptor.getMemref(), storeIndices); + writeMask, adaptor.getBase(), storeIndices); // Write srcs bits to destination memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori, - alignedVal, adaptor.getMemref(), storeIndices); + alignedVal, adaptor.getBase(), storeIndices); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp index 6f815ae46904c..79282ecd79d5e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateWideInt.cpp @@ -66,9 +66,9 @@ struct ConvertMemRefLoad final : OpConversionPattern { op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getMemRefType())); - rewriter.replaceOpWithNewOp( - op, newResTy, adaptor.getMemref(), adaptor.getIndices(), - op.getNontemporal()); + rewriter.replaceOpWithNewOp(op, newResTy, adaptor.getBase(), + adaptor.getIndices(), + op.getNontemporal()); return success(); } }; @@ -90,7 +90,7 @@ struct ConvertMemRefStore final : OpConversionPattern { op.getMemRefType())); rewriter.replaceOpWithNewOp( - op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), + op, adaptor.getValueToStore(), adaptor.getBase(), adaptor.getIndices(), op.getNontemporal()); return success(); } diff --git a/mlir/test/python/dialects/openacc.py b/mlir/test/python/dialects/openacc.py index 8f2142a74c7a1..d3af869889e10 100644 --- a/mlir/test/python/dialects/openacc.py +++ b/mlir/test/python/dialects/openacc.py @@ -121,8 +121,8 @@ def testParallelMemcpy(): with InsertionPoint(loop_block): idx = arith.index_cast(out=IndexType.get(), in_=loop_block.arguments[0]) - val = memref.load(memref=copied, indices=[idx]) - memref.store(value=val, memref=created, indices=[idx]) + val = memref.load(base=copied, indices=[idx]) + memref.store(value_to_store=val, base=created, indices=[idx]) openacc.YieldOp([]) openacc.YieldOp([])