diff --git a/src/cudecomp.cc b/src/cudecomp.cc index 94a3b3e..b2283d2 100644 --- a/src/cudecomp.cc +++ b/src/cudecomp.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -166,6 +167,35 @@ static void checkGridDesc(cudecompGridDesc_t grid_desc) { if (!grid_desc || !grid_desc->initialized) { THROW_INVALID_USAGE("invalid grid descriptor"); } } +static cudecompResult_t handleException(const BaseException& e) { + std::cerr << e.what(); + return e.getResult(); +} + +static cudecompResult_t handleUnexpectedException(const std::exception& e) { + std::cerr << "CUDECOMP:ERROR: Internal error. (" << e.what() << ")\n"; + return CUDECOMP_RESULT_INTERNAL_ERROR; +} + +static cudecompResult_t handleUnexpectedException() { + std::cerr << "CUDECOMP:ERROR: Internal error. (unknown exception)\n"; + return CUDECOMP_RESULT_INTERNAL_ERROR; +} + +#define CUDECOMP_CATCH_C_API_ERRORS(...) \ + catch (const cudecomp::BaseException& e) { \ + __VA_ARGS__; \ + return handleException(e); \ + } \ + catch (const std::exception& e) { \ + __VA_ARGS__; \ + return handleUnexpectedException(e); \ + } \ + catch (...) { \ + __VA_ARGS__; \ + return handleUnexpectedException(); \ + } + static void checkConfig(cudecompHandle_t handle, const cudecompGridDescConfig_t* config, bool autotune_transpose, bool autotune_halos) { if (!autotune_transpose) { checkTransposeCommBackend(config->transpose_comm_backend); } @@ -572,12 +602,8 @@ cudecompResult_t cudecompInit(cudecompHandle_t* handle_in, MPI_Comm mpi_comm) { cudecomp_initialized = true; *handle_in = handle; - - } catch (const BaseException& e) { - if (handle) { delete handle; } - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS(if (handle) { delete handle; }) return CUDECOMP_RESULT_SUCCESS; }; @@ -621,17 +647,18 @@ cudecompResult_t cudecompFinalize(cudecompHandle_t handle) { delete handle; cudecomp_initialized = false; - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } cudecompResult_t cudecompInit_F(cudecompHandle_t* handle_in, MPI_Fint mpi_comm_f) { - MPI_Comm mpi_comm = MPI_Comm_f2c(mpi_comm_f); - return cudecompInit(handle_in, mpi_comm); + using namespace cudecomp; + try { + MPI_Comm mpi_comm = MPI_Comm_f2c(mpi_comm_f); + return cudecompInit(handle_in, mpi_comm); + } + CUDECOMP_CATCH_C_API_ERRORS() } cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDesc_t* grid_desc_in, @@ -642,6 +669,7 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes cudecompGridDesc_t grid_desc = nullptr; try { checkHandle(handle); + if (!grid_desc_in) { THROW_INVALID_USAGE("grid_desc argument cannot be null"); } if (!config) { THROW_INVALID_USAGE("config argument cannot be null"); } // Check some autotuning options @@ -873,18 +901,15 @@ cudecompResult_t cudecompGridDescCreate(cudecompHandle_t handle, cudecompGridDes } } } - - } catch (const cudecomp::BaseException& e) { - if (grid_desc) { delete grid_desc; } - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS(if (grid_desc) { delete grid_desc; }) return CUDECOMP_RESULT_SUCCESS; } cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDesc_t grid_desc) { using namespace cudecomp; try { + checkHandle(handle); checkGridDesc(grid_desc); if (grid_desc->row_comm_info.mpi_comm != MPI_COMM_NULL) { @@ -970,11 +995,8 @@ cudecompResult_t cudecompGridDescDestroy(cudecompHandle_t handle, cudecompGridDe delete grid_desc; grid_desc = nullptr; - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1004,11 +1026,8 @@ cudecompResult_t cudecompGridDescConfigSetDefaults(cudecompGridDescConfig_t* con // Halo Options config->halo_comm_backend = CUDECOMP_HALO_COMM_MPI; - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1048,11 +1067,8 @@ cudecompResult_t cudecompGridDescAutotuneOptionsSetDefaults(cudecompGridDescAuto options->halo_periods[i] = false; options->halo_padding[i] = 0; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1110,11 +1126,8 @@ cudecompResult_t cudecompGetPencilInfo(cudecompHandle_t handle, cudecompGridDesc } pencil_info->size *= pencil_info->shape[ord]; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1142,11 +1155,8 @@ cudecompResult_t cudecompGetGridDescConfig(cudecompHandle_t handle, cudecompGrid } } } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1168,11 +1178,8 @@ cudecompResult_t cudecompGetTransposeWorkspaceSize(cudecompHandle_t handle, cude int64_t wsize_zy = alignCountToBytes(max_pencil_size_z, CUDECOMP_WORKSPACE_ALIGN_BYTES) + max_pencil_size_y; *workspace_size = std::max({wsize_xy, wsize_yx, wsize_yz, wsize_zy}); - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1198,10 +1205,8 @@ cudecompResult_t cudecompGetHaloWorkspaceSize(cudecompHandle_t handle, cudecompG 4 * alignCountToBytes(shape_g[0] * shape_g[1] * pinfo.halo_extents[2], CUDECOMP_WORKSPACE_ALIGN_BYTES); *workspace_size = std::max({halo_size_x, halo_size_y, halo_size_z}); - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1303,11 +1308,8 @@ cudecompResult_t cudecompMalloc(cudecompHandle_t handle, cudecompGridDesc_t grid } #endif } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1359,10 +1361,8 @@ cudecompResult_t cudecompFree(cudecompHandle_t handle, cudecompGridDesc_t grid_d if (buffer) { CHECK_CUDA(cudaFree(buffer)); } } } - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1402,11 +1402,8 @@ cudecompResult_t cudecompGetDataTypeSize(cudecompDataType_t dtype, int64_t* dtyp case CUDECOMP_FLOAT_COMPLEX: *dtype_size = 8; break; case CUDECOMP_DOUBLE_COMPLEX: *dtype_size = 16; break; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1447,11 +1444,8 @@ cudecompResult_t cudecompGetShiftedRank(cudecompHandle_t handle, cudecompGridDes int global_peer = getGlobalRank(handle, grid_desc, comm_axis, comm_peer); *shifted_rank = global_peer; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1491,10 +1485,8 @@ cudecompResult_t cudecompTransposeXToY(cudecompHandle_t handle, cudecompGridDesc output_halo_extents, input_padding, output_padding, stream); break; } - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1534,10 +1526,8 @@ cudecompResult_t cudecompTransposeYToZ(cudecompHandle_t handle, cudecompGridDesc output_halo_extents, input_padding, output_padding, stream); break; } - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1577,10 +1567,8 @@ cudecompResult_t cudecompTransposeZToY(cudecompHandle_t handle, cudecompGridDesc output_halo_extents, input_padding, output_padding, stream); break; } - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1620,10 +1608,8 @@ cudecompResult_t cudecompTransposeYToX(cudecompHandle_t handle, cudecompGridDesc output_halo_extents, input_padding, output_padding, stream); break; } - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1665,11 +1651,8 @@ cudecompResult_t cudecompUpdateHalosX(cudecompHandle_t handle, cudecompGridDesc_ padding, stream); break; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1711,11 +1694,8 @@ cudecompResult_t cudecompUpdateHalosY(cudecompHandle_t handle, cudecompGridDesc_ padding, stream); break; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } @@ -1757,10 +1737,9 @@ cudecompResult_t cudecompUpdateHalosZ(cudecompHandle_t handle, cudecompGridDesc_ padding, stream); break; } - - } catch (const cudecomp::BaseException& e) { - std::cerr << e.what(); - return e.getResult(); } + CUDECOMP_CATCH_C_API_ERRORS() return CUDECOMP_RESULT_SUCCESS; } + +#undef CUDECOMP_CATCH_C_API_ERRORS