blob: fb055286df46b3021250efdeee5c122a02476752 [file] [log] [blame] [edit]
//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Lower/CUDA.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#define DEBUG_TYPE "flang-lower-cuda"
mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType(
fir::FirOpBuilder &builder, mlir::Location loc,
const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
llvm::SmallVector<mlir::Value> &coordinates) {
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
mlir::Type fieldTy;
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
// Field found in the base record type.
auto fieldName = recTy.getTypeList()[fieldIdx].first;
fieldTy = recTy.getTypeList()[fieldIdx].second;
mlir::Value fieldIndex = fir::FieldIndexOp::create(
builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(fieldIndex);
} else {
// Field not found in base record type, search in potential
// record type components.
for (auto component : recTy.getTypeList()) {
if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) {
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
builder, loc, fir::FieldType::get(childRecTy.getContext()),
component.first, recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(parentFieldIndex);
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
mlir::Value childFieldIndex = fir::FieldIndexOp::create(
builder, loc, fir::FieldType::get(fieldTy.getContext()),
fieldName, childRecTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(childFieldIndex);
break;
}
}
}
}
if (coordinates.empty())
TODO(loc, "device resident component in complex derived-type hierarchy");
return fieldTy;
}
cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
return cuf::getDataAttribute(mlirContext, cudaAttr);
}
hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
return llvm::hasSingleElement(
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
1;
};
if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
if (!declOp.getMemref().getDefiningOp())
return {};
if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>(
declOp.getMemref().getDefiningOp()))
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
associateOp.getSource().getDefiningOp()))
if (isConversionElementalOp(elOp))
return elOp;
}
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
if (isConversionElementalOp(elOp))
return elOp;
return {};
}
bool Fortran::lower::hasDoubleDescriptor(mlir::Value addr) {
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(addr.getDefiningOp())) {
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
declareOp.getMemref().getDefiningOp())) {
if (declareOp.getDataAttr() &&
*declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
return false;
return true;
}
}
return false;
}