110 std::vector<raster::GDALRasterWrapper *> rasters,
111 std::vector<std::vector<int>> bands,
112 std::vector<std::vector<int>> strataCounts,
113 std::string filename,
116 std::string tempFolder,
117 std::map<std::string, std::string> driverOptions)
121 int height = rasters[0]->getHeight();
122 int width = rasters[0]->getWidth();
123 double *geotransform = rasters[0]->getGeotransform();
124 std::string projection = (rasters[0]->getDataset()->GetProjectionRef());
125 if (projection ==
"") {
126 throw std::runtime_error(
"could not get projection from the first raster argument.");
129 OGRSpatialReference
srs;
130 srs.importFromWkt(projection.c_str());
132 for (
size_t i = 1; i < rasters.size(); i++) {
133 if (rasters[i]->getHeight() != height) {
134 std::string err =
"raster with index " + std::to_string(i) +
" has a different height from the raster at index 0.";
135 throw std::runtime_error(err);
138 if (rasters[i]->getWidth() != width) {
139 std::string err =
"raster with index " + std::to_string(i) +
" has a different width from the raster at index 0.";
140 throw std::runtime_error(err);
143 double *checkGeotransform = rasters[i]->getGeotransform();
144 for (
int j = 0; j < 6; j++) {
145 if (geotransform[i] != checkGeotransform[i]) {
146 std::string err =
"raster with index " + std::to_string(i) +
" has a different geotransform from the raster at index 0.";
147 throw std::runtime_error(err);
151 OGRSpatialReference checkSrs;
152 checkSrs.importFromWkt(rasters[i]->getDataset()->GetProjectionRef());
153 if (!
srs.IsSame(&checkSrs)) {
154 std::string err =
"raster with index " + std::to_string(i) +
" has a different projection from the raster at index 0.";
155 throw std::runtime_error(err);
159 std::vector<helper::RasterBandMetaData> stratBands;
160 std::vector<int> numStrataPerBand;
162 std::vector<helper::VRTBandDatasetInfo> VRTBandInfo(1);
164 bool isMEMDataset = !largeRaster && filename ==
"";
165 bool isVRTDataset = largeRaster && filename ==
"";
167 std::vector<std::mutex> stratDatasetMutexes(rasters.size());
168 std::mutex mapBandMutex;
171 GDALDataset *p_dataset =
nullptr;
172 if (isMEMDataset || isVRTDataset) {
173 std::string driver = isMEMDataset ?
"MEM" :
"VRT";
177 std::filesystem::path filepath = filename;
178 std::string extension = filepath.extension().string();
180 if (extension ==
".tif") {
184 throw std::runtime_error(
"sgs only supports .tif files right now");
189 std::vector<size_t> multipliers(1, 1);
190 for (
size_t i = 0; i < rasters.size(); i++) {
193 for (
size_t j = 0; j < bands[i].size(); j++) {
194 int band = bands[i][j];
195 int strataCount = strataCounts[i][j];
196 numStrataPerBand.push_back(strataCount);
201 stratBand.
p_band = p_band;
205 stratBand.
nan = p_band->GetNoDataValue();
206 stratBand.
p_mutex = &stratDatasetMutexes[i];
208 stratBands.push_back(stratBand);
211 multipliers.push_back(multipliers.back() * strataCount);
215 size_t bandCount = stratBands.size();
216 size_t maxStrata = multipliers.back();
217 multipliers.pop_back();
219 mapBand.
name =
"strat_map";
220 mapBand.
xBlockSize = stratBands[0].xBlockSize;
221 mapBand.
yBlockSize = stratBands[0].yBlockSize;
222 mapBand.
p_mutex = &mapBandMutex;
227 else if (isVRTDataset) {
238 bool useTiles = mapBand.
xBlockSize != width &&
242 VSIMalloc3(height, width, mapBand.
size) :
260 pybind11::gil_scoped_acquire acquire;
261 boost::asio::thread_pool pool(threadCount);
263 int xBlockSize = stratBands[0].xBlockSize;
264 int yBlockSize = stratBands[0].yBlockSize;
266 int xBlocks = (width + xBlockSize - 1) / xBlockSize;
267 int yBlocks = (height + yBlockSize - 1) / yBlockSize;
268 int chunkSize = yBlocks / threadCount;
270 for (
int yBlockStart = 0; yBlockStart < yBlocks; yBlockStart += chunkSize) {
271 int yBlockEnd = std::min(yBlockStart + chunkSize, yBlocks);
273 boost::asio::post(pool, [
285 std::vector<void *> stratBuffers(bandCount);
286 for (
size_t band = 0; band < bandCount; band++) {
287 stratBuffers[band] = VSIMalloc3(xBlockSize, yBlockSize, stratBands[band].size);
289 void *p_mapBuffer = VSIMalloc3(xBlockSize, yBlockSize, mapBand.
size);
292 std::vector<int> intNoDataValues(bandCount);
293 for (
size_t band = 0; band < bandCount; band++) {
294 intNoDataValues[band] =
static_cast<int>(stratBands[band].nan);
297 for (
int yBlock = yBlockStart; yBlock < yBlockEnd; yBlock++) {
298 for (
int xBlock = 0; xBlock < xBlocks; xBlock++) {
300 stratBands[0].p_mutex->lock();
301 stratBands[0].p_band->GetActualBlockSize(xBlock, yBlock, &xValid, &yValid);
302 stratBands[0].p_mutex->unlock();
304 for (
size_t band = 0; band < bandCount; band++) {
318 for (
int y = 0; y < yValid; y++) {
319 for (
int x = 0; x < xValid; x++) {
320 size_t index =
static_cast<size_t>(x + y * xBlockSize);
324 for (
size_t band = 0; band < bandCount; band++) {
326 isNan =
strat == intNoDataValues[band];
329 if (
strat >= numStrataPerBand[band]) {
330 std::string bandName = stratBands[band].p_band->GetDescription();
331 std::string errmsg =
"the num_strata indicated for band " + bandName +
" is less than or equal to one of the values in that band.";
332 throw std::runtime_error(errmsg);
336 std::string bandName = stratBands[band].p_band->GetDescription();
337 std::string errmsg =
"a negative strata value of " + std::to_string(
strat) +
" was found in band " + bandName +
", and is not marked as a nodata value.";
338 throw std::runtime_error(errmsg);
341 mappedStrat +=
strat * multipliers[band];
366 for (
size_t band = 0; band < bandCount; band++) {
367 VSIFree(stratBuffers[band]);
369 VSIFree(p_mapBuffer);
374 pybind11::gil_scoped_release release;
377 std::vector<int> intNoDataValues(bandCount);
378 for (
size_t band = 0; band < bandCount; band++) {
379 intNoDataValues[band] =
static_cast<int>(stratBands[band].nan);
382 size_t pixelCount =
static_cast<size_t>(height) *
static_cast<size_t>(width);
383 for (
size_t index = 0; index < pixelCount; index++) {
387 for (
size_t band = 0; (band < bandCount) && !isNan; band++) {
389 isNan =
strat == intNoDataValues[band];
391 mappedStrat +=
strat * multipliers[band];
398 if (!isVRTDataset && !isMEMDataset) {
399 CPLErr err = mapBand.
p_band->RasterIO(
413 throw std::runtime_error(
"error writing band to file.");
420 GDALClose(VRTBandInfo[0].p_dataset);